diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index 8c39a14390..531dada135 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -391,8 +391,7 @@ class RopeParams: ) if self.scale_type == RotaryScalingType.yarn: - rope_inv_freq = None - _, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_yarn( + rope_inv_freq, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_yarn( self.max_positions, self.dim, self.theta, diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index deb3d019b2..6c326b2c0a 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -110,7 +110,7 @@ class TrtllmAttentionWrapper: self.qk_rope_head_dim = None self.v_head_dim = None - self.rotary_inv_freq, self.rotary_cos_sin = rope_params.create_rope_const_params( + self.rotary_inv_freq, self.rotary_cos_sin = self.rope_params.create_rope_const_params( ) self.num_heads = num_heads diff --git a/tensorrt_llm/_torch/models/modeling_nemotron_nas.py b/tensorrt_llm/_torch/models/modeling_nemotron_nas.py index 333f52532a..146d13f16f 100644 --- a/tensorrt_llm/_torch/models/modeling_nemotron_nas.py +++ b/tensorrt_llm/_torch/models/modeling_nemotron_nas.py @@ -4,7 +4,7 @@ import torch from torch import nn from transformers import PretrainedConfig -from tensorrt_llm.functional import PositionEmbeddingType +from tensorrt_llm.functional import PositionEmbeddingType, RotaryScalingType from tensorrt_llm.lora_manager import HfLoraLoader from tensorrt_llm.models.convert_utils import split_matrix_tp @@ -48,10 +48,18 @@ def _create_linear_from_configs(model_config: ModelConfig[PretrainedConfig], class NemotronNASAttention(Attention): + NON_NEOX_TYPES = ("mistral_yarn", "rope_llama4") def __init__(self, model_config: ModelConfig[PretrainedConfig], layer_idx: int): config = model_config.pretrained_config + is_neox = getattr(model_config.pretrained_config, + "position_embedding_type", + None) not in self.NON_NEOX_TYPES + rope = RopeParams.from_config(config) + if rope.scale_type == RotaryScalingType.yarn: + rope.mscale_all_dim = 0.0 + super().__init__( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, @@ -59,8 +67,9 @@ class NemotronNASAttention(Attention): max_position_embeddings=config.max_position_embeddings, bias=False, pos_embd_params=PositionalEmbeddingParams( - type=PositionEmbeddingType.rope_gpt_neox, - rope=RopeParams.from_config(config), + type=PositionEmbeddingType.rope_gpt_neox + if is_neox else PositionEmbeddingType.rope_gptj, + rope=rope, ), layer_idx=layer_idx, dtype=config.torch_dtype,