mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat: Add support for YARN in NemotronNAS models (#4906)
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
This commit is contained in:
parent
a985c0b7e6
commit
de9779900c
@ -391,8 +391,7 @@ class RopeParams:
|
||||
)
|
||||
|
||||
if self.scale_type == RotaryScalingType.yarn:
|
||||
rope_inv_freq = None
|
||||
_, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_yarn(
|
||||
rope_inv_freq, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_yarn(
|
||||
self.max_positions,
|
||||
self.dim,
|
||||
self.theta,
|
||||
|
||||
@ -110,7 +110,7 @@ class TrtllmAttentionWrapper:
|
||||
self.qk_rope_head_dim = None
|
||||
self.v_head_dim = None
|
||||
|
||||
self.rotary_inv_freq, self.rotary_cos_sin = rope_params.create_rope_const_params(
|
||||
self.rotary_inv_freq, self.rotary_cos_sin = self.rope_params.create_rope_const_params(
|
||||
)
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
@ -4,7 +4,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
from tensorrt_llm.functional import PositionEmbeddingType, RotaryScalingType
|
||||
from tensorrt_llm.lora_manager import HfLoraLoader
|
||||
from tensorrt_llm.models.convert_utils import split_matrix_tp
|
||||
|
||||
@ -48,10 +48,18 @@ def _create_linear_from_configs(model_config: ModelConfig[PretrainedConfig],
|
||||
|
||||
|
||||
class NemotronNASAttention(Attention):
|
||||
NON_NEOX_TYPES = ("mistral_yarn", "rope_llama4")
|
||||
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig],
|
||||
layer_idx: int):
|
||||
config = model_config.pretrained_config
|
||||
is_neox = getattr(model_config.pretrained_config,
|
||||
"position_embedding_type",
|
||||
None) not in self.NON_NEOX_TYPES
|
||||
rope = RopeParams.from_config(config)
|
||||
if rope.scale_type == RotaryScalingType.yarn:
|
||||
rope.mscale_all_dim = 0.0
|
||||
|
||||
super().__init__(
|
||||
hidden_size=config.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
@ -59,8 +67,9 @@ class NemotronNASAttention(Attention):
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
bias=False,
|
||||
pos_embd_params=PositionalEmbeddingParams(
|
||||
type=PositionEmbeddingType.rope_gpt_neox,
|
||||
rope=RopeParams.from_config(config),
|
||||
type=PositionEmbeddingType.rope_gpt_neox
|
||||
if is_neox else PositionEmbeddingType.rope_gptj,
|
||||
rope=rope,
|
||||
),
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user