diff --git a/tensorrt_llm/_torch/models/modeling_exaone4.py b/tensorrt_llm/_torch/models/modeling_exaone4.py index f766eb4ea1..86147787aa 100644 --- a/tensorrt_llm/_torch/models/modeling_exaone4.py +++ b/tensorrt_llm/_torch/models/modeling_exaone4.py @@ -3,18 +3,17 @@ from typing import Optional, Tuple import torch from torch import nn +from tensorrt_llm._torch.modules.qk_norm_attention import QKNormRoPEAttention from tensorrt_llm.functional import PositionEmbeddingType from ..attention_backend import AttentionMetadata from ..attention_backend.interface import (PositionalEmbeddingParams, PredefinedAttentionMask, RopeParams) from ..model_config import ModelConfig -from ..modules.attention import Attention from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding from ..modules.gated_mlp import GatedMLP from ..modules.linear import TensorParallelMode -from ..modules.multi_stream_utils import maybe_execute_in_parallel from ..modules.rms_norm import RMSNorm from ..speculative import SpecMetadata from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, @@ -50,12 +49,11 @@ def check_is_sliding(config: Exaone4Config, layer_idx: int) -> bool: return False -class Exaone4Attention(Attention): +class Exaone4Attention(QKNormRoPEAttention): def __init__(self, model_config: ModelConfig[Exaone4Config], layer_idx: Optional[int] = None, - aux_stream: Optional[torch.cuda.Stream] = None, fuse_qk_norm_rope: bool = False): config = model_config.pretrained_config @@ -73,10 +71,10 @@ class Exaone4Attention(Attention): rope=RopeParams.from_config(config), ) - self.fuse_qk_norm_rope = (self.is_sliding and fuse_qk_norm_rope) + fuse_qk_norm_rope = (self.is_sliding and fuse_qk_norm_rope) # TODO: Fusing qk norm with rope has an issue that slightly hurts accuracy. - assert self.fuse_qk_norm_rope is False, "Fusing qk norm and rope is having issue now" + assert fuse_qk_norm_rope is False, "Fusing qk norm and rope is having issue now" super().__init__( hidden_size=config.hidden_size, @@ -85,65 +83,13 @@ class Exaone4Attention(Attention): max_position_embeddings=config.max_position_embeddings, bias=False, pos_embd_params=pos_embd_params, - rope_fusion=False, + fuse_qk_norm_rope=fuse_qk_norm_rope, + skip_rope=self.sliding_window and not self.is_sliding, layer_idx=layer_idx, dtype=config.torch_dtype, config=model_config, ) - self.q_norm = RMSNorm(hidden_size=self.head_dim, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) - self.k_norm = RMSNorm(hidden_size=self.head_dim, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) - - self.aux_stream = aux_stream - self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] - - def apply_qk_norm(self, q, k): - - def q_l2norm(): - return self.q_norm(q.reshape(-1, self.head_dim)).reshape( - -1, self.q_size) - - def k_l2norm(): - return self.k_norm(k.reshape(-1, self.head_dim)).reshape( - -1, self.kv_size) - - q, k = maybe_execute_in_parallel( - q_l2norm, - k_l2norm, - self.ln_events[0], - self.ln_events[1], - self.aux_stream, - ) - - return q, k - - def apply_qk_norm_rope(self, qkv, position_ids): - torch.ops.trtllm.fused_qk_norm_rope( - qkv, self.num_heads, self.num_key_value_heads, - self.num_key_value_heads, self.head_dim, - self.q_norm.variance_epsilon, self.q_norm.weight, - self.k_norm.weight, self.pos_embd_params.rope.theta, - self.pos_embd_params.is_neox, position_ids.view(-1)) - return qkv, None, None - - def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor], - v: Optional[torch.Tensor], position_ids: torch.Tensor): - if self.fuse_qk_norm_rope: - assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope" - qkv = q - return self.apply_qk_norm_rope(qkv, position_ids) - - q, k, v = self.split_qkv(q, k, v) - q, k = self.apply_qk_norm(q, k) - if self.sliding_window is None or self.is_sliding: - return super().apply_rope(q, k, v, position_ids) - else: - return q, k, v - def forward( self, position_ids: Optional[torch.LongTensor], @@ -175,7 +121,6 @@ class Exaone4DecoderLayer(DecoderLayer): self, model_config: ModelConfig[Exaone4Config], layer_idx: int, - aux_stream: Optional[torch.cuda.Stream] = None, ): super().__init__() config = model_config.pretrained_config @@ -186,7 +131,6 @@ class Exaone4DecoderLayer(DecoderLayer): self.self_attn = Exaone4Attention( model_config, layer_idx=layer_idx, - aux_stream=aux_stream, ) self.mlp = GatedMLP( @@ -244,7 +188,6 @@ class Exaone4Model(DecoderModel): super().__init__(model_config) config = self.model_config.pretrained_config self.num_hidden_layers = config.num_hidden_layers - self.aux_stream = torch.cuda.Stream() self.embed_tokens = Embedding( config.vocab_size, config.hidden_size, @@ -258,7 +201,6 @@ class Exaone4Model(DecoderModel): Exaone4DecoderLayer( model_config, layer_idx, - self.aux_stream, ) for layer_idx in range(self.num_hidden_layers) ]) diff --git a/tensorrt_llm/_torch/models/modeling_gemma3.py b/tensorrt_llm/_torch/models/modeling_gemma3.py index ccbe0165ca..1667b755b2 100644 --- a/tensorrt_llm/_torch/models/modeling_gemma3.py +++ b/tensorrt_llm/_torch/models/modeling_gemma3.py @@ -7,6 +7,7 @@ from transformers import Gemma3TextConfig from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \ BaseWeightMapper +from tensorrt_llm._torch.modules.qk_norm_attention import QKNormRoPEAttention from tensorrt_llm.functional import PositionEmbeddingType, RotaryScalingType from tensorrt_llm.mapping import Mapping @@ -15,12 +16,10 @@ from ..attention_backend.interface import (AttentionMask, CustomAttentionMask, PositionalEmbeddingParams, PredefinedAttentionMask, RopeParams) from ..model_config import ModelConfig -from ..modules.attention import Attention from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding from ..modules.gated_mlp import GatedMLP from ..modules.linear import TensorParallelMode -from ..modules.multi_stream_utils import maybe_execute_in_parallel from ..modules.rms_norm import RMSNorm from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, register_auto_model) @@ -52,7 +51,7 @@ class Gemma3TextScaledWordEmbedding(Embedding): return super().forward(input_ids) * self.embed_scale -class Gemma3Attention(Attention): +class Gemma3Attention(QKNormRoPEAttention): def __init__( self, @@ -82,20 +81,13 @@ class Gemma3Attention(Attention): max_position_embeddings=config.max_position_embeddings, bias=False, pos_embd_params=pos_embd_params, + fuse_qk_norm_rope=False, layer_idx=layer_idx, dtype=config.torch_dtype, dense_bias=False, config=model_config, q_scaling=q_scaling, ) - self.q_norm = RMSNorm(hidden_size=config.head_dim, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) - self.k_norm = RMSNorm(hidden_size=config.head_dim, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) - self.aux_stream = torch.cuda.Stream() - self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] @torch.inference_mode() def forward( @@ -121,33 +113,6 @@ class Gemma3Attention(Attention): attention_mask_data=attention_mask_data, **kwargs) - def apply_qk_norm(self, q, k): - - def q_l2norm(): - return self.q_norm(q.reshape(-1, self.head_dim)).reshape( - -1, self.q_size) - - def k_l2norm(): - return self.k_norm(k.reshape(-1, self.head_dim)).reshape( - -1, self.kv_size) - - q, k = maybe_execute_in_parallel( - q_l2norm, - k_l2norm, - self.ln_events[0], - self.ln_events[1], - self.aux_stream, - ) - - return q, k - - def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor], - v: Optional[torch.Tensor], position_ids: torch.Tensor): - # Gemma3 applies QK norm before RoPE. - q, k, v = self.split_qkv(q, k, v) - q, k = self.apply_qk_norm(q, k) - return super().apply_rope(q, k, v, position_ids) - # This function is written to be compatible with TRTLLM's GatedMLP class. def pytorch_gelu_tanh(gate_x: torch.Tensor) -> torch.Tensor: diff --git a/tensorrt_llm/_torch/models/modeling_qwen3.py b/tensorrt_llm/_torch/models/modeling_qwen3.py index 48a73e85f1..cbd2ebb983 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3.py @@ -1,133 +1,26 @@ -import math from typing import Optional, Tuple import torch from torch import nn -from transformers import PretrainedConfig, Qwen3Config +from transformers import Qwen3Config from tensorrt_llm.functional import PositionEmbeddingType from ..attention_backend import AttentionMetadata from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams from ..model_config import ModelConfig -from ..modules.attention import Attention from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding from ..modules.gated_mlp import GatedMLP from ..modules.linear import TensorParallelMode -from ..modules.multi_stream_utils import maybe_execute_in_parallel +from ..modules.qk_norm_attention import QKNormRoPEAttention from ..modules.rms_norm import RMSNorm from ..speculative import SpecMetadata from .modeling_speculative import SpecDecOneEngineForCausalLM from .modeling_utils import DecoderModel, register_auto_model -# Move out from this class -def compute_yarn_parameters( - config: PretrainedConfig, ) -> tuple[float, float, float, float]: - """ - Refer to https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L197C1-L288C1 - Computes the inverse frequencies with NTK scaling. Please refer to the - [original paper](https://huggingface.co/papers/2309.00071) - Args: - config ([`~transformers.PretrainedConfig`]): - The model configuration. - Returns: - factor: float, the scaling factor for the RoPE embeddings - low: float, the lower bound of the dimension range - high: float, the upper bound of the dimension range - attention_factor: float, the post-processing scaling factor applied to the computed cos/sin - """ - - # The config does not contain rope_scaling, which means the model is not using yarn - rope_scaling = getattr(config, "rope_scaling", None) - if rope_scaling is None: - return 1.0, 0, 0, 1.0 - - base = config.rope_theta - partial_rotary_factor = config.partial_rotary_factor if hasattr( - config, "partial_rotary_factor") else 1.0 - head_dim = getattr(config, "head_dim", - config.hidden_size // config.num_attention_heads) - dim = int(head_dim * partial_rotary_factor) - factor = getattr(rope_scaling, "factor", 1.0) - attention_factor = rope_scaling.get("attention_factor") - mscale = rope_scaling.get("mscale") - mscale_all_dim = rope_scaling.get("mscale_all_dim") - - if "original_max_position_embeddings" in rope_scaling: - original_max_position_embeddings = rope_scaling[ - "original_max_position_embeddings"] - factor = config.max_position_embeddings / original_max_position_embeddings - else: - original_max_position_embeddings = config.max_position_embeddings - - def get_mscale(scale, mscale=1): - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - # Sets the attention factor as suggested in the paper - if attention_factor is None: - if mscale and mscale_all_dim: - attention_factor = float( - get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim)) - else: - attention_factor = get_mscale(factor) - - # Optional config options - # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) - beta_fast = rope_scaling.get("beta_fast") or 32 - beta_slow = rope_scaling.get("beta_slow") or 1 - - # Compute the inverse frequencies - def find_correction_dim(num_rotations, dim, base, max_position_embeddings): - """Inverse dimension formula to find the dimension based on the number of rotations""" - return (dim * - math.log(max_position_embeddings / - (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - def find_correction_range(low_rot, high_rot, dim, base, - max_position_embeddings, truncate): - """Find dimension range bounds based on rotations""" - low = find_correction_dim(low_rot, dim, base, max_position_embeddings) - high = find_correction_dim(high_rot, dim, base, max_position_embeddings) - if truncate: - low = math.floor(low) - high = math.ceil(high) - return max(low, 0), min(high, dim - 1) - - truncate = rope_scaling.get("truncate", True) - low, high = find_correction_range(beta_fast, beta_slow, dim, base, - original_max_position_embeddings, - truncate) - - # These parts are implemented in the fusedQKNormRopeKernel.cu - # # def linear_ramp_factor(min, max, dim): - # # if min == max: - # # max += 0.001 # Prevent singularity - - # # linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - # # ramp_func = torch.clamp(linear_func, 0, 1) - # # return ramp_func - - # # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs - # # to expand the possible context length. In other words, interpolation = apply scaling factor. - # # pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim) - # # inv_freq_extrapolation = 1.0 / pos_freqs - # # inv_freq_interpolation = 1.0 / (factor * pos_freqs) - - # # # Get n-dimensional rotational scaling corrected for extrapolation - # # inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float) - # # inv_freq = ( - # # inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) - # # + inv_freq_extrapolation * inv_freq_extrapolation_factor - # # ) - # # return inv_freq, attention_factor - return factor, low, high, attention_factor - - -class Qwen3Attention(Attention): +class Qwen3Attention(QKNormRoPEAttention): def __init__( self, @@ -136,7 +29,6 @@ class Qwen3Attention(Attention): fuse_qk_norm_rope: bool = True, ): config = model_config.pretrained_config - self.pretrained_config = config if getattr(config, "rope_scaling", None) is not None: if "type" in config.rope_scaling: @@ -156,8 +48,6 @@ class Qwen3Attention(Attention): rope=RopeParams.from_config(config), ) - self.fuse_qk_norm_rope = fuse_qk_norm_rope - super().__init__( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, @@ -165,69 +55,13 @@ class Qwen3Attention(Attention): max_position_embeddings=config.max_position_embeddings, bias=config.attention_bias, pos_embd_params=pos_embd_params, - rope_fusion=not self. - fuse_qk_norm_rope, # If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP, and self.rotary_emb will be skipped in the overridden apply_rope. + fuse_qk_norm_rope=fuse_qk_norm_rope, layer_idx=layer_idx, dtype=config.torch_dtype, dense_bias=config.attention_bias, config=model_config, ) - self.q_norm = RMSNorm(hidden_size=self.head_dim, - eps=1e-6, - dtype=config.torch_dtype, - has_weights=True) - self.k_norm = RMSNorm(hidden_size=self.head_dim, - eps=1e-6, - dtype=config.torch_dtype, - has_weights=True) - self.aux_stream = torch.cuda.Stream() - self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] - - def apply_qk_norm(self, q, k): - - def q_l2norm(): - return self.q_norm(q.reshape(-1, self.head_dim)).reshape( - -1, self.q_size) - - def k_l2norm(): - return self.k_norm(k.reshape(-1, self.head_dim)).reshape( - -1, self.kv_size) - - q, k = maybe_execute_in_parallel( - q_l2norm, - k_l2norm, - self.ln_events[0], - self.ln_events[1], - self.aux_stream, - ) - - return q, k - - def apply_qk_norm_rope(self, qkv, position_ids): - factor, low, high, attention_factor = compute_yarn_parameters( - self.pretrained_config) - torch.ops.trtllm.fused_qk_norm_rope( - qkv, self.num_heads, self.num_key_value_heads, - self.num_key_value_heads, self.head_dim, - self.q_norm.variance_epsilon, self.q_norm.weight, - self.k_norm.weight, - self.pos_embd_params.rope.theta, self.pos_embd_params.is_neox, - position_ids.view(-1), factor, low, high, attention_factor) - return qkv, None, None - - def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor], - v: Optional[torch.Tensor], position_ids: torch.Tensor): - # Qwen3 applies QK norm before RoPE. - if not self.fuse_qk_norm_rope: - q, k, v = self.split_qkv(q, k, v) - q, k = self.apply_qk_norm(q, k) - return super().apply_rope(q, k, v, position_ids) - - assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope" - qkv = q - return self.apply_qk_norm_rope(qkv, position_ids) - class Qwen3DecoderLayer(DecoderLayer): @@ -235,7 +69,7 @@ class Qwen3DecoderLayer(DecoderLayer): self, model_config: ModelConfig[Qwen3Config], layer_idx: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ): super().__init__() self.layer_idx = layer_idx config = model_config.pretrained_config diff --git a/tensorrt_llm/_torch/modules/qk_norm_attention.py b/tensorrt_llm/_torch/modules/qk_norm_attention.py new file mode 100644 index 0000000000..6c146e4bfc --- /dev/null +++ b/tensorrt_llm/_torch/modules/qk_norm_attention.py @@ -0,0 +1,243 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch +from transformers import PretrainedConfig + +from ..attention_backend.interface import PositionalEmbeddingParams +from ..model_config import ModelConfig +from ..modules.attention import Attention +from ..modules.multi_stream_utils import maybe_execute_in_parallel +from ..modules.rms_norm import RMSNorm + + +# Move out from this class +def compute_yarn_parameters( + config: PretrainedConfig, ) -> tuple[float, float, float, float]: + """ + Refer to https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L197C1-L288C1 + Computes the inverse frequencies with NTK scaling. Please refer to the + [original paper](https://huggingface.co/papers/2309.00071) + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + Returns: + factor: float, the scaling factor for the RoPE embeddings + low: float, the lower bound of the dimension range + high: float, the upper bound of the dimension range + attention_factor: float, the post-processing scaling factor applied to the computed cos/sin + """ + + # If config does not contain rope_scaling or rope_type is not yarn, it means the model is not using yarn + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is None or getattr(rope_scaling, "rope_type", + None) != "yarn": + return 1.0, 0, 0, 1.0 + + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr( + config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", + config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + factor = getattr(rope_scaling, "factor", 1.0) + attention_factor = rope_scaling.get("attention_factor") + mscale = rope_scaling.get("mscale") + mscale_all_dim = rope_scaling.get("mscale_all_dim") + + if "original_max_position_embeddings" in rope_scaling: + original_max_position_embeddings = rope_scaling[ + "original_max_position_embeddings"] + factor = config.max_position_embeddings / original_max_position_embeddings + else: + original_max_position_embeddings = config.max_position_embeddings + + def get_mscale(scale, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + # Sets the attention factor as suggested in the paper + if attention_factor is None: + if mscale and mscale_all_dim: + attention_factor = float( + get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim)) + else: + attention_factor = get_mscale(factor) + + # Optional config options + # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) + beta_fast = rope_scaling.get("beta_fast") or 32 + beta_slow = rope_scaling.get("beta_slow") or 1 + + # Compute the inverse frequencies + def find_correction_dim(num_rotations, dim, base, max_position_embeddings): + """Inverse dimension formula to find the dimension based on the number of rotations""" + return (dim * + math.log(max_position_embeddings / + (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + def find_correction_range(low_rot, high_rot, dim, base, + max_position_embeddings, truncate): + """Find dimension range bounds based on rotations""" + low = find_correction_dim(low_rot, dim, base, max_position_embeddings) + high = find_correction_dim(high_rot, dim, base, max_position_embeddings) + if truncate: + low = math.floor(low) + high = math.ceil(high) + return max(low, 0), min(high, dim - 1) + + truncate = rope_scaling.get("truncate", True) + low, high = find_correction_range(beta_fast, beta_slow, dim, base, + original_max_position_embeddings, + truncate) + + # These parts are implemented in the fusedQKNormRopeKernel.cu + # # def linear_ramp_factor(min, max, dim): + # # if min == max: + # # max += 0.001 # Prevent singularity + + # # linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + # # ramp_func = torch.clamp(linear_func, 0, 1) + # # return ramp_func + + # # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs + # # to expand the possible context length. In other words, interpolation = apply scaling factor. + # # pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim) + # # inv_freq_extrapolation = 1.0 / pos_freqs + # # inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + # # # Get n-dimensional rotational scaling corrected for extrapolation + # # inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float) + # # inv_freq = ( + # # inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + # # + inv_freq_extrapolation * inv_freq_extrapolation_factor + # # ) + # # return inv_freq, attention_factor + return factor, low, high, attention_factor + + +class QKNormRoPEAttention(Attention): + """ + QKNormRoPEAttention is a custom attention layer that applies QK norm and RoPE to the input tensor. + It is used in the ExaOne4, Gemma3 and Qwen3 models. + It is a subclass of Attention, and overrides the apply_rope method to apply QK norm and RoPE. + """ + + def __init__( + self, + *, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + max_position_embeddings: int, + bias: bool, + pos_embd_params: Optional[PositionalEmbeddingParams] = None, + skip_rope: bool = False, + fuse_qk_norm_rope: bool = True, + layer_idx: Optional[int] = None, + dtype: torch.dtype = None, + dense_bias: Optional[bool] = None, + config: ModelConfig, + q_scaling: float = 1.0, + ): + self.pretrained_config = config.pretrained_config + + self.fuse_qk_norm_rope = fuse_qk_norm_rope + self.skip_rope = skip_rope + assert not (fuse_qk_norm_rope and skip_rope + ), "Fusing qk norm and skipping rope is not supported" + + super().__init__( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + max_position_embeddings=max_position_embeddings, + bias=bias, + pos_embd_params=pos_embd_params, + # If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP, + # and self.rotary_emb will be skipped in the overridden apply_rope. + rope_fusion=not self.fuse_qk_norm_rope and not skip_rope, + layer_idx=layer_idx, + dtype=dtype, + dense_bias=dense_bias, + config=config, + q_scaling=q_scaling, + ) + + self.q_norm = RMSNorm(hidden_size=self.head_dim, + eps=self.pretrained_config.rms_norm_eps, + dtype=self.pretrained_config.torch_dtype, + has_weights=True) + self.k_norm = RMSNorm(hidden_size=self.head_dim, + eps=self.pretrained_config.rms_norm_eps, + dtype=self.pretrained_config.torch_dtype, + has_weights=True) + self.aux_stream = torch.cuda.Stream() + self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] + + def apply_qk_norm(self, q, k): + + def q_l2norm(): + return self.q_norm(q.reshape(-1, self.head_dim)).reshape( + -1, self.q_size) + + def k_l2norm(): + return self.k_norm(k.reshape(-1, self.head_dim)).reshape( + -1, self.kv_size) + + q, k = maybe_execute_in_parallel( + q_l2norm, + k_l2norm, + self.ln_events[0], + self.ln_events[1], + self.aux_stream, + ) + + return q, k + + def apply_qk_norm_rope(self, qkv, position_ids): + factor, low, high, attention_factor = compute_yarn_parameters( + self.pretrained_config) + torch.ops.trtllm.fused_qk_norm_rope( + qkv, self.num_heads, self.num_key_value_heads, + self.num_key_value_heads, self.head_dim, + self.q_norm.variance_epsilon, self.q_norm.weight, + self.k_norm.weight, + self.pos_embd_params.rope.theta, self.pos_embd_params.is_neox, + position_ids.view(-1), factor, low, high, attention_factor) + return qkv, None, None + + def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor], + v: Optional[torch.Tensor], position_ids: torch.Tensor): + """ + The apply_rope method is called in the forward method of the Attention class. + The apply_rope method is overridden in this class to apply QK norm and RoPE to the input tensor. + """ + # Apply QK norm before RoPE. + if not self.fuse_qk_norm_rope: + q, k, v = self.split_qkv(q, k, v) + q, k = self.apply_qk_norm(q, k) + if not self.skip_rope: + return super().apply_rope(q, k, v, position_ids) + else: + return q, k, v + + assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope" + qkv = q + return self.apply_qk_norm_rope(qkv, position_ids) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index e3021797ac..d1e079df2f 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -504,7 +504,6 @@ def create_py_executor_instance( resources, mapping, pytorch_backend_config, - executor_config, ctx_chunk_config, model_engine, start_worker, @@ -515,13 +514,19 @@ def create_py_executor_instance( garbage_collection_gen0_threshold: Optional[int] = None, kv_connector_manager: Optional[KvCacheConnectorManager] = None, max_seq_len: Optional[int] = None, + max_batch_size: Optional[int] = None, + max_beam_width: Optional[int] = None, + max_num_tokens: Optional[int] = None, + peft_cache_config: Optional[trtllm.PeftCacheConfig] = None, + scheduler_config: Optional[trtllm.SchedulerConfig] = None, + cache_transceiver_config: Optional[trtllm.CacheTransceiverConfig] = None, ) -> PyExecutor: kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None) spec_config = model_engine.spec_config logger.info( - f"max_seq_len={executor_config.max_seq_len}, max_num_requests={executor_config.max_batch_size}, max_num_tokens={executor_config.max_num_tokens}, max_batch_size={executor_config.max_batch_size}" + f"max_seq_len={max_seq_len}, max_num_requests={max_batch_size}, max_num_tokens={max_num_tokens}, max_batch_size={max_batch_size}" ) for key, value in pytorch_backend_config.extra_resource_managers.items(): @@ -578,16 +583,15 @@ def create_py_executor_instance( len(lora_config.lora_target_modules + lora_config.missing_qkv_modules) peft_cache_config_model = PeftCacheConfig.from_pybind( - executor_config.peft_cache_config - ) if executor_config.peft_cache_config is not None else PeftCacheConfig( - ) + peft_cache_config + ) if peft_cache_config is not None else PeftCacheConfig() if lora_config.max_loras is not None: peft_cache_config_model.num_device_module_layer = \ max_lora_rank * num_lora_modules * lora_config.max_loras if lora_config.max_cpu_loras is not None: peft_cache_config_model.num_host_module_layer = \ max_lora_rank * num_lora_modules * lora_config.max_cpu_loras - executor_config.peft_cache_config = peft_cache_config_model._to_pybind() + peft_cache_config = peft_cache_config_model._to_pybind() from tensorrt_llm.bindings import WorldConfig world_config = WorldConfig( @@ -598,7 +602,7 @@ def create_py_executor_instance( gpus_per_node=dist.mapping.gpus_per_node, ) peft_cache_manager = PeftCacheManager( - peft_cache_config=executor_config.peft_cache_config, + peft_cache_config=peft_cache_config, lora_config=lora_config, model_config=model_binding_config, world_config=world_config, @@ -609,7 +613,7 @@ def create_py_executor_instance( lora_config.trtllm_modules_to_hf_modules, lora_config.swap_gate_up_proj_lora_b_weight) - max_num_sequences = executor_config.max_batch_size * mapping.pp_size + max_num_sequences = max_batch_size * mapping.pp_size resources[ResourceManagerType.SEQ_SLOT_MANAGER] = SeqSlotManager( max_num_sequences) @@ -632,17 +636,15 @@ def create_py_executor_instance( scheduler_capacity, kv_cache_manager.impl if kv_cache_manager is not None else None, peft_cache_manager.impl if peft_cache_manager is not None else None, - executor_config.scheduler_config.capacity_scheduler_policy, + scheduler_config.capacity_scheduler_policy, two_step_lookahead=mapping.has_pp()) - mb_scheduler = BindMicroBatchScheduler(executor_config.max_batch_size, - executor_config.max_num_tokens, + mb_scheduler = BindMicroBatchScheduler(max_batch_size, max_num_tokens, ctx_chunk_config) scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler) config = model_engine.model.model_config.pretrained_config attention_type = AttentionTypeCpp.MLA if is_mla( config) else AttentionTypeCpp.DEFAULT - cache_transceiver_config = executor_config.cache_transceiver_config kv_cache_transceiver = create_kv_cache_transceiver( mapping, kv_cache_manager, attention_type, cache_transceiver_config) return PyExecutor( @@ -655,8 +657,8 @@ def create_py_executor_instance( max_num_sequences=max_num_sequences, disable_overlap_scheduler=pytorch_backend_config. disable_overlap_scheduler, - max_batch_size=executor_config.max_batch_size, - max_beam_width=executor_config.max_beam_width, + max_batch_size=max_batch_size, + max_beam_width=max_beam_width, max_draft_len=spec_config.max_draft_len if spec_config is not None else 0, kv_cache_transceiver=kv_cache_transceiver, @@ -664,7 +666,8 @@ def create_py_executor_instance( start_worker=start_worker, garbage_collection_gen0_threshold=garbage_collection_gen0_threshold, kv_connector_manager=kv_connector_manager, - max_seq_len=max_seq_len) + max_seq_len=max_seq_len, + peft_cache_config=peft_cache_config) def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping, diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index cc0159643b..4070f210e1 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -25,8 +25,8 @@ from tensorrt_llm._utils import (customized_gc_thresholds, global_mpi_rank, from tensorrt_llm.bindings.executor import (DisServingRequestStats, FinishReason, InflightBatchingStats, IterationStats, KvCacheStats, - RequestStage, RequestStats, - SpecDecodingStats, + PeftCacheConfig, RequestStage, + RequestStats, SpecDecodingStats, StaticBatchingStats) from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType, ReqIdsSet) @@ -157,11 +157,14 @@ class PyExecutor: garbage_collection_gen0_threshold: Optional[int] = None, start_worker: bool = True, kv_connector_manager: Optional[KvCacheConnectorManager] = None, - max_seq_len: Optional[int] = None): + max_seq_len: Optional[int] = None, + peft_cache_config: Optional[PeftCacheConfig] = None): super(PyExecutor, self).__init__() self.device_id = torch.cuda.current_device() self.global_rank = global_mpi_rank() + self.peft_cache_config = peft_cache_config + # profile config self.profile_start_iters, self.profile_stop_iters = _load_iteration_indexes( PROFILE_START_STOP_ENV_VAR_NAME) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index d8c76a35b0..e5f17640fb 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -226,7 +226,7 @@ def create_py_executor( mapping = _get_mapping(executor_config) dist = MPIDist(mapping=mapping) - + cache_transceiver_config = executor_config.cache_transceiver_config spec_config = executor_config.speculative_config has_draft_model_engine = False has_spec_drafter = False @@ -508,7 +508,6 @@ def create_py_executor( resources=resources, mapping=mapping, pytorch_backend_config=pytorch_backend_config, - executor_config=executor_config, ctx_chunk_config=ctx_chunk_config, model_engine=model_engine, start_worker=False, @@ -520,7 +519,16 @@ def create_py_executor( kv_connector_manager=kv_connector_manager if not estimating_kv_cache else None, max_seq_len=executor_config.max_seq_len, + max_batch_size=executor_config.max_batch_size, + max_beam_width=executor_config.max_beam_width, + max_num_tokens=executor_config.max_num_tokens, + peft_cache_config=executor_config.peft_cache_config, + scheduler_config=executor_config.scheduler_config, + cache_transceiver_config=cache_transceiver_config, ) + # Modify the executor_config.peft_cache_config which might be mutated + # inside create_py_executor_instance + executor_config.peft_cache_config = py_executor.peft_cache_config if estimating_kv_cache: assert kv_cache_creator is not None @@ -553,7 +561,6 @@ def create_py_executor( resources=resources, mapping=mapping, pytorch_backend_config=pytorch_backend_config, - executor_config=executor_config, ctx_chunk_config=ctx_chunk_config, model_engine=model_engine, start_worker=False, @@ -565,6 +572,12 @@ def create_py_executor( garbage_collection_gen0_threshold, kv_connector_manager=kv_connector_manager, max_seq_len=executor_config.max_seq_len, + max_batch_size=executor_config.max_batch_size, + max_beam_width=executor_config.max_beam_width, + max_num_tokens=executor_config.max_num_tokens, + peft_cache_config=executor_config.peft_cache_config, + scheduler_config=executor_config.scheduler_config, + cache_transceiver_config=cache_transceiver_config, ) _adjust_torch_mem_fraction(executor_config.pytorch_backend_config) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 02701a066f..359b1fbd22 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -353,3 +353,5 @@ test_e2e.py::test_multi_nodes_eval[llama4-models/nvidia/Llama-4-Maverick-17B-128 accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl SKIP (https://nvbugs/5413362) stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-GUARANTEED_NO_EVICT-pytorch-stress-test] SKIP (https://nvbugs/5474169) test_e2e.py::test_trtllm_bench_iteration_log[TRT-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B] SKIP (https://nvbugs/5448523) +cpp/test_unit_tests.py::test_unit_tests[kernels-80] SKIP (https://nvbugs/5504078) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5504086) diff --git a/tests/unittest/_torch/attention/test_attention_mla.py b/tests/unittest/_torch/attention/test_attention_mla.py index 72bfc78b9d..ec445acb64 100644 --- a/tests/unittest/_torch/attention/test_attention_mla.py +++ b/tests/unittest/_torch/attention/test_attention_mla.py @@ -389,11 +389,6 @@ def test_attention_mla(scenario: Scenario, context_sequence_lengths: List[int], dtype = scenario.dtype kv_cache_dtype = scenario.kv_cache_dtype - FAILED_CSL = [777, 912, 431, 42, 266, 989, 524] - if (kv_cache_dtype is torch.float8_e4m3fn - and context_sequence_lengths == FAILED_CSL): - pytest.skip("https://nvbugs/5453806") - print( f"--------------------------------Test for scenario: {scenario} start--------------------------------" ) diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index 8de0ac8642..038cea3c4f 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -16,7 +16,6 @@ from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig, sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -@pytest.mark.skip(reason="https://nvbugs/5461761") @pytest.mark.parametrize( "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill", [ diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index 4f7488205a..aa39e9acfa 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -594,6 +594,7 @@ def llm_for_sampling_params(): llm.shutdown() +@pytest.mark.skip(reason="https://nvbugs/5504095") @pytest.mark.part0 def test_user_specify_workspace(): user_specified_ws_path = '/tmp/specified_workspace'