Merge branch 'user/sm103_trtllmgen' into feat/b300_cu13

Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
Xiwen Yu 2025-09-06 00:49:23 +08:00
commit 5e7aa76bb4
11 changed files with 300 additions and 300 deletions

View File

@ -3,18 +3,17 @@ from typing import Optional, Tuple
import torch
from torch import nn
from tensorrt_llm._torch.modules.qk_norm_attention import QKNormRoPEAttention
from tensorrt_llm.functional import PositionEmbeddingType
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import (PositionalEmbeddingParams,
PredefinedAttentionMask, RopeParams)
from ..model_config import ModelConfig
from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import TensorParallelMode
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from ..speculative import SpecMetadata
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
@ -50,12 +49,11 @@ def check_is_sliding(config: Exaone4Config, layer_idx: int) -> bool:
return False
class Exaone4Attention(Attention):
class Exaone4Attention(QKNormRoPEAttention):
def __init__(self,
model_config: ModelConfig[Exaone4Config],
layer_idx: Optional[int] = None,
aux_stream: Optional[torch.cuda.Stream] = None,
fuse_qk_norm_rope: bool = False):
config = model_config.pretrained_config
@ -73,10 +71,10 @@ class Exaone4Attention(Attention):
rope=RopeParams.from_config(config),
)
self.fuse_qk_norm_rope = (self.is_sliding and fuse_qk_norm_rope)
fuse_qk_norm_rope = (self.is_sliding and fuse_qk_norm_rope)
# TODO: Fusing qk norm with rope has an issue that slightly hurts accuracy.
assert self.fuse_qk_norm_rope is False, "Fusing qk norm and rope is having issue now"
assert fuse_qk_norm_rope is False, "Fusing qk norm and rope is having issue now"
super().__init__(
hidden_size=config.hidden_size,
@ -85,65 +83,13 @@ class Exaone4Attention(Attention):
max_position_embeddings=config.max_position_embeddings,
bias=False,
pos_embd_params=pos_embd_params,
rope_fusion=False,
fuse_qk_norm_rope=fuse_qk_norm_rope,
skip_rope=self.sliding_window and not self.is_sliding,
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
)
self.q_norm = RMSNorm(hidden_size=self.head_dim,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.k_norm = RMSNorm(hidden_size=self.head_dim,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.aux_stream = aux_stream
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
def apply_qk_norm(self, q, k):
def q_l2norm():
return self.q_norm(q.reshape(-1, self.head_dim)).reshape(
-1, self.q_size)
def k_l2norm():
return self.k_norm(k.reshape(-1, self.head_dim)).reshape(
-1, self.kv_size)
q, k = maybe_execute_in_parallel(
q_l2norm,
k_l2norm,
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
return q, k
def apply_qk_norm_rope(self, qkv, position_ids):
torch.ops.trtllm.fused_qk_norm_rope(
qkv, self.num_heads, self.num_key_value_heads,
self.num_key_value_heads, self.head_dim,
self.q_norm.variance_epsilon, self.q_norm.weight,
self.k_norm.weight, self.pos_embd_params.rope.theta,
self.pos_embd_params.is_neox, position_ids.view(-1))
return qkv, None, None
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
v: Optional[torch.Tensor], position_ids: torch.Tensor):
if self.fuse_qk_norm_rope:
assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope"
qkv = q
return self.apply_qk_norm_rope(qkv, position_ids)
q, k, v = self.split_qkv(q, k, v)
q, k = self.apply_qk_norm(q, k)
if self.sliding_window is None or self.is_sliding:
return super().apply_rope(q, k, v, position_ids)
else:
return q, k, v
def forward(
self,
position_ids: Optional[torch.LongTensor],
@ -175,7 +121,6 @@ class Exaone4DecoderLayer(DecoderLayer):
self,
model_config: ModelConfig[Exaone4Config],
layer_idx: int,
aux_stream: Optional[torch.cuda.Stream] = None,
):
super().__init__()
config = model_config.pretrained_config
@ -186,7 +131,6 @@ class Exaone4DecoderLayer(DecoderLayer):
self.self_attn = Exaone4Attention(
model_config,
layer_idx=layer_idx,
aux_stream=aux_stream,
)
self.mlp = GatedMLP(
@ -244,7 +188,6 @@ class Exaone4Model(DecoderModel):
super().__init__(model_config)
config = self.model_config.pretrained_config
self.num_hidden_layers = config.num_hidden_layers
self.aux_stream = torch.cuda.Stream()
self.embed_tokens = Embedding(
config.vocab_size,
config.hidden_size,
@ -258,7 +201,6 @@ class Exaone4Model(DecoderModel):
Exaone4DecoderLayer(
model_config,
layer_idx,
self.aux_stream,
) for layer_idx in range(self.num_hidden_layers)
])

View File

@ -7,6 +7,7 @@ from transformers import Gemma3TextConfig
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
BaseWeightMapper
from tensorrt_llm._torch.modules.qk_norm_attention import QKNormRoPEAttention
from tensorrt_llm.functional import PositionEmbeddingType, RotaryScalingType
from tensorrt_llm.mapping import Mapping
@ -15,12 +16,10 @@ from ..attention_backend.interface import (AttentionMask, CustomAttentionMask,
PositionalEmbeddingParams,
PredefinedAttentionMask, RopeParams)
from ..model_config import ModelConfig
from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import TensorParallelMode
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
register_auto_model)
@ -52,7 +51,7 @@ class Gemma3TextScaledWordEmbedding(Embedding):
return super().forward(input_ids) * self.embed_scale
class Gemma3Attention(Attention):
class Gemma3Attention(QKNormRoPEAttention):
def __init__(
self,
@ -82,20 +81,13 @@ class Gemma3Attention(Attention):
max_position_embeddings=config.max_position_embeddings,
bias=False,
pos_embd_params=pos_embd_params,
fuse_qk_norm_rope=False,
layer_idx=layer_idx,
dtype=config.torch_dtype,
dense_bias=False,
config=model_config,
q_scaling=q_scaling,
)
self.q_norm = RMSNorm(hidden_size=config.head_dim,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.k_norm = RMSNorm(hidden_size=config.head_dim,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.aux_stream = torch.cuda.Stream()
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
@torch.inference_mode()
def forward(
@ -121,33 +113,6 @@ class Gemma3Attention(Attention):
attention_mask_data=attention_mask_data,
**kwargs)
def apply_qk_norm(self, q, k):
def q_l2norm():
return self.q_norm(q.reshape(-1, self.head_dim)).reshape(
-1, self.q_size)
def k_l2norm():
return self.k_norm(k.reshape(-1, self.head_dim)).reshape(
-1, self.kv_size)
q, k = maybe_execute_in_parallel(
q_l2norm,
k_l2norm,
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
return q, k
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
v: Optional[torch.Tensor], position_ids: torch.Tensor):
# Gemma3 applies QK norm before RoPE.
q, k, v = self.split_qkv(q, k, v)
q, k = self.apply_qk_norm(q, k)
return super().apply_rope(q, k, v, position_ids)
# This function is written to be compatible with TRTLLM's GatedMLP class.
def pytorch_gelu_tanh(gate_x: torch.Tensor) -> torch.Tensor:

View File

@ -1,133 +1,26 @@
import math
from typing import Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig, Qwen3Config
from transformers import Qwen3Config
from tensorrt_llm.functional import PositionEmbeddingType
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..model_config import ModelConfig
from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import TensorParallelMode
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.qk_norm_attention import QKNormRoPEAttention
from ..modules.rms_norm import RMSNorm
from ..speculative import SpecMetadata
from .modeling_speculative import SpecDecOneEngineForCausalLM
from .modeling_utils import DecoderModel, register_auto_model
# Move out from this class
def compute_yarn_parameters(
config: PretrainedConfig, ) -> tuple[float, float, float, float]:
"""
Refer to https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L197C1-L288C1
Computes the inverse frequencies with NTK scaling. Please refer to the
[original paper](https://huggingface.co/papers/2309.00071)
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
Returns:
factor: float, the scaling factor for the RoPE embeddings
low: float, the lower bound of the dimension range
high: float, the upper bound of the dimension range
attention_factor: float, the post-processing scaling factor applied to the computed cos/sin
"""
# The config does not contain rope_scaling, which means the model is not using yarn
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is None:
return 1.0, 0, 0, 1.0
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(
config, "partial_rotary_factor") else 1.0
head_dim = getattr(config, "head_dim",
config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
factor = getattr(rope_scaling, "factor", 1.0)
attention_factor = rope_scaling.get("attention_factor")
mscale = rope_scaling.get("mscale")
mscale_all_dim = rope_scaling.get("mscale_all_dim")
if "original_max_position_embeddings" in rope_scaling:
original_max_position_embeddings = rope_scaling[
"original_max_position_embeddings"]
factor = config.max_position_embeddings / original_max_position_embeddings
else:
original_max_position_embeddings = config.max_position_embeddings
def get_mscale(scale, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
# Sets the attention factor as suggested in the paper
if attention_factor is None:
if mscale and mscale_all_dim:
attention_factor = float(
get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
else:
attention_factor = get_mscale(factor)
# Optional config options
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
beta_fast = rope_scaling.get("beta_fast") or 32
beta_slow = rope_scaling.get("beta_slow") or 1
# Compute the inverse frequencies
def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
"""Inverse dimension formula to find the dimension based on the number of rotations"""
return (dim *
math.log(max_position_embeddings /
(num_rotations * 2 * math.pi))) / (2 * math.log(base))
def find_correction_range(low_rot, high_rot, dim, base,
max_position_embeddings, truncate):
"""Find dimension range bounds based on rotations"""
low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
if truncate:
low = math.floor(low)
high = math.ceil(high)
return max(low, 0), min(high, dim - 1)
truncate = rope_scaling.get("truncate", True)
low, high = find_correction_range(beta_fast, beta_slow, dim, base,
original_max_position_embeddings,
truncate)
# These parts are implemented in the fusedQKNormRopeKernel.cu
# # def linear_ramp_factor(min, max, dim):
# # if min == max:
# # max += 0.001 # Prevent singularity
# # linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
# # ramp_func = torch.clamp(linear_func, 0, 1)
# # return ramp_func
# # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
# # to expand the possible context length. In other words, interpolation = apply scaling factor.
# # pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim)
# # inv_freq_extrapolation = 1.0 / pos_freqs
# # inv_freq_interpolation = 1.0 / (factor * pos_freqs)
# # # Get n-dimensional rotational scaling corrected for extrapolation
# # inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
# # inv_freq = (
# # inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
# # + inv_freq_extrapolation * inv_freq_extrapolation_factor
# # )
# # return inv_freq, attention_factor
return factor, low, high, attention_factor
class Qwen3Attention(Attention):
class Qwen3Attention(QKNormRoPEAttention):
def __init__(
self,
@ -136,7 +29,6 @@ class Qwen3Attention(Attention):
fuse_qk_norm_rope: bool = True,
):
config = model_config.pretrained_config
self.pretrained_config = config
if getattr(config, "rope_scaling", None) is not None:
if "type" in config.rope_scaling:
@ -156,8 +48,6 @@ class Qwen3Attention(Attention):
rope=RopeParams.from_config(config),
)
self.fuse_qk_norm_rope = fuse_qk_norm_rope
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
@ -165,69 +55,13 @@ class Qwen3Attention(Attention):
max_position_embeddings=config.max_position_embeddings,
bias=config.attention_bias,
pos_embd_params=pos_embd_params,
rope_fusion=not self.
fuse_qk_norm_rope, # If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP, and self.rotary_emb will be skipped in the overridden apply_rope.
fuse_qk_norm_rope=fuse_qk_norm_rope,
layer_idx=layer_idx,
dtype=config.torch_dtype,
dense_bias=config.attention_bias,
config=model_config,
)
self.q_norm = RMSNorm(hidden_size=self.head_dim,
eps=1e-6,
dtype=config.torch_dtype,
has_weights=True)
self.k_norm = RMSNorm(hidden_size=self.head_dim,
eps=1e-6,
dtype=config.torch_dtype,
has_weights=True)
self.aux_stream = torch.cuda.Stream()
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
def apply_qk_norm(self, q, k):
def q_l2norm():
return self.q_norm(q.reshape(-1, self.head_dim)).reshape(
-1, self.q_size)
def k_l2norm():
return self.k_norm(k.reshape(-1, self.head_dim)).reshape(
-1, self.kv_size)
q, k = maybe_execute_in_parallel(
q_l2norm,
k_l2norm,
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
return q, k
def apply_qk_norm_rope(self, qkv, position_ids):
factor, low, high, attention_factor = compute_yarn_parameters(
self.pretrained_config)
torch.ops.trtllm.fused_qk_norm_rope(
qkv, self.num_heads, self.num_key_value_heads,
self.num_key_value_heads, self.head_dim,
self.q_norm.variance_epsilon, self.q_norm.weight,
self.k_norm.weight,
self.pos_embd_params.rope.theta, self.pos_embd_params.is_neox,
position_ids.view(-1), factor, low, high, attention_factor)
return qkv, None, None
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
v: Optional[torch.Tensor], position_ids: torch.Tensor):
# Qwen3 applies QK norm before RoPE.
if not self.fuse_qk_norm_rope:
q, k, v = self.split_qkv(q, k, v)
q, k = self.apply_qk_norm(q, k)
return super().apply_rope(q, k, v, position_ids)
assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope"
qkv = q
return self.apply_qk_norm_rope(qkv, position_ids)
class Qwen3DecoderLayer(DecoderLayer):
@ -235,7 +69,7 @@ class Qwen3DecoderLayer(DecoderLayer):
self,
model_config: ModelConfig[Qwen3Config],
layer_idx: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
):
super().__init__()
self.layer_idx = layer_idx
config = model_config.pretrained_config

View File

@ -0,0 +1,243 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
from transformers import PretrainedConfig
from ..attention_backend.interface import PositionalEmbeddingParams
from ..model_config import ModelConfig
from ..modules.attention import Attention
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
# Move out from this class
def compute_yarn_parameters(
config: PretrainedConfig, ) -> tuple[float, float, float, float]:
"""
Refer to https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L197C1-L288C1
Computes the inverse frequencies with NTK scaling. Please refer to the
[original paper](https://huggingface.co/papers/2309.00071)
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
Returns:
factor: float, the scaling factor for the RoPE embeddings
low: float, the lower bound of the dimension range
high: float, the upper bound of the dimension range
attention_factor: float, the post-processing scaling factor applied to the computed cos/sin
"""
# If config does not contain rope_scaling or rope_type is not yarn, it means the model is not using yarn
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is None or getattr(rope_scaling, "rope_type",
None) != "yarn":
return 1.0, 0, 0, 1.0
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(
config, "partial_rotary_factor") else 1.0
head_dim = getattr(config, "head_dim",
config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
factor = getattr(rope_scaling, "factor", 1.0)
attention_factor = rope_scaling.get("attention_factor")
mscale = rope_scaling.get("mscale")
mscale_all_dim = rope_scaling.get("mscale_all_dim")
if "original_max_position_embeddings" in rope_scaling:
original_max_position_embeddings = rope_scaling[
"original_max_position_embeddings"]
factor = config.max_position_embeddings / original_max_position_embeddings
else:
original_max_position_embeddings = config.max_position_embeddings
def get_mscale(scale, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
# Sets the attention factor as suggested in the paper
if attention_factor is None:
if mscale and mscale_all_dim:
attention_factor = float(
get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
else:
attention_factor = get_mscale(factor)
# Optional config options
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
beta_fast = rope_scaling.get("beta_fast") or 32
beta_slow = rope_scaling.get("beta_slow") or 1
# Compute the inverse frequencies
def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
"""Inverse dimension formula to find the dimension based on the number of rotations"""
return (dim *
math.log(max_position_embeddings /
(num_rotations * 2 * math.pi))) / (2 * math.log(base))
def find_correction_range(low_rot, high_rot, dim, base,
max_position_embeddings, truncate):
"""Find dimension range bounds based on rotations"""
low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
if truncate:
low = math.floor(low)
high = math.ceil(high)
return max(low, 0), min(high, dim - 1)
truncate = rope_scaling.get("truncate", True)
low, high = find_correction_range(beta_fast, beta_slow, dim, base,
original_max_position_embeddings,
truncate)
# These parts are implemented in the fusedQKNormRopeKernel.cu
# # def linear_ramp_factor(min, max, dim):
# # if min == max:
# # max += 0.001 # Prevent singularity
# # linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
# # ramp_func = torch.clamp(linear_func, 0, 1)
# # return ramp_func
# # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
# # to expand the possible context length. In other words, interpolation = apply scaling factor.
# # pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim)
# # inv_freq_extrapolation = 1.0 / pos_freqs
# # inv_freq_interpolation = 1.0 / (factor * pos_freqs)
# # # Get n-dimensional rotational scaling corrected for extrapolation
# # inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
# # inv_freq = (
# # inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
# # + inv_freq_extrapolation * inv_freq_extrapolation_factor
# # )
# # return inv_freq, attention_factor
return factor, low, high, attention_factor
class QKNormRoPEAttention(Attention):
"""
QKNormRoPEAttention is a custom attention layer that applies QK norm and RoPE to the input tensor.
It is used in the ExaOne4, Gemma3 and Qwen3 models.
It is a subclass of Attention, and overrides the apply_rope method to apply QK norm and RoPE.
"""
def __init__(
self,
*,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
max_position_embeddings: int,
bias: bool,
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
skip_rope: bool = False,
fuse_qk_norm_rope: bool = True,
layer_idx: Optional[int] = None,
dtype: torch.dtype = None,
dense_bias: Optional[bool] = None,
config: ModelConfig,
q_scaling: float = 1.0,
):
self.pretrained_config = config.pretrained_config
self.fuse_qk_norm_rope = fuse_qk_norm_rope
self.skip_rope = skip_rope
assert not (fuse_qk_norm_rope and skip_rope
), "Fusing qk norm and skipping rope is not supported"
super().__init__(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
max_position_embeddings=max_position_embeddings,
bias=bias,
pos_embd_params=pos_embd_params,
# If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP,
# and self.rotary_emb will be skipped in the overridden apply_rope.
rope_fusion=not self.fuse_qk_norm_rope and not skip_rope,
layer_idx=layer_idx,
dtype=dtype,
dense_bias=dense_bias,
config=config,
q_scaling=q_scaling,
)
self.q_norm = RMSNorm(hidden_size=self.head_dim,
eps=self.pretrained_config.rms_norm_eps,
dtype=self.pretrained_config.torch_dtype,
has_weights=True)
self.k_norm = RMSNorm(hidden_size=self.head_dim,
eps=self.pretrained_config.rms_norm_eps,
dtype=self.pretrained_config.torch_dtype,
has_weights=True)
self.aux_stream = torch.cuda.Stream()
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
def apply_qk_norm(self, q, k):
def q_l2norm():
return self.q_norm(q.reshape(-1, self.head_dim)).reshape(
-1, self.q_size)
def k_l2norm():
return self.k_norm(k.reshape(-1, self.head_dim)).reshape(
-1, self.kv_size)
q, k = maybe_execute_in_parallel(
q_l2norm,
k_l2norm,
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
return q, k
def apply_qk_norm_rope(self, qkv, position_ids):
factor, low, high, attention_factor = compute_yarn_parameters(
self.pretrained_config)
torch.ops.trtllm.fused_qk_norm_rope(
qkv, self.num_heads, self.num_key_value_heads,
self.num_key_value_heads, self.head_dim,
self.q_norm.variance_epsilon, self.q_norm.weight,
self.k_norm.weight,
self.pos_embd_params.rope.theta, self.pos_embd_params.is_neox,
position_ids.view(-1), factor, low, high, attention_factor)
return qkv, None, None
def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
v: Optional[torch.Tensor], position_ids: torch.Tensor):
"""
The apply_rope method is called in the forward method of the Attention class.
The apply_rope method is overridden in this class to apply QK norm and RoPE to the input tensor.
"""
# Apply QK norm before RoPE.
if not self.fuse_qk_norm_rope:
q, k, v = self.split_qkv(q, k, v)
q, k = self.apply_qk_norm(q, k)
if not self.skip_rope:
return super().apply_rope(q, k, v, position_ids)
else:
return q, k, v
assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope"
qkv = q
return self.apply_qk_norm_rope(qkv, position_ids)

View File

@ -504,7 +504,6 @@ def create_py_executor_instance(
resources,
mapping,
pytorch_backend_config,
executor_config,
ctx_chunk_config,
model_engine,
start_worker,
@ -515,13 +514,19 @@ def create_py_executor_instance(
garbage_collection_gen0_threshold: Optional[int] = None,
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
max_seq_len: Optional[int] = None,
max_batch_size: Optional[int] = None,
max_beam_width: Optional[int] = None,
max_num_tokens: Optional[int] = None,
peft_cache_config: Optional[trtllm.PeftCacheConfig] = None,
scheduler_config: Optional[trtllm.SchedulerConfig] = None,
cache_transceiver_config: Optional[trtllm.CacheTransceiverConfig] = None,
) -> PyExecutor:
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
spec_config = model_engine.spec_config
logger.info(
f"max_seq_len={executor_config.max_seq_len}, max_num_requests={executor_config.max_batch_size}, max_num_tokens={executor_config.max_num_tokens}, max_batch_size={executor_config.max_batch_size}"
f"max_seq_len={max_seq_len}, max_num_requests={max_batch_size}, max_num_tokens={max_num_tokens}, max_batch_size={max_batch_size}"
)
for key, value in pytorch_backend_config.extra_resource_managers.items():
@ -578,16 +583,15 @@ def create_py_executor_instance(
len(lora_config.lora_target_modules + lora_config.missing_qkv_modules)
peft_cache_config_model = PeftCacheConfig.from_pybind(
executor_config.peft_cache_config
) if executor_config.peft_cache_config is not None else PeftCacheConfig(
)
peft_cache_config
) if peft_cache_config is not None else PeftCacheConfig()
if lora_config.max_loras is not None:
peft_cache_config_model.num_device_module_layer = \
max_lora_rank * num_lora_modules * lora_config.max_loras
if lora_config.max_cpu_loras is not None:
peft_cache_config_model.num_host_module_layer = \
max_lora_rank * num_lora_modules * lora_config.max_cpu_loras
executor_config.peft_cache_config = peft_cache_config_model._to_pybind()
peft_cache_config = peft_cache_config_model._to_pybind()
from tensorrt_llm.bindings import WorldConfig
world_config = WorldConfig(
@ -598,7 +602,7 @@ def create_py_executor_instance(
gpus_per_node=dist.mapping.gpus_per_node,
)
peft_cache_manager = PeftCacheManager(
peft_cache_config=executor_config.peft_cache_config,
peft_cache_config=peft_cache_config,
lora_config=lora_config,
model_config=model_binding_config,
world_config=world_config,
@ -609,7 +613,7 @@ def create_py_executor_instance(
lora_config.trtllm_modules_to_hf_modules,
lora_config.swap_gate_up_proj_lora_b_weight)
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
max_num_sequences = max_batch_size * mapping.pp_size
resources[ResourceManagerType.SEQ_SLOT_MANAGER] = SeqSlotManager(
max_num_sequences)
@ -632,17 +636,15 @@ def create_py_executor_instance(
scheduler_capacity,
kv_cache_manager.impl if kv_cache_manager is not None else None,
peft_cache_manager.impl if peft_cache_manager is not None else None,
executor_config.scheduler_config.capacity_scheduler_policy,
scheduler_config.capacity_scheduler_policy,
two_step_lookahead=mapping.has_pp())
mb_scheduler = BindMicroBatchScheduler(executor_config.max_batch_size,
executor_config.max_num_tokens,
mb_scheduler = BindMicroBatchScheduler(max_batch_size, max_num_tokens,
ctx_chunk_config)
scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler)
config = model_engine.model.model_config.pretrained_config
attention_type = AttentionTypeCpp.MLA if is_mla(
config) else AttentionTypeCpp.DEFAULT
cache_transceiver_config = executor_config.cache_transceiver_config
kv_cache_transceiver = create_kv_cache_transceiver(
mapping, kv_cache_manager, attention_type, cache_transceiver_config)
return PyExecutor(
@ -655,8 +657,8 @@ def create_py_executor_instance(
max_num_sequences=max_num_sequences,
disable_overlap_scheduler=pytorch_backend_config.
disable_overlap_scheduler,
max_batch_size=executor_config.max_batch_size,
max_beam_width=executor_config.max_beam_width,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_draft_len=spec_config.max_draft_len
if spec_config is not None else 0,
kv_cache_transceiver=kv_cache_transceiver,
@ -664,7 +666,8 @@ def create_py_executor_instance(
start_worker=start_worker,
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
kv_connector_manager=kv_connector_manager,
max_seq_len=max_seq_len)
max_seq_len=max_seq_len,
peft_cache_config=peft_cache_config)
def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,

View File

@ -25,8 +25,8 @@ from tensorrt_llm._utils import (customized_gc_thresholds, global_mpi_rank,
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
FinishReason, InflightBatchingStats,
IterationStats, KvCacheStats,
RequestStage, RequestStats,
SpecDecodingStats,
PeftCacheConfig, RequestStage,
RequestStats, SpecDecodingStats,
StaticBatchingStats)
from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType,
ReqIdsSet)
@ -157,11 +157,14 @@ class PyExecutor:
garbage_collection_gen0_threshold: Optional[int] = None,
start_worker: bool = True,
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
max_seq_len: Optional[int] = None):
max_seq_len: Optional[int] = None,
peft_cache_config: Optional[PeftCacheConfig] = None):
super(PyExecutor, self).__init__()
self.device_id = torch.cuda.current_device()
self.global_rank = global_mpi_rank()
self.peft_cache_config = peft_cache_config
# profile config
self.profile_start_iters, self.profile_stop_iters = _load_iteration_indexes(
PROFILE_START_STOP_ENV_VAR_NAME)

View File

@ -226,7 +226,7 @@ def create_py_executor(
mapping = _get_mapping(executor_config)
dist = MPIDist(mapping=mapping)
cache_transceiver_config = executor_config.cache_transceiver_config
spec_config = executor_config.speculative_config
has_draft_model_engine = False
has_spec_drafter = False
@ -508,7 +508,6 @@ def create_py_executor(
resources=resources,
mapping=mapping,
pytorch_backend_config=pytorch_backend_config,
executor_config=executor_config,
ctx_chunk_config=ctx_chunk_config,
model_engine=model_engine,
start_worker=False,
@ -520,7 +519,16 @@ def create_py_executor(
kv_connector_manager=kv_connector_manager
if not estimating_kv_cache else None,
max_seq_len=executor_config.max_seq_len,
max_batch_size=executor_config.max_batch_size,
max_beam_width=executor_config.max_beam_width,
max_num_tokens=executor_config.max_num_tokens,
peft_cache_config=executor_config.peft_cache_config,
scheduler_config=executor_config.scheduler_config,
cache_transceiver_config=cache_transceiver_config,
)
# Modify the executor_config.peft_cache_config which might be mutated
# inside create_py_executor_instance
executor_config.peft_cache_config = py_executor.peft_cache_config
if estimating_kv_cache:
assert kv_cache_creator is not None
@ -553,7 +561,6 @@ def create_py_executor(
resources=resources,
mapping=mapping,
pytorch_backend_config=pytorch_backend_config,
executor_config=executor_config,
ctx_chunk_config=ctx_chunk_config,
model_engine=model_engine,
start_worker=False,
@ -565,6 +572,12 @@ def create_py_executor(
garbage_collection_gen0_threshold,
kv_connector_manager=kv_connector_manager,
max_seq_len=executor_config.max_seq_len,
max_batch_size=executor_config.max_batch_size,
max_beam_width=executor_config.max_beam_width,
max_num_tokens=executor_config.max_num_tokens,
peft_cache_config=executor_config.peft_cache_config,
scheduler_config=executor_config.scheduler_config,
cache_transceiver_config=cache_transceiver_config,
)
_adjust_torch_mem_fraction(executor_config.pytorch_backend_config)

View File

@ -353,3 +353,5 @@ test_e2e.py::test_multi_nodes_eval[llama4-models/nvidia/Llama-4-Maverick-17B-128
accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl SKIP (https://nvbugs/5413362)
stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-GUARANTEED_NO_EVICT-pytorch-stress-test] SKIP (https://nvbugs/5474169)
test_e2e.py::test_trtllm_bench_iteration_log[TRT-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B] SKIP (https://nvbugs/5448523)
cpp/test_unit_tests.py::test_unit_tests[kernels-80] SKIP (https://nvbugs/5504078)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5504086)

View File

@ -389,11 +389,6 @@ def test_attention_mla(scenario: Scenario, context_sequence_lengths: List[int],
dtype = scenario.dtype
kv_cache_dtype = scenario.kv_cache_dtype
FAILED_CSL = [777, 912, 431, 42, 266, 989, 524]
if (kv_cache_dtype is torch.float8_e4m3fn
and context_sequence_lengths == FAILED_CSL):
pytest.skip("https://nvbugs/5453806")
print(
f"--------------------------------Test for scenario: {scenario} start--------------------------------"
)

View File

@ -16,7 +16,6 @@ from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
@pytest.mark.skip(reason="https://nvbugs/5461761")
@pytest.mark.parametrize(
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill",
[

View File

@ -594,6 +594,7 @@ def llm_for_sampling_params():
llm.shutdown()
@pytest.mark.skip(reason="https://nvbugs/5504095")
@pytest.mark.part0
def test_user_specify_workspace():
user_specified_ws_path = '/tmp/specified_workspace'