From 085ac221a32bca4d9ef4d325444b9713412a7ddf Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 28 May 2026 19:29:12 +0100 Subject: [PATCH] Deprecate `JAISLMHeadModel` (#43784) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- docs/models/supported_models.md | 1 - tests/distributed/test_pipeline_parallel.py | 1 - tests/models/registry.py | 1 - vllm/model_executor/models/jais.py | 401 -------------------- vllm/model_executor/models/registry.py | 2 +- vllm/transformers_utils/config.py | 1 - vllm/transformers_utils/configs/__init__.py | 2 - vllm/transformers_utils/configs/jais.py | 243 ------------ 8 files changed, 1 insertion(+), 651 deletions(-) delete mode 100644 vllm/model_executor/models/jais.py delete mode 100644 vllm/transformers_utils/configs/jais.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 1f38200a786..dec3975ed1a 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -428,7 +428,6 @@ th { | `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | | `IQuestCoderForCausalLM` | IQuestCoderV1 | `IQuestLab/IQuest-Coder-V1-40B-Instruct`, etc. | | | | `IQuestLoopCoderForCausalLM` | IQuestLoopCoderV1 | `IQuestLab/IQuest-Coder-V1-40B-Loop-Instruct`, etc. | | | -| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | | `Jais2ForCausalLM` | Jais2 | `inceptionai/Jais-2-8B-Chat`, `inceptionai/Jais-2-70B-Chat`, etc. | | ✅︎ | | `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | | `KimiLinearForCausalLM` | Kimi-Linear-48B-A3B-Base, Kimi-Linear-48B-A3B-Instruct | `moonshotai/Kimi-Linear-48B-A3B-Base`, `moonshotai/Kimi-Linear-48B-A3B-Instruct` | | ✅︎ | diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 2742663093f..c2dda1b51cf 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -127,7 +127,6 @@ TEXT_GENERATION_MODELS = { # Uses Llama # "internlm/internlm-chat-7b": PPTestSettings.fast(), "internlm/internlm2-chat-7b": PPTestSettings.fast(), - "inceptionai/jais-13b-chat": PPTestSettings.fast(), "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(), "pfnet/plamo-2-1b": PPTestSettings.fast(), "pfnet/plamo-3-nict-2b-base": PPTestSettings.fast(), diff --git a/tests/models/registry.py b/tests/models/registry.py index 226eb80aca1..154eb3e479d 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -363,7 +363,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "IQuestLoopCoderForCausalLM": _HfExamplesInfo( "IQuestLab/IQuest-Coder-V1-40B-Loop-Instruct", trust_remote_code=True ), - "JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"), "Jais2ForCausalLM": _HfExamplesInfo( "inceptionai/Jais-2-8B-Chat", min_transformers_version="4.58" ), diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py deleted file mode 100644 index cc0c1aa01ba..00000000000 --- a/vllm/model_executor/models/jais.py +++ /dev/null @@ -1,401 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Adapted from -# https://huggingface.co/inceptionai/jais-30b-chat-v3/blob/main/modeling_jais.py -# Copyright 2023 The vLLM team. -# Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights -# reserved. -# Copyright 2023 Cerebras Systems. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only Jais model compatible with HuggingFace weights.""" - -import math -from collections.abc import Iterable -from itertools import islice - -import torch -from torch import nn - -from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import ( - get_pp_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) -from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.configs.jais import JAISConfig - -from .interfaces import SupportsPP -from .utils import ( - AutoWeightsLoader, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, - make_layers, - maybe_prefix, -) - - -class SwiGLUActivation(nn.Module): - def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - return x1 * nn.functional.silu(x2) - - -def _get_alibi_slopes(n): - def get_slopes_power_of_2(n): - start = 2 ** (-(2 ** -(math.log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] - - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) - return ( - get_slopes_power_of_2(closest_power_of_2) - + _get_alibi_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] - ) - - -class JAISAttention(nn.Module): - def __init__( - self, - config: JAISConfig, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ): - super().__init__() - self.hidden_size = config.hidden_size - total_num_heads = config.num_attention_heads - tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() - assert total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = total_num_heads // tensor_model_parallel_world_size - self.head_dim = self.hidden_size // total_num_heads - if hasattr(config, "scale_qk_dot_by_d"): - config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d - self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5 - self.scale = self.head_dim**-self.attn_scale_power - - self.c_attn = QKVParallelLinear( - self.hidden_size, - self.head_dim, - total_num_heads, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.c_attn", - ) - self.c_proj = RowParallelLinear( - self.hidden_size, - self.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.c_proj", - ) - - self.use_alibi = config.position_embedding_type == "alibi" - alibi_slopes = None - if self.use_alibi: - tp_rank = get_tensor_model_parallel_rank() - head_start = tp_rank * self.num_heads - head_end = (tp_rank + 1) * self.num_heads - alibi_slopes = _get_alibi_slopes(total_num_heads) - alibi_slopes = alibi_slopes[head_start:head_end] - self.attn = Attention( - self.num_heads, - self.head_dim, - scale=self.scale, - alibi_slopes=alibi_slopes, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - ) - - def forward( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor: - qkv, _ = self.c_attn(hidden_states) - q, k, v = qkv.chunk(chunks=3, dim=-1) - attn_output = self.attn(q, k, v) - attn_output, _ = self.c_proj(attn_output) - return attn_output - - -class JAISMLP(nn.Module): - def __init__( - self, - intermediate_size: int, - config: JAISConfig, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ): - super().__init__() - hidden_size = config.hidden_size - self.swiglu = config.activation_function == "swiglu" - self.c_fc = ColumnParallelLinear( - hidden_size, - intermediate_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.c_fc", - ) - self.c_fc2 = ( - ColumnParallelLinear( - hidden_size, - intermediate_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.c_fc2", - ) - if self.swiglu - else None - ) - self.c_proj = RowParallelLinear( - intermediate_size, - hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.c_proj", - ) - - self.act = SwiGLUActivation() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - if self.swiglu: - hidden_states2, _ = self.c_fc2(hidden_states) - hidden_states, _ = self.c_fc(hidden_states) - hidden_states = ( - self.act(hidden_states, hidden_states2) - if self.swiglu - else self.act(hidden_states) - ) - hidden_states, _ = self.c_proj(hidden_states) - return hidden_states - - -class JAISBlock(nn.Module): - def __init__( - self, - config: JAISConfig, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ): - super().__init__() - hidden_size = config.hidden_size - inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size - - self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = JAISAttention( - config, cache_config, quant_config, prefix=f"{prefix}.attn" - ) - self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = JAISMLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp") - - def forward( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor: - residual = hidden_states - hidden_states = self.ln_1(hidden_states) - attn_output = self.attn( - hidden_states=hidden_states, - ) - # residual connection - hidden_states = attn_output + residual - - residual = hidden_states - hidden_states = self.ln_2(hidden_states) - feed_forward_hidden_states = self.mlp(hidden_states) - # residual connection - hidden_states = residual + feed_forward_hidden_states - return hidden_states - - -@support_torch_compile -class JAISModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.config = config - assert not config.scale_attn_by_inverse_layer_idx - assert not config.reorder_and_upcast_attn - self.embed_dim = config.hidden_size - self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) - self.wpe = ( - nn.Embedding(config.max_position_embeddings, self.embed_dim) - if config.position_embedding_type != "alibi" - else None - ) - if hasattr(config, "embeddings_scale"): - self.embeddings_scale = config.embeddings_scale - else: - self.embeddings_scale = config.mup_embeddings_scale - - self.start_layer, self.end_layer, self.h = make_layers( - config.num_hidden_layers, - lambda prefix: JAISBlock( - config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix, - ), - prefix=f"{prefix}.h", - ) - - self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( - ["hidden_states"], config.n_embd - ) - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.wte(input_ids) - - def forward( - self, - input_ids: torch.Tensor | None, - position_ids: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - ) -> IntermediateTensors | torch.Tensor: - if get_pp_group().is_first_rank: - if inputs_embeds is None: - inputs_embeds = self.embed_input_ids(input_ids) - if self.wpe is not None: - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds - else: - hidden_states = inputs_embeds - hidden_states *= torch.tensor( - float(self.embeddings_scale), dtype=hidden_states.dtype - ) - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - - for layer in islice(self.h, self.start_layer, self.end_layer): - hidden_states = layer(hidden_states) - - if not get_pp_group().is_last_rank: - return IntermediateTensors({"hidden_states": hidden_states}) - - hidden_states = self.ln_f(hidden_states) - return hidden_states - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if ".attn.bias" in name or ".attn.masked_bias" in name: - # Skip attention mask. - # NOTE: "c_attn.bias" should not be skipped. - continue - if "relative_pe" in name: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - # The HF's GPT-2 implementation uses Conv1D instead of Linear. - # Because of this, we need to transpose the weights. - # Note(zhuohan): the logic below might break quantized models. - for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: - if conv1d_weight_name not in name: - continue - if not name.endswith(".weight"): - continue - loaded_weight = loaded_weight.t() - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class JAISLMHeadModel(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.transformer = JAISModel( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") - ) - if self.config.tie_word_embeddings: - self.lm_head = self.transformer.wte - else: - self.lm_head = ParallelLMHead( - self.config.vocab_size, - self.config.hidden_size, - prefix=maybe_prefix(prefix, "lm_head"), - ) - if hasattr(config, "width_scale"): - self.output_logits_scale = config.width_scale - else: - self.output_logits_scale = config.mup_output_alpha * config.mup_width_scale - self.logits_processor = LogitsProcessor( - vocab_size=config.vocab_size, scale=self.output_logits_scale - ) - self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors - ) - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.transformer.embed_input_ids(input_ids) - - def forward( - self, - input_ids: torch.Tensor | None, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - ) -> IntermediateTensors | torch.Tensor: - hidden_states = self.transformer( - input_ids, positions, intermediate_tensors, inputs_embeds - ) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - logits = self.logits_processor(self.lm_head, hidden_states) - return logits - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), - ) - return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 5e6f106150a..62c0149a3cf 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -146,7 +146,6 @@ _TEXT_GENERATION_MODELS = { "InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"), "IQuestCoderForCausalLM": ("llama", "LlamaForCausalLM"), "IQuestLoopCoderForCausalLM": ("iquest_loopcoder", "IQuestLoopCoderForCausalLM"), - "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"), @@ -705,6 +704,7 @@ _PREVIOUSLY_SUPPORTED_MODELS = { "Phi3SmallForCausalLM": "0.9.2", "Phi4FlashForCausalLM": "0.10.2", "Phi4MultimodalForCausalLM": "0.12.0", + "JAISLMHeadModel": "0.22.0", # encoder-decoder models except whisper # have been removed for V0 deprecation. "DonutForConditionalGeneration": "0.10.2", diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 0969d816902..ffe5941bcc0 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -113,7 +113,6 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( kimi_k25="KimiK25Config", RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct) RefinedWebModel="RWConfig", # For tiiuae/falcon-7b(-instruct) - jais="JAISConfig", mlp_speculator="MLPSpeculatorConfig", medusa="MedusaConfig", midashenglm="MiDashengLMConfig", diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 4a90655dcc0..5998e61dfd8 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -46,7 +46,6 @@ _CLASS_TO_MODULE: dict[str, str] = { # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. "RWConfig": "vllm.transformers_utils.configs.falcon", - "JAISConfig": "vllm.transformers_utils.configs.jais", "LagunaConfig": "vllm.transformers_utils.configs.laguna", "Lfm2MoeConfig": "vllm.transformers_utils.configs.lfm2_moe", "MedusaConfig": "vllm.transformers_utils.configs.medusa", @@ -115,7 +114,6 @@ __all__ = [ "HyperCLOVAXConfig", "IsaacConfig", "RWConfig", - "JAISConfig", "LagunaConfig", "Lfm2MoeConfig", "MedusaConfig", diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py deleted file mode 100644 index 6b581bf1877..00000000000 --- a/vllm/transformers_utils/configs/jais.py +++ /dev/null @@ -1,243 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# Copyright 2023 Cerebras Systems. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""JAIS configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - -logger = logging.get_logger(__name__) - - -class JAISConfig(PretrainedConfig): - """ - This is the configuration class to store the configuration of a - [`JAISModel`]. It is used to instantiate a JAIS model according to the - specified arguments, defining the model architecture. - - Configuration objects inherit from [`PretrainedConfig`] and can be used - to control the model outputs. Read the documentation from - [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 50257): - Vocabulary size of the JAIS model. Defines the number of different - tokens that can be represented by the - `inputs_ids` passed when calling [`JAISModel`]. - n_positions (`int`, *optional*, defaults to 1024): - The maximum sequence length that this model might ever be used - with. Typically set this to something large just in case - (e.g., 512 or 1024 or 2048). - n_embd (`int`, *optional*, defaults to 768): - Dimensionality of the embeddings and hidden states. - n_layer (`int`, *optional*, defaults to 12): - Number of hidden layers in the Transformer encoder. - n_head (`int`, *optional*, defaults to 12): - Number of attention heads for each attention layer in the - Transformer encoder. - n_inner (`int`, *optional*, defaults to None): - Dimensionality of the inner feed-forward layers. `None` will set - it to 4 times n_embd - activation_function (`str`, *optional*, defaults to `"gelu"`): - Activation function, to be selected in the list - `["relu", "silu", "gelu", "tanh", "gelu_new", "swiglu"]`. - resid_pdrop (`float`, *optional*, defaults to 0.1): - The dropout probability for all fully connected layers in - the embeddings, encoder, and pooler. - embd_pdrop (`float`, *optional*, defaults to 0.1): - The dropout ratio for the embeddings. - attn_pdrop (`float`, *optional*, defaults to 0.1): - The dropout ratio for the attention. - layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): - The epsilon to use in the layer normalization layers. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for - initializing all weight matrices. - scale_attn_weights (`bool`, *optional*, defaults to `True`): - Scale attention weights by dividing by sqrt(hidden_size).. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values - attentions (not used by all models). - scale_attn_by_inverse_layer_idx (`bool`, *optional*, default `True`): - Whether to additionally scale attention weights - by `1 / layer_idx + 1`. - reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): - Whether to scale keys (K) prior to computing attention - (dot-product) - and upcast attention dot-product/softmax to float() when training - with mixed precision. - position_embedding_type (`str`, *optional*, defaults to `"learned"`): - Positional embedding can be either `"alibi"` or `"learned"`. - mup_width_scale (`float`, *optional*, defaults to 1.0): - muP parameter to scale learning rate and initializers. Calculated - as (`d_model,0 / d_model`), where - `d_model` is the model's width and `d_model,0` is the proxy - model's width. - mup_embeddings_scale (`float`, *optional*, defaults to 1.0): - muP parameter to scale token and position embeddings. - mup_output_alpha (`float`, *optional*, defaults to 1.0): - muP parameter to scale output logits - (`output_logits_scale = mup_output_alpha * mup_width_scale`). - mup_scale_qk_dot_by_d (`bool`, *optional*, defaults to `False`): - Scale attention weights by dividing by hidden_size instead of - sqrt(hidden_size). Need to set scale_attn_weights to `True` as - well. - alibi_scaling (`dict`, *optional*): - Dictionary containing the scaling configuration for ALiBi - embeddings. Currently only supports linear - scaling strategy. Can specify either the scaling `factor` (must be - a float greater than 1) for fixed scaling - or `train_seq_len` for dynamic scaling on input samples with - sequence length > `train_seq_len`. The expected - formats are `{"type": strategy name, "factor": scaling factor}` or - `{"type": strategy name, - "train_seq_len": training sequence length}`. - architectures (`list`, *optional*, defaults to ['JAISLMHeadModel']): - architecture names for Jais. - - Example: - - ```python - >>> from transformers import JAISConfig, JAISModel - - >>> # Initializing a JAIS configuration - >>> configuration = JAISConfig() - - >>> # Initializing a model (with random weights) from the configuration - >>> model = JAISModel(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "jais" - keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = { - "hidden_size": "n_embd", - "max_position_embeddings": "n_positions", - "num_attention_heads": "n_head", - "num_hidden_layers": "n_layer", - } - - def __init__( - self, - vocab_size=50257, - n_positions=1024, - n_embd=768, - n_layer=12, - n_head=12, - n_inner=None, - activation_function="gelu_new", - resid_pdrop=0.1, - embd_pdrop=0.1, - attn_pdrop=0.1, - layer_norm_epsilon=1e-5, - initializer_range=0.02, - scale_attn_weights=True, - use_cache=True, - bos_token_id=50256, - eos_token_id=50256, - scale_attn_by_inverse_layer_idx=False, - reorder_and_upcast_attn=False, - position_embedding_type="learned", - mup_width_scale=1.0, - mup_embeddings_scale=1.0, - mup_output_alpha=1.0, - mup_scale_qk_dot_by_d=False, - alibi_scaling=None, - architectures=None, - **kwargs, - ): - self.vocab_size = vocab_size - self.n_positions = n_positions - self.n_embd = n_embd - self.n_layer = n_layer - self.n_head = n_head - self.n_inner = n_inner - self.activation_function = activation_function - self.resid_pdrop = resid_pdrop - self.embd_pdrop = embd_pdrop - self.attn_pdrop = attn_pdrop - self.layer_norm_epsilon = layer_norm_epsilon - self.initializer_range = initializer_range - self.scale_attn_weights = scale_attn_weights - self.use_cache = use_cache - self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx - self.reorder_and_upcast_attn = reorder_and_upcast_attn - - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - - self.position_embedding_type = position_embedding_type - self.mup_width_scale = mup_width_scale - self.mup_embeddings_scale = mup_embeddings_scale - self.mup_output_alpha = mup_output_alpha - self.mup_scale_qk_dot_by_d = mup_scale_qk_dot_by_d - - self.alibi_scaling = alibi_scaling - self._alibi_scaling_validation() - if architectures is None: - architectures = ["JAISLMHeadModel"] - - super().__init__( - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - architectures=architectures, - **kwargs, - ) - - def _alibi_scaling_validation(self): - """ - Validate the `alibi_scaling` configuration. - """ - if self.alibi_scaling is None: - return - - if not isinstance(self.alibi_scaling, dict) or len(self.alibi_scaling) != 2: - raise ValueError( - "`alibi_scaling` must be a dictionary with two fields, " - "`type` and `factor` or `type` and `train_seq_len`, " - f"got {self.alibi_scaling}" - ) - alibi_scaling_type = self.alibi_scaling.get("type", None) - alibi_scaling_factor = self.alibi_scaling.get("factor", None) - alibi_dynamic_scaling = self.alibi_scaling.get("train_seq_len", None) - if alibi_scaling_type is None or alibi_scaling_type != "linear": - raise ValueError( - f"`alibi_scaling`'s type field must be 'linear', " - f"got {alibi_scaling_type}" - ) - if ( - alibi_scaling_factor is not None - and not isinstance(alibi_scaling_factor, float) - or (alibi_scaling_factor is not None and alibi_scaling_factor <= 1.0) - ): - raise ValueError( - f"`alibi_scaling`'s factor field must be a float > 1.0, " - f"got {alibi_scaling_factor}" - ) - if ( - alibi_dynamic_scaling is not None - and not isinstance(alibi_dynamic_scaling, int) - or (alibi_dynamic_scaling is not None and alibi_dynamic_scaling <= 1) - ): - raise ValueError( - f"`alibi_scaling`'s `train_seq_len` field must be an " - f"integer > 1, got {alibi_dynamic_scaling}" - )