feat: Add support for YARN in NemotronNAS models (#4906)

Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
This commit is contained in:
amirkl94 2025-06-29 09:45:49 +03:00 committed by GitHub
parent a985c0b7e6
commit de9779900c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 14 additions and 6 deletions

View File

@ -391,8 +391,7 @@ class RopeParams:
)
if self.scale_type == RotaryScalingType.yarn:
rope_inv_freq = None
_, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_yarn(
rope_inv_freq, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_yarn(
self.max_positions,
self.dim,
self.theta,

View File

@ -110,7 +110,7 @@ class TrtllmAttentionWrapper:
self.qk_rope_head_dim = None
self.v_head_dim = None
self.rotary_inv_freq, self.rotary_cos_sin = rope_params.create_rope_const_params(
self.rotary_inv_freq, self.rotary_cos_sin = self.rope_params.create_rope_const_params(
)
self.num_heads = num_heads

View File

@ -4,7 +4,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.functional import PositionEmbeddingType, RotaryScalingType
from tensorrt_llm.lora_manager import HfLoraLoader
from tensorrt_llm.models.convert_utils import split_matrix_tp
@ -48,10 +48,18 @@ def _create_linear_from_configs(model_config: ModelConfig[PretrainedConfig],
class NemotronNASAttention(Attention):
NON_NEOX_TYPES = ("mistral_yarn", "rope_llama4")
def __init__(self, model_config: ModelConfig[PretrainedConfig],
layer_idx: int):
config = model_config.pretrained_config
is_neox = getattr(model_config.pretrained_config,
"position_embedding_type",
None) not in self.NON_NEOX_TYPES
rope = RopeParams.from_config(config)
if rope.scale_type == RotaryScalingType.yarn:
rope.mscale_all_dim = 0.0
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
@ -59,8 +67,9 @@ class NemotronNASAttention(Attention):
max_position_embeddings=config.max_position_embeddings,
bias=False,
pos_embd_params=PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gpt_neox,
rope=RopeParams.from_config(config),
type=PositionEmbeddingType.rope_gpt_neox
if is_neox else PositionEmbeddingType.rope_gptj,
rope=rope,
),
layer_idx=layer_idx,
dtype=config.torch_dtype,