# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from ..._utils import pad_vocab_size from ...functional import PositionEmbeddingType, Tensor from ...layers import (MLP, Attention, AttentionMaskType, ColumnLinear, Embedding, LayerNorm) from ...module import Module from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, PretrainedConfig) class MPTDecoderLayer(Module): def __init__(self, config: PretrainedConfig, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.config = config hidden_size = config.hidden_size dtype = config.dtype tp_size = config.mapping.tp_size tp_rank = config.mapping.tp_rank tp_group = config.mapping.tp_group layernorm_epsilon = config.norm_epsilon self.input_layernorm = LayerNorm(normalized_shape=hidden_size, eps=layernorm_epsilon, bias=False, dtype=dtype) layers_range = config.mapping.pp_layers(config.num_hidden_layers) local_layer_idx = layer_idx - layers_range[0] self.attention = Attention( local_layer_idx=local_layer_idx, hidden_size=hidden_size, num_attention_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, attention_mask_type=AttentionMaskType.causal, dtype=dtype, tp_group=tp_group, tp_size=tp_size, tp_rank=tp_rank, bias=config.bias, position_embedding_type=PositionEmbeddingType.alibi, quant_mode=config.quant_mode, clip_qkv=config.clip_qkv, alibi_bias_max=config.alibi_bias_max) self.mlp = MLP(hidden_size=hidden_size, ffn_hidden_size=hidden_size * 4, hidden_act=config.hidden_act, dtype=dtype, bias=config.bias, tp_group=tp_group, tp_size=tp_size, quant_mode=config.quant_mode) self.post_layernorm = LayerNorm(normalized_shape=hidden_size, eps=layernorm_epsilon, bias=False, dtype=dtype) def forward(self, hidden_states: Tensor, attention_mask=None, use_cache=False, kv_cache_params=None, attention_params=None): assert isinstance(hidden_states, Tensor) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) attention_output = self.attention(hidden_states, attention_mask=attention_mask, use_cache=use_cache, kv_cache_params=kv_cache_params, attention_params=attention_params) if use_cache: attention_output, presents = attention_output hidden_states = residual + attention_output residual = hidden_states hidden_states = self.post_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states if use_cache: return (hidden_states, presents) return hidden_states class MPTModel(Module): def __init__(self, config: PretrainedConfig): super().__init__() self.config = config if config.mapping.is_first_pp_rank(): self.vocab_embedding = Embedding(config.vocab_size, config.hidden_size, dtype=config.dtype) self.layers = DecoderLayerList(MPTDecoderLayer, config) if config.mapping.is_last_pp_rank(): self.ln_f = LayerNorm(normalized_shape=config.hidden_size, bias=False, dtype=config.dtype) def forward(self, input_ids, position_ids, use_cache=False, attention_mask=None, kv_cache_params=None, attention_params=None): hidden_states = self.vocab_embedding(input_ids) hidden_states = self.layers(hidden_states, use_cache=use_cache, attention_mask=attention_mask, kv_cache_params=kv_cache_params, attention_params=attention_params) if use_cache: hidden_states, presents = hidden_states hidden_states = self.ln_f(hidden_states) if use_cache: return (hidden_states, tuple(presents)) return hidden_states class MPTForCausalLM(DecoderModelForCausalLM): def __init__(self, config: PretrainedConfig): self.check_config(config) transformer = MPTModel(config) vocab_size_padded = pad_vocab_size(config.vocab_size, config.mapping.tp_size) if config.mapping.is_last_pp_rank(): lm_head = ColumnLinear(config.hidden_size, vocab_size_padded, bias=config.bias, dtype=config.dtype, tp_group=config.mapping.tp_group, tp_size=config.mapping.tp_size, gather_output=True) else: lm_head = None super().__init__(config, transformer, lm_head) def check_config(self, config): config.set_if_not_exist('bias', False) config.set_if_not_exist('clip_qkv', None) config.set_if_not_exist('alibi_bias_max', 8)