mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Merge branch 'user/sm103_trtllmgen' into feat/b300_cu13
Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
commit
5e7aa76bb4
@ -3,18 +3,17 @@ from typing import Optional, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm._torch.modules.qk_norm_attention import QKNormRoPEAttention
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import (PositionalEmbeddingParams,
|
||||
PredefinedAttentionMask, RopeParams)
|
||||
from ..model_config import ModelConfig
|
||||
from ..modules.attention import Attention
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.linear import TensorParallelMode
|
||||
from ..modules.multi_stream_utils import maybe_execute_in_parallel
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..speculative import SpecMetadata
|
||||
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
|
||||
@ -50,12 +49,11 @@ def check_is_sliding(config: Exaone4Config, layer_idx: int) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class Exaone4Attention(Attention):
|
||||
class Exaone4Attention(QKNormRoPEAttention):
|
||||
|
||||
def __init__(self,
|
||||
model_config: ModelConfig[Exaone4Config],
|
||||
layer_idx: Optional[int] = None,
|
||||
aux_stream: Optional[torch.cuda.Stream] = None,
|
||||
fuse_qk_norm_rope: bool = False):
|
||||
config = model_config.pretrained_config
|
||||
|
||||
@ -73,10 +71,10 @@ class Exaone4Attention(Attention):
|
||||
rope=RopeParams.from_config(config),
|
||||
)
|
||||
|
||||
self.fuse_qk_norm_rope = (self.is_sliding and fuse_qk_norm_rope)
|
||||
fuse_qk_norm_rope = (self.is_sliding and fuse_qk_norm_rope)
|
||||
|
||||
# TODO: Fusing qk norm with rope has an issue that slightly hurts accuracy.
|
||||
assert self.fuse_qk_norm_rope is False, "Fusing qk norm and rope is having issue now"
|
||||
assert fuse_qk_norm_rope is False, "Fusing qk norm and rope is having issue now"
|
||||
|
||||
super().__init__(
|
||||
hidden_size=config.hidden_size,
|
||||
@ -85,65 +83,13 @@ class Exaone4Attention(Attention):
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
bias=False,
|
||||
pos_embd_params=pos_embd_params,
|
||||
rope_fusion=False,
|
||||
fuse_qk_norm_rope=fuse_qk_norm_rope,
|
||||
skip_rope=self.sliding_window and not self.is_sliding,
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config,
|
||||
)
|
||||
|
||||
self.q_norm = RMSNorm(hidden_size=self.head_dim,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
self.k_norm = RMSNorm(hidden_size=self.head_dim,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
|
||||
self.aux_stream = aux_stream
|
||||
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
|
||||
|
||||
def apply_qk_norm(self, q, k):
|
||||
|
||||
def q_l2norm():
|
||||
return self.q_norm(q.reshape(-1, self.head_dim)).reshape(
|
||||
-1, self.q_size)
|
||||
|
||||
def k_l2norm():
|
||||
return self.k_norm(k.reshape(-1, self.head_dim)).reshape(
|
||||
-1, self.kv_size)
|
||||
|
||||
q, k = maybe_execute_in_parallel(
|
||||
q_l2norm,
|
||||
k_l2norm,
|
||||
self.ln_events[0],
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
|
||||
return q, k
|
||||
|
||||
def apply_qk_norm_rope(self, qkv, position_ids):
|
||||
torch.ops.trtllm.fused_qk_norm_rope(
|
||||
qkv, self.num_heads, self.num_key_value_heads,
|
||||
self.num_key_value_heads, self.head_dim,
|
||||
self.q_norm.variance_epsilon, self.q_norm.weight,
|
||||
self.k_norm.weight, self.pos_embd_params.rope.theta,
|
||||
self.pos_embd_params.is_neox, position_ids.view(-1))
|
||||
return qkv, None, None
|
||||
|
||||
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
|
||||
v: Optional[torch.Tensor], position_ids: torch.Tensor):
|
||||
if self.fuse_qk_norm_rope:
|
||||
assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope"
|
||||
qkv = q
|
||||
return self.apply_qk_norm_rope(qkv, position_ids)
|
||||
|
||||
q, k, v = self.split_qkv(q, k, v)
|
||||
q, k = self.apply_qk_norm(q, k)
|
||||
if self.sliding_window is None or self.is_sliding:
|
||||
return super().apply_rope(q, k, v, position_ids)
|
||||
else:
|
||||
return q, k, v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: Optional[torch.LongTensor],
|
||||
@ -175,7 +121,6 @@ class Exaone4DecoderLayer(DecoderLayer):
|
||||
self,
|
||||
model_config: ModelConfig[Exaone4Config],
|
||||
layer_idx: int,
|
||||
aux_stream: Optional[torch.cuda.Stream] = None,
|
||||
):
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
@ -186,7 +131,6 @@ class Exaone4DecoderLayer(DecoderLayer):
|
||||
self.self_attn = Exaone4Attention(
|
||||
model_config,
|
||||
layer_idx=layer_idx,
|
||||
aux_stream=aux_stream,
|
||||
)
|
||||
|
||||
self.mlp = GatedMLP(
|
||||
@ -244,7 +188,6 @@ class Exaone4Model(DecoderModel):
|
||||
super().__init__(model_config)
|
||||
config = self.model_config.pretrained_config
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.aux_stream = torch.cuda.Stream()
|
||||
self.embed_tokens = Embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
@ -258,7 +201,6 @@ class Exaone4Model(DecoderModel):
|
||||
Exaone4DecoderLayer(
|
||||
model_config,
|
||||
layer_idx,
|
||||
self.aux_stream,
|
||||
) for layer_idx in range(self.num_hidden_layers)
|
||||
])
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from transformers import Gemma3TextConfig
|
||||
|
||||
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
|
||||
BaseWeightMapper
|
||||
from tensorrt_llm._torch.modules.qk_norm_attention import QKNormRoPEAttention
|
||||
from tensorrt_llm.functional import PositionEmbeddingType, RotaryScalingType
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
@ -15,12 +16,10 @@ from ..attention_backend.interface import (AttentionMask, CustomAttentionMask,
|
||||
PositionalEmbeddingParams,
|
||||
PredefinedAttentionMask, RopeParams)
|
||||
from ..model_config import ModelConfig
|
||||
from ..modules.attention import Attention
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.linear import TensorParallelMode
|
||||
from ..modules.multi_stream_utils import maybe_execute_in_parallel
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
|
||||
register_auto_model)
|
||||
@ -52,7 +51,7 @@ class Gemma3TextScaledWordEmbedding(Embedding):
|
||||
return super().forward(input_ids) * self.embed_scale
|
||||
|
||||
|
||||
class Gemma3Attention(Attention):
|
||||
class Gemma3Attention(QKNormRoPEAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -82,20 +81,13 @@ class Gemma3Attention(Attention):
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
bias=False,
|
||||
pos_embd_params=pos_embd_params,
|
||||
fuse_qk_norm_rope=False,
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
dense_bias=False,
|
||||
config=model_config,
|
||||
q_scaling=q_scaling,
|
||||
)
|
||||
self.q_norm = RMSNorm(hidden_size=config.head_dim,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
self.k_norm = RMSNorm(hidden_size=config.head_dim,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
self.aux_stream = torch.cuda.Stream()
|
||||
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
@ -121,33 +113,6 @@ class Gemma3Attention(Attention):
|
||||
attention_mask_data=attention_mask_data,
|
||||
**kwargs)
|
||||
|
||||
def apply_qk_norm(self, q, k):
|
||||
|
||||
def q_l2norm():
|
||||
return self.q_norm(q.reshape(-1, self.head_dim)).reshape(
|
||||
-1, self.q_size)
|
||||
|
||||
def k_l2norm():
|
||||
return self.k_norm(k.reshape(-1, self.head_dim)).reshape(
|
||||
-1, self.kv_size)
|
||||
|
||||
q, k = maybe_execute_in_parallel(
|
||||
q_l2norm,
|
||||
k_l2norm,
|
||||
self.ln_events[0],
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
|
||||
return q, k
|
||||
|
||||
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
|
||||
v: Optional[torch.Tensor], position_ids: torch.Tensor):
|
||||
# Gemma3 applies QK norm before RoPE.
|
||||
q, k, v = self.split_qkv(q, k, v)
|
||||
q, k = self.apply_qk_norm(q, k)
|
||||
return super().apply_rope(q, k, v, position_ids)
|
||||
|
||||
|
||||
# This function is written to be compatible with TRTLLM's GatedMLP class.
|
||||
def pytorch_gelu_tanh(gate_x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -1,133 +1,26 @@
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig, Qwen3Config
|
||||
from transformers import Qwen3Config
|
||||
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
||||
from ..model_config import ModelConfig
|
||||
from ..modules.attention import Attention
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.linear import TensorParallelMode
|
||||
from ..modules.multi_stream_utils import maybe_execute_in_parallel
|
||||
from ..modules.qk_norm_attention import QKNormRoPEAttention
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..speculative import SpecMetadata
|
||||
from .modeling_speculative import SpecDecOneEngineForCausalLM
|
||||
from .modeling_utils import DecoderModel, register_auto_model
|
||||
|
||||
|
||||
# Move out from this class
|
||||
def compute_yarn_parameters(
|
||||
config: PretrainedConfig, ) -> tuple[float, float, float, float]:
|
||||
"""
|
||||
Refer to https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L197C1-L288C1
|
||||
Computes the inverse frequencies with NTK scaling. Please refer to the
|
||||
[original paper](https://huggingface.co/papers/2309.00071)
|
||||
Args:
|
||||
config ([`~transformers.PretrainedConfig`]):
|
||||
The model configuration.
|
||||
Returns:
|
||||
factor: float, the scaling factor for the RoPE embeddings
|
||||
low: float, the lower bound of the dimension range
|
||||
high: float, the upper bound of the dimension range
|
||||
attention_factor: float, the post-processing scaling factor applied to the computed cos/sin
|
||||
"""
|
||||
|
||||
# The config does not contain rope_scaling, which means the model is not using yarn
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is None:
|
||||
return 1.0, 0, 0, 1.0
|
||||
|
||||
base = config.rope_theta
|
||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(
|
||||
config, "partial_rotary_factor") else 1.0
|
||||
head_dim = getattr(config, "head_dim",
|
||||
config.hidden_size // config.num_attention_heads)
|
||||
dim = int(head_dim * partial_rotary_factor)
|
||||
factor = getattr(rope_scaling, "factor", 1.0)
|
||||
attention_factor = rope_scaling.get("attention_factor")
|
||||
mscale = rope_scaling.get("mscale")
|
||||
mscale_all_dim = rope_scaling.get("mscale_all_dim")
|
||||
|
||||
if "original_max_position_embeddings" in rope_scaling:
|
||||
original_max_position_embeddings = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
factor = config.max_position_embeddings / original_max_position_embeddings
|
||||
else:
|
||||
original_max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
def get_mscale(scale, mscale=1):
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
# Sets the attention factor as suggested in the paper
|
||||
if attention_factor is None:
|
||||
if mscale and mscale_all_dim:
|
||||
attention_factor = float(
|
||||
get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
|
||||
else:
|
||||
attention_factor = get_mscale(factor)
|
||||
|
||||
# Optional config options
|
||||
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
|
||||
beta_fast = rope_scaling.get("beta_fast") or 32
|
||||
beta_slow = rope_scaling.get("beta_slow") or 1
|
||||
|
||||
# Compute the inverse frequencies
|
||||
def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
|
||||
"""Inverse dimension formula to find the dimension based on the number of rotations"""
|
||||
return (dim *
|
||||
math.log(max_position_embeddings /
|
||||
(num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
||||
|
||||
def find_correction_range(low_rot, high_rot, dim, base,
|
||||
max_position_embeddings, truncate):
|
||||
"""Find dimension range bounds based on rotations"""
|
||||
low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
||||
high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
||||
if truncate:
|
||||
low = math.floor(low)
|
||||
high = math.ceil(high)
|
||||
return max(low, 0), min(high, dim - 1)
|
||||
|
||||
truncate = rope_scaling.get("truncate", True)
|
||||
low, high = find_correction_range(beta_fast, beta_slow, dim, base,
|
||||
original_max_position_embeddings,
|
||||
truncate)
|
||||
|
||||
# These parts are implemented in the fusedQKNormRopeKernel.cu
|
||||
# # def linear_ramp_factor(min, max, dim):
|
||||
# # if min == max:
|
||||
# # max += 0.001 # Prevent singularity
|
||||
|
||||
# # linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
||||
# # ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
# # return ramp_func
|
||||
|
||||
# # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
|
||||
# # to expand the possible context length. In other words, interpolation = apply scaling factor.
|
||||
# # pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim)
|
||||
# # inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
# # inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
||||
|
||||
# # # Get n-dimensional rotational scaling corrected for extrapolation
|
||||
# # inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
|
||||
# # inv_freq = (
|
||||
# # inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
||||
# # + inv_freq_extrapolation * inv_freq_extrapolation_factor
|
||||
# # )
|
||||
# # return inv_freq, attention_factor
|
||||
return factor, low, high, attention_factor
|
||||
|
||||
|
||||
class Qwen3Attention(Attention):
|
||||
class Qwen3Attention(QKNormRoPEAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -136,7 +29,6 @@ class Qwen3Attention(Attention):
|
||||
fuse_qk_norm_rope: bool = True,
|
||||
):
|
||||
config = model_config.pretrained_config
|
||||
self.pretrained_config = config
|
||||
|
||||
if getattr(config, "rope_scaling", None) is not None:
|
||||
if "type" in config.rope_scaling:
|
||||
@ -156,8 +48,6 @@ class Qwen3Attention(Attention):
|
||||
rope=RopeParams.from_config(config),
|
||||
)
|
||||
|
||||
self.fuse_qk_norm_rope = fuse_qk_norm_rope
|
||||
|
||||
super().__init__(
|
||||
hidden_size=config.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
@ -165,69 +55,13 @@ class Qwen3Attention(Attention):
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
bias=config.attention_bias,
|
||||
pos_embd_params=pos_embd_params,
|
||||
rope_fusion=not self.
|
||||
fuse_qk_norm_rope, # If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP, and self.rotary_emb will be skipped in the overridden apply_rope.
|
||||
fuse_qk_norm_rope=fuse_qk_norm_rope,
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
dense_bias=config.attention_bias,
|
||||
config=model_config,
|
||||
)
|
||||
|
||||
self.q_norm = RMSNorm(hidden_size=self.head_dim,
|
||||
eps=1e-6,
|
||||
dtype=config.torch_dtype,
|
||||
has_weights=True)
|
||||
self.k_norm = RMSNorm(hidden_size=self.head_dim,
|
||||
eps=1e-6,
|
||||
dtype=config.torch_dtype,
|
||||
has_weights=True)
|
||||
self.aux_stream = torch.cuda.Stream()
|
||||
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
|
||||
|
||||
def apply_qk_norm(self, q, k):
|
||||
|
||||
def q_l2norm():
|
||||
return self.q_norm(q.reshape(-1, self.head_dim)).reshape(
|
||||
-1, self.q_size)
|
||||
|
||||
def k_l2norm():
|
||||
return self.k_norm(k.reshape(-1, self.head_dim)).reshape(
|
||||
-1, self.kv_size)
|
||||
|
||||
q, k = maybe_execute_in_parallel(
|
||||
q_l2norm,
|
||||
k_l2norm,
|
||||
self.ln_events[0],
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
|
||||
return q, k
|
||||
|
||||
def apply_qk_norm_rope(self, qkv, position_ids):
|
||||
factor, low, high, attention_factor = compute_yarn_parameters(
|
||||
self.pretrained_config)
|
||||
torch.ops.trtllm.fused_qk_norm_rope(
|
||||
qkv, self.num_heads, self.num_key_value_heads,
|
||||
self.num_key_value_heads, self.head_dim,
|
||||
self.q_norm.variance_epsilon, self.q_norm.weight,
|
||||
self.k_norm.weight,
|
||||
self.pos_embd_params.rope.theta, self.pos_embd_params.is_neox,
|
||||
position_ids.view(-1), factor, low, high, attention_factor)
|
||||
return qkv, None, None
|
||||
|
||||
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
|
||||
v: Optional[torch.Tensor], position_ids: torch.Tensor):
|
||||
# Qwen3 applies QK norm before RoPE.
|
||||
if not self.fuse_qk_norm_rope:
|
||||
q, k, v = self.split_qkv(q, k, v)
|
||||
q, k = self.apply_qk_norm(q, k)
|
||||
return super().apply_rope(q, k, v, position_ids)
|
||||
|
||||
assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope"
|
||||
qkv = q
|
||||
return self.apply_qk_norm_rope(qkv, position_ids)
|
||||
|
||||
|
||||
class Qwen3DecoderLayer(DecoderLayer):
|
||||
|
||||
@ -235,7 +69,7 @@ class Qwen3DecoderLayer(DecoderLayer):
|
||||
self,
|
||||
model_config: ModelConfig[Qwen3Config],
|
||||
layer_idx: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
config = model_config.pretrained_config
|
||||
|
||||
243
tensorrt_llm/_torch/modules/qk_norm_attention.py
Normal file
243
tensorrt_llm/_torch/modules/qk_norm_attention.py
Normal file
@ -0,0 +1,243 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from ..attention_backend.interface import PositionalEmbeddingParams
|
||||
from ..model_config import ModelConfig
|
||||
from ..modules.attention import Attention
|
||||
from ..modules.multi_stream_utils import maybe_execute_in_parallel
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
|
||||
|
||||
# Move out from this class
|
||||
def compute_yarn_parameters(
|
||||
config: PretrainedConfig, ) -> tuple[float, float, float, float]:
|
||||
"""
|
||||
Refer to https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L197C1-L288C1
|
||||
Computes the inverse frequencies with NTK scaling. Please refer to the
|
||||
[original paper](https://huggingface.co/papers/2309.00071)
|
||||
Args:
|
||||
config ([`~transformers.PretrainedConfig`]):
|
||||
The model configuration.
|
||||
Returns:
|
||||
factor: float, the scaling factor for the RoPE embeddings
|
||||
low: float, the lower bound of the dimension range
|
||||
high: float, the upper bound of the dimension range
|
||||
attention_factor: float, the post-processing scaling factor applied to the computed cos/sin
|
||||
"""
|
||||
|
||||
# If config does not contain rope_scaling or rope_type is not yarn, it means the model is not using yarn
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is None or getattr(rope_scaling, "rope_type",
|
||||
None) != "yarn":
|
||||
return 1.0, 0, 0, 1.0
|
||||
|
||||
base = config.rope_theta
|
||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(
|
||||
config, "partial_rotary_factor") else 1.0
|
||||
head_dim = getattr(config, "head_dim",
|
||||
config.hidden_size // config.num_attention_heads)
|
||||
dim = int(head_dim * partial_rotary_factor)
|
||||
factor = getattr(rope_scaling, "factor", 1.0)
|
||||
attention_factor = rope_scaling.get("attention_factor")
|
||||
mscale = rope_scaling.get("mscale")
|
||||
mscale_all_dim = rope_scaling.get("mscale_all_dim")
|
||||
|
||||
if "original_max_position_embeddings" in rope_scaling:
|
||||
original_max_position_embeddings = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
factor = config.max_position_embeddings / original_max_position_embeddings
|
||||
else:
|
||||
original_max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
def get_mscale(scale, mscale=1):
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
# Sets the attention factor as suggested in the paper
|
||||
if attention_factor is None:
|
||||
if mscale and mscale_all_dim:
|
||||
attention_factor = float(
|
||||
get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
|
||||
else:
|
||||
attention_factor = get_mscale(factor)
|
||||
|
||||
# Optional config options
|
||||
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
|
||||
beta_fast = rope_scaling.get("beta_fast") or 32
|
||||
beta_slow = rope_scaling.get("beta_slow") or 1
|
||||
|
||||
# Compute the inverse frequencies
|
||||
def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
|
||||
"""Inverse dimension formula to find the dimension based on the number of rotations"""
|
||||
return (dim *
|
||||
math.log(max_position_embeddings /
|
||||
(num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
||||
|
||||
def find_correction_range(low_rot, high_rot, dim, base,
|
||||
max_position_embeddings, truncate):
|
||||
"""Find dimension range bounds based on rotations"""
|
||||
low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
||||
high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
||||
if truncate:
|
||||
low = math.floor(low)
|
||||
high = math.ceil(high)
|
||||
return max(low, 0), min(high, dim - 1)
|
||||
|
||||
truncate = rope_scaling.get("truncate", True)
|
||||
low, high = find_correction_range(beta_fast, beta_slow, dim, base,
|
||||
original_max_position_embeddings,
|
||||
truncate)
|
||||
|
||||
# These parts are implemented in the fusedQKNormRopeKernel.cu
|
||||
# # def linear_ramp_factor(min, max, dim):
|
||||
# # if min == max:
|
||||
# # max += 0.001 # Prevent singularity
|
||||
|
||||
# # linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
||||
# # ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
# # return ramp_func
|
||||
|
||||
# # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
|
||||
# # to expand the possible context length. In other words, interpolation = apply scaling factor.
|
||||
# # pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim)
|
||||
# # inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
# # inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
||||
|
||||
# # # Get n-dimensional rotational scaling corrected for extrapolation
|
||||
# # inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
|
||||
# # inv_freq = (
|
||||
# # inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
||||
# # + inv_freq_extrapolation * inv_freq_extrapolation_factor
|
||||
# # )
|
||||
# # return inv_freq, attention_factor
|
||||
return factor, low, high, attention_factor
|
||||
|
||||
|
||||
class QKNormRoPEAttention(Attention):
|
||||
"""
|
||||
QKNormRoPEAttention is a custom attention layer that applies QK norm and RoPE to the input tensor.
|
||||
It is used in the ExaOne4, Gemma3 and Qwen3 models.
|
||||
It is a subclass of Attention, and overrides the apply_rope method to apply QK norm and RoPE.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
max_position_embeddings: int,
|
||||
bias: bool,
|
||||
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
|
||||
skip_rope: bool = False,
|
||||
fuse_qk_norm_rope: bool = True,
|
||||
layer_idx: Optional[int] = None,
|
||||
dtype: torch.dtype = None,
|
||||
dense_bias: Optional[bool] = None,
|
||||
config: ModelConfig,
|
||||
q_scaling: float = 1.0,
|
||||
):
|
||||
self.pretrained_config = config.pretrained_config
|
||||
|
||||
self.fuse_qk_norm_rope = fuse_qk_norm_rope
|
||||
self.skip_rope = skip_rope
|
||||
assert not (fuse_qk_norm_rope and skip_rope
|
||||
), "Fusing qk norm and skipping rope is not supported"
|
||||
|
||||
super().__init__(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
bias=bias,
|
||||
pos_embd_params=pos_embd_params,
|
||||
# If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP,
|
||||
# and self.rotary_emb will be skipped in the overridden apply_rope.
|
||||
rope_fusion=not self.fuse_qk_norm_rope and not skip_rope,
|
||||
layer_idx=layer_idx,
|
||||
dtype=dtype,
|
||||
dense_bias=dense_bias,
|
||||
config=config,
|
||||
q_scaling=q_scaling,
|
||||
)
|
||||
|
||||
self.q_norm = RMSNorm(hidden_size=self.head_dim,
|
||||
eps=self.pretrained_config.rms_norm_eps,
|
||||
dtype=self.pretrained_config.torch_dtype,
|
||||
has_weights=True)
|
||||
self.k_norm = RMSNorm(hidden_size=self.head_dim,
|
||||
eps=self.pretrained_config.rms_norm_eps,
|
||||
dtype=self.pretrained_config.torch_dtype,
|
||||
has_weights=True)
|
||||
self.aux_stream = torch.cuda.Stream()
|
||||
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
|
||||
|
||||
def apply_qk_norm(self, q, k):
|
||||
|
||||
def q_l2norm():
|
||||
return self.q_norm(q.reshape(-1, self.head_dim)).reshape(
|
||||
-1, self.q_size)
|
||||
|
||||
def k_l2norm():
|
||||
return self.k_norm(k.reshape(-1, self.head_dim)).reshape(
|
||||
-1, self.kv_size)
|
||||
|
||||
q, k = maybe_execute_in_parallel(
|
||||
q_l2norm,
|
||||
k_l2norm,
|
||||
self.ln_events[0],
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
|
||||
return q, k
|
||||
|
||||
def apply_qk_norm_rope(self, qkv, position_ids):
|
||||
factor, low, high, attention_factor = compute_yarn_parameters(
|
||||
self.pretrained_config)
|
||||
torch.ops.trtllm.fused_qk_norm_rope(
|
||||
qkv, self.num_heads, self.num_key_value_heads,
|
||||
self.num_key_value_heads, self.head_dim,
|
||||
self.q_norm.variance_epsilon, self.q_norm.weight,
|
||||
self.k_norm.weight,
|
||||
self.pos_embd_params.rope.theta, self.pos_embd_params.is_neox,
|
||||
position_ids.view(-1), factor, low, high, attention_factor)
|
||||
return qkv, None, None
|
||||
|
||||
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
|
||||
v: Optional[torch.Tensor], position_ids: torch.Tensor):
|
||||
"""
|
||||
The apply_rope method is called in the forward method of the Attention class.
|
||||
The apply_rope method is overridden in this class to apply QK norm and RoPE to the input tensor.
|
||||
"""
|
||||
# Apply QK norm before RoPE.
|
||||
if not self.fuse_qk_norm_rope:
|
||||
q, k, v = self.split_qkv(q, k, v)
|
||||
q, k = self.apply_qk_norm(q, k)
|
||||
if not self.skip_rope:
|
||||
return super().apply_rope(q, k, v, position_ids)
|
||||
else:
|
||||
return q, k, v
|
||||
|
||||
assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope"
|
||||
qkv = q
|
||||
return self.apply_qk_norm_rope(qkv, position_ids)
|
||||
@ -504,7 +504,6 @@ def create_py_executor_instance(
|
||||
resources,
|
||||
mapping,
|
||||
pytorch_backend_config,
|
||||
executor_config,
|
||||
ctx_chunk_config,
|
||||
model_engine,
|
||||
start_worker,
|
||||
@ -515,13 +514,19 @@ def create_py_executor_instance(
|
||||
garbage_collection_gen0_threshold: Optional[int] = None,
|
||||
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
|
||||
max_seq_len: Optional[int] = None,
|
||||
max_batch_size: Optional[int] = None,
|
||||
max_beam_width: Optional[int] = None,
|
||||
max_num_tokens: Optional[int] = None,
|
||||
peft_cache_config: Optional[trtllm.PeftCacheConfig] = None,
|
||||
scheduler_config: Optional[trtllm.SchedulerConfig] = None,
|
||||
cache_transceiver_config: Optional[trtllm.CacheTransceiverConfig] = None,
|
||||
) -> PyExecutor:
|
||||
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
|
||||
|
||||
spec_config = model_engine.spec_config
|
||||
|
||||
logger.info(
|
||||
f"max_seq_len={executor_config.max_seq_len}, max_num_requests={executor_config.max_batch_size}, max_num_tokens={executor_config.max_num_tokens}, max_batch_size={executor_config.max_batch_size}"
|
||||
f"max_seq_len={max_seq_len}, max_num_requests={max_batch_size}, max_num_tokens={max_num_tokens}, max_batch_size={max_batch_size}"
|
||||
)
|
||||
|
||||
for key, value in pytorch_backend_config.extra_resource_managers.items():
|
||||
@ -578,16 +583,15 @@ def create_py_executor_instance(
|
||||
len(lora_config.lora_target_modules + lora_config.missing_qkv_modules)
|
||||
|
||||
peft_cache_config_model = PeftCacheConfig.from_pybind(
|
||||
executor_config.peft_cache_config
|
||||
) if executor_config.peft_cache_config is not None else PeftCacheConfig(
|
||||
)
|
||||
peft_cache_config
|
||||
) if peft_cache_config is not None else PeftCacheConfig()
|
||||
if lora_config.max_loras is not None:
|
||||
peft_cache_config_model.num_device_module_layer = \
|
||||
max_lora_rank * num_lora_modules * lora_config.max_loras
|
||||
if lora_config.max_cpu_loras is not None:
|
||||
peft_cache_config_model.num_host_module_layer = \
|
||||
max_lora_rank * num_lora_modules * lora_config.max_cpu_loras
|
||||
executor_config.peft_cache_config = peft_cache_config_model._to_pybind()
|
||||
peft_cache_config = peft_cache_config_model._to_pybind()
|
||||
|
||||
from tensorrt_llm.bindings import WorldConfig
|
||||
world_config = WorldConfig(
|
||||
@ -598,7 +602,7 @@ def create_py_executor_instance(
|
||||
gpus_per_node=dist.mapping.gpus_per_node,
|
||||
)
|
||||
peft_cache_manager = PeftCacheManager(
|
||||
peft_cache_config=executor_config.peft_cache_config,
|
||||
peft_cache_config=peft_cache_config,
|
||||
lora_config=lora_config,
|
||||
model_config=model_binding_config,
|
||||
world_config=world_config,
|
||||
@ -609,7 +613,7 @@ def create_py_executor_instance(
|
||||
lora_config.trtllm_modules_to_hf_modules,
|
||||
lora_config.swap_gate_up_proj_lora_b_weight)
|
||||
|
||||
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
|
||||
max_num_sequences = max_batch_size * mapping.pp_size
|
||||
|
||||
resources[ResourceManagerType.SEQ_SLOT_MANAGER] = SeqSlotManager(
|
||||
max_num_sequences)
|
||||
@ -632,17 +636,15 @@ def create_py_executor_instance(
|
||||
scheduler_capacity,
|
||||
kv_cache_manager.impl if kv_cache_manager is not None else None,
|
||||
peft_cache_manager.impl if peft_cache_manager is not None else None,
|
||||
executor_config.scheduler_config.capacity_scheduler_policy,
|
||||
scheduler_config.capacity_scheduler_policy,
|
||||
two_step_lookahead=mapping.has_pp())
|
||||
mb_scheduler = BindMicroBatchScheduler(executor_config.max_batch_size,
|
||||
executor_config.max_num_tokens,
|
||||
mb_scheduler = BindMicroBatchScheduler(max_batch_size, max_num_tokens,
|
||||
ctx_chunk_config)
|
||||
scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler)
|
||||
|
||||
config = model_engine.model.model_config.pretrained_config
|
||||
attention_type = AttentionTypeCpp.MLA if is_mla(
|
||||
config) else AttentionTypeCpp.DEFAULT
|
||||
cache_transceiver_config = executor_config.cache_transceiver_config
|
||||
kv_cache_transceiver = create_kv_cache_transceiver(
|
||||
mapping, kv_cache_manager, attention_type, cache_transceiver_config)
|
||||
return PyExecutor(
|
||||
@ -655,8 +657,8 @@ def create_py_executor_instance(
|
||||
max_num_sequences=max_num_sequences,
|
||||
disable_overlap_scheduler=pytorch_backend_config.
|
||||
disable_overlap_scheduler,
|
||||
max_batch_size=executor_config.max_batch_size,
|
||||
max_beam_width=executor_config.max_beam_width,
|
||||
max_batch_size=max_batch_size,
|
||||
max_beam_width=max_beam_width,
|
||||
max_draft_len=spec_config.max_draft_len
|
||||
if spec_config is not None else 0,
|
||||
kv_cache_transceiver=kv_cache_transceiver,
|
||||
@ -664,7 +666,8 @@ def create_py_executor_instance(
|
||||
start_worker=start_worker,
|
||||
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
|
||||
kv_connector_manager=kv_connector_manager,
|
||||
max_seq_len=max_seq_len)
|
||||
max_seq_len=max_seq_len,
|
||||
peft_cache_config=peft_cache_config)
|
||||
|
||||
|
||||
def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,
|
||||
|
||||
@ -25,8 +25,8 @@ from tensorrt_llm._utils import (customized_gc_thresholds, global_mpi_rank,
|
||||
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
|
||||
FinishReason, InflightBatchingStats,
|
||||
IterationStats, KvCacheStats,
|
||||
RequestStage, RequestStats,
|
||||
SpecDecodingStats,
|
||||
PeftCacheConfig, RequestStage,
|
||||
RequestStats, SpecDecodingStats,
|
||||
StaticBatchingStats)
|
||||
from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType,
|
||||
ReqIdsSet)
|
||||
@ -157,11 +157,14 @@ class PyExecutor:
|
||||
garbage_collection_gen0_threshold: Optional[int] = None,
|
||||
start_worker: bool = True,
|
||||
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
|
||||
max_seq_len: Optional[int] = None):
|
||||
max_seq_len: Optional[int] = None,
|
||||
peft_cache_config: Optional[PeftCacheConfig] = None):
|
||||
super(PyExecutor, self).__init__()
|
||||
self.device_id = torch.cuda.current_device()
|
||||
self.global_rank = global_mpi_rank()
|
||||
|
||||
self.peft_cache_config = peft_cache_config
|
||||
|
||||
# profile config
|
||||
self.profile_start_iters, self.profile_stop_iters = _load_iteration_indexes(
|
||||
PROFILE_START_STOP_ENV_VAR_NAME)
|
||||
|
||||
@ -226,7 +226,7 @@ def create_py_executor(
|
||||
mapping = _get_mapping(executor_config)
|
||||
|
||||
dist = MPIDist(mapping=mapping)
|
||||
|
||||
cache_transceiver_config = executor_config.cache_transceiver_config
|
||||
spec_config = executor_config.speculative_config
|
||||
has_draft_model_engine = False
|
||||
has_spec_drafter = False
|
||||
@ -508,7 +508,6 @@ def create_py_executor(
|
||||
resources=resources,
|
||||
mapping=mapping,
|
||||
pytorch_backend_config=pytorch_backend_config,
|
||||
executor_config=executor_config,
|
||||
ctx_chunk_config=ctx_chunk_config,
|
||||
model_engine=model_engine,
|
||||
start_worker=False,
|
||||
@ -520,7 +519,16 @@ def create_py_executor(
|
||||
kv_connector_manager=kv_connector_manager
|
||||
if not estimating_kv_cache else None,
|
||||
max_seq_len=executor_config.max_seq_len,
|
||||
max_batch_size=executor_config.max_batch_size,
|
||||
max_beam_width=executor_config.max_beam_width,
|
||||
max_num_tokens=executor_config.max_num_tokens,
|
||||
peft_cache_config=executor_config.peft_cache_config,
|
||||
scheduler_config=executor_config.scheduler_config,
|
||||
cache_transceiver_config=cache_transceiver_config,
|
||||
)
|
||||
# Modify the executor_config.peft_cache_config which might be mutated
|
||||
# inside create_py_executor_instance
|
||||
executor_config.peft_cache_config = py_executor.peft_cache_config
|
||||
|
||||
if estimating_kv_cache:
|
||||
assert kv_cache_creator is not None
|
||||
@ -553,7 +561,6 @@ def create_py_executor(
|
||||
resources=resources,
|
||||
mapping=mapping,
|
||||
pytorch_backend_config=pytorch_backend_config,
|
||||
executor_config=executor_config,
|
||||
ctx_chunk_config=ctx_chunk_config,
|
||||
model_engine=model_engine,
|
||||
start_worker=False,
|
||||
@ -565,6 +572,12 @@ def create_py_executor(
|
||||
garbage_collection_gen0_threshold,
|
||||
kv_connector_manager=kv_connector_manager,
|
||||
max_seq_len=executor_config.max_seq_len,
|
||||
max_batch_size=executor_config.max_batch_size,
|
||||
max_beam_width=executor_config.max_beam_width,
|
||||
max_num_tokens=executor_config.max_num_tokens,
|
||||
peft_cache_config=executor_config.peft_cache_config,
|
||||
scheduler_config=executor_config.scheduler_config,
|
||||
cache_transceiver_config=cache_transceiver_config,
|
||||
)
|
||||
|
||||
_adjust_torch_mem_fraction(executor_config.pytorch_backend_config)
|
||||
|
||||
@ -353,3 +353,5 @@ test_e2e.py::test_multi_nodes_eval[llama4-models/nvidia/Llama-4-Maverick-17B-128
|
||||
accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl SKIP (https://nvbugs/5413362)
|
||||
stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-GUARANTEED_NO_EVICT-pytorch-stress-test] SKIP (https://nvbugs/5474169)
|
||||
test_e2e.py::test_trtllm_bench_iteration_log[TRT-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B] SKIP (https://nvbugs/5448523)
|
||||
cpp/test_unit_tests.py::test_unit_tests[kernels-80] SKIP (https://nvbugs/5504078)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5504086)
|
||||
|
||||
@ -389,11 +389,6 @@ def test_attention_mla(scenario: Scenario, context_sequence_lengths: List[int],
|
||||
dtype = scenario.dtype
|
||||
kv_cache_dtype = scenario.kv_cache_dtype
|
||||
|
||||
FAILED_CSL = [777, 912, 431, 42, 266, 989, 524]
|
||||
if (kv_cache_dtype is torch.float8_e4m3fn
|
||||
and context_sequence_lengths == FAILED_CSL):
|
||||
pytest.skip("https://nvbugs/5453806")
|
||||
|
||||
print(
|
||||
f"--------------------------------Test for scenario: {scenario} start--------------------------------"
|
||||
)
|
||||
|
||||
@ -16,7 +16,6 @@ from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5461761")
|
||||
@pytest.mark.parametrize(
|
||||
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill",
|
||||
[
|
||||
|
||||
@ -594,6 +594,7 @@ def llm_for_sampling_params():
|
||||
llm.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5504095")
|
||||
@pytest.mark.part0
|
||||
def test_user_specify_workspace():
|
||||
user_specified_ws_path = '/tmp/specified_workspace'
|
||||
|
||||
Loading…
Reference in New Issue
Block a user