# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json from pathlib import Path from typing import Optional from ..._utils import pad_vocab_size from ...functional import Tensor, recv, send from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear, Embedding, GatedMLP, MoeConfig, PositionEmbeddingType, RmsNorm) from ...lora_manager import LoraBuildConfig, use_lora from ...mapping import Mapping from ...module import Module from ...plugin import init_all_reduce_helper from ...quantization import W8A8_SQ_PLUGIN_LIST, QuantAlgo from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, PretrainedConfig, QuantConfig) class LLaMADecoderLayer(Module): def __init__(self, config: PretrainedConfig, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.config = config self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, dtype=config.dtype) layers_range = config.mapping.pp_layers(config.num_hidden_layers) local_layer_idx = layer_idx - layers_range[0] self.attention = Attention( local_layer_idx=local_layer_idx, hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, max_position_embeddings=config.max_position_embeddings, dtype=config.dtype, attention_mask_type=AttentionMaskType.causal, bias=config.attn_bias, position_embedding_type=PositionEmbeddingType.rope_gpt_neox, rotary_embedding_base=config.rotary_base, rotary_embedding_scaling=config.rotary_scaling, tp_group=config.mapping.tp_group, tp_size=config.mapping.tp_size, tp_rank=config.mapping.tp_rank, quant_mode=config.quant_mode) mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size ClsMLP = GatedMLP mlp_kwargs = {} if config.moe_num_experts > 1: ClsMLP = MOE mlp_kwargs = { "moe_config": MoeConfig( config.moe_num_experts, config.moe_top_k, config.moe_tp_mode, config.moe_normalization_mode, ), "tp_rank": config.mapping.tp_rank, } self.mlp = ClsMLP(hidden_size=config.hidden_size, ffn_hidden_size=mlp_hidden_size, hidden_act=config.hidden_act, dtype=config.dtype, bias=config.mlp_bias, tp_group=config.mapping.tp_group, tp_size=config.mapping.tp_size, quant_mode=config.quant_mode, **mlp_kwargs) self.post_layernorm = RmsNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, dtype=config.dtype) def forward( self, hidden_states, attention_mask=None, medusa_packed_mask=None, # For Medusa support medusa_position_offsets=None, use_cache=False, kv_cache_params=None, attention_params=None, lora_layer_params=None): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) attention_output = self.attention( hidden_states, attention_mask=attention_mask, medusa_packed_mask=medusa_packed_mask, # For Medusa support medusa_position_offsets=medusa_position_offsets, use_cache=use_cache, kv_cache_params=kv_cache_params, attention_params=attention_params, lora_layer_params=lora_layer_params) if use_cache: attention_output, presents = attention_output hidden_states = residual + attention_output residual = hidden_states hidden_states = self.post_layernorm(hidden_states) hidden_states = self.mlp(hidden_states, lora_layer_params=lora_layer_params) hidden_states = residual + hidden_states if use_cache: return (hidden_states, presents) return hidden_states class LLaMAModel(Module): def __init__(self, config: PretrainedConfig) -> None: super().__init__() init_all_reduce_helper() self.mapping = config.mapping if self.mapping.is_first_pp_rank(): self.vocab_embedding = Embedding(config.vocab_size, config.hidden_size, dtype=config.dtype) self.layers = DecoderLayerList(LLaMADecoderLayer, config) if self.mapping.is_last_pp_rank(): self.ln_f = RmsNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, dtype=config.dtype) def forward( self, input_ids, position_ids=None, use_cache=False, attention_mask=None, medusa_position_offsets=None, # For Medusa support medusa_packed_mask=None, # For Medusa support kv_cache_params=None, attention_params=None, hidden_states=None, prompt_embedding_table: Optional[Tensor] = None, prompt_tasks: Optional[Tensor] = None, prompt_vocab_size: Optional[Tensor] = None, lora_params=None): ptuning_args = [ prompt_embedding_table, prompt_tasks, prompt_vocab_size ] if prompt_embedding_table is not None else [] if self.mapping.is_first_pp_rank(): hidden_states = self.vocab_embedding(input_ids, *ptuning_args) else: hidden_states = recv(hidden_states, self.mapping.prev_pp_rank()) hidden_states = self.layers.forward( hidden_states, use_cache=use_cache, attention_mask=attention_mask, kv_cache_params=kv_cache_params, attention_params=attention_params, lora_params=lora_params, medusa_position_offsets=medusa_position_offsets, medusa_packed_mask=medusa_packed_mask) if use_cache: hidden_states, presents = hidden_states if self.mapping.is_last_pp_rank(): hidden_states = self.ln_f(hidden_states) else: hidden_states = send(hidden_states, self.mapping.next_pp_rank()) if use_cache: return (hidden_states, tuple(presents)) return hidden_states class LLaMAForCausalLM(DecoderModelForCausalLM): def __init__(self, config: PretrainedConfig): self.check_config(config) transformer = LLaMAModel(config) vocab_size_padded = pad_vocab_size(config.vocab_size, config.mapping.tp_size) if config.mapping.is_last_pp_rank(): lm_head = ColumnLinear(config.hidden_size, vocab_size_padded, bias=False, dtype=config.dtype, tp_group=config.mapping.tp_group, tp_size=config.mapping.tp_size, gather_output=True) else: lm_head = None self.quant_mode = config.quant_mode self.mapping = config.mapping super().__init__(config, transformer, lm_head) def check_config(self, config): config.set_if_not_exist('mlp_bias', False) config.set_if_not_exist('attn_bias', False) config.set_if_not_exist('rotary_base', 10000.0) config.set_if_not_exist('rotary_scaling', None) config.set_if_not_exist('moe_num_experts', 0) config.set_if_not_exist('moe_top_k', 0) config.set_if_not_exist('moe_tp_mode', MoeConfig.ParallelismMode.TENSOR_PARALLEL) config.set_if_not_exist( 'moe_normalization_mode', MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE) @classmethod def from_hugging_face(cls, hf_model_dir, dtype='float16', mapping: Optional[Mapping] = None, **kwargs): from . import convert if mapping is None: mapping = Mapping() llama = convert.from_hugging_face( cls, hf_model_dir, dtype, mapping=mapping, quantization=kwargs.get('quantization', QuantConfig()), load_by_shard=kwargs.get('load_by_shard', False), load_model_on_cpu=kwargs.get('load_model_on_cpu', False), override_fields=kwargs.get('override_fields', {}), skip_loading_weights=kwargs.get('skip_loading_weights', False), preloaded_model=kwargs.get('preloaded_model', None)) return llama def default_plugin_config(self, **kwargs): plugin_config = super().default_plugin_config(**kwargs) if self.quant_mode.is_int4_weight_only_per_group(): plugin_config.set_weight_only_groupwise_quant_matmul_plugin() return plugin_config @classmethod def from_meta_ckpt(cls, meta_ckpt_dir, dtype, mapping, use_parallel_embedding: Optional[bool] = False, embedding_sharding_dim: Optional[int] = 0): meta_config = None with open(Path(meta_ckpt_dir, "params.json")) as fp: meta_config: dict = json.load(fp) assert meta_config is not None config = {} n_embd = meta_config["dim"] n_head = meta_config["n_heads"] n_kv_head = meta_config.get("n_kv_heads", n_head) if "hidden_dim" in meta_config: inter_size = meta_config["hidden_dim"] else: multiple_of = meta_config.get("multiple_of", 1) n_embd_ = int(4 * n_embd * 2 / 3) ffn_dim_multiplier = meta_config.get("ffn_dim_multiplier", 1) inter_size = multiple_of * ( (int(n_embd_ * ffn_dim_multiplier) + multiple_of - 1) // multiple_of) # meta checkpoint don't have vocab_size|hidden_act|rotary_base specified, use same default value as HF config.update({ 'architecture': "LlamaForCausalLM", 'dtype': dtype, 'logits_dtype': 'float32', 'num_hidden_layers': meta_config["n_layers"], 'num_attention_heads': n_head, 'hidden_size': n_embd, 'intermediate_size': inter_size, 'num_key_value_heads': n_kv_head, 'vocab_size': 32000, 'position_embedding_type': 'rope_gpt_neox', 'max_position_embeddings': 2048, 'hidden_act': 'silu', 'rotary_base': 10000.0, 'norm_epsilon': meta_config["norm_eps"], 'mapping': { 'world_size': mapping.tp_size * mapping.pp_size, 'tp_size': mapping.tp_size, 'pp_size': mapping.pp_size, }, }) pretrained_config = PretrainedConfig.from_dict(config) pretrained_config.use_parallel_embedding = use_parallel_embedding pretrained_config.embedding_sharding_dim = embedding_sharding_dim pretrained_config.set_rank(mapping.rank) llama = cls(pretrained_config) from .weight import load_from_meta_llama weights = load_from_meta_llama(meta_ckpt_dir, mapping, pretrained_config) llama.load(weights) return llama @classmethod def quantize( cls, hf_model_dir, output_dir, quant_config: QuantConfig, *, dtype='float16', mapping: Optional[Mapping] = None, calib_batches=512, calib_batch_size=1, random_seed=1234, tokenizer_max_seq_length=2048, **kwargs, ): DEFAULT_AMMO_FLOW = [ QuantAlgo.W4A16_AWQ, QuantAlgo.FP8, QuantAlgo.W8A8_SQ_PER_CHANNEL, QuantAlgo.W4A8_AWQ ] use_ammo_quantization = quant_config.quant_algo in DEFAULT_AMMO_FLOW if use_ammo_quantization: super().quantize(hf_model_dir, output_dir, quant_config, dtype=dtype, mapping=mapping, calib_batches=calib_batches, calib_batch_size=calib_batch_size, random_seed=random_seed, tokenizer_max_seq_length=tokenizer_max_seq_length) else: # non-ammo, the legacy TRT-LLM native quantization algorithm: # sq, int4/int8 weights only, int8 kv cache NATIVE_QUANT_FLOW = [QuantAlgo.W4A16, QuantAlgo.W8A16, None ] + W8A8_SQ_PLUGIN_LIST is_valid_native_quant = (quant_config.quant_algo in NATIVE_QUANT_FLOW) and \ (quant_config.kv_cache_quant_algo in [QuantAlgo.INT8, None]) assert quant_config.quant_algo is not None or quant_config.kv_cache_quant_algo is not None, \ "There is no point to call the quantize function if both quant_algo and kv_cache_quant_algo is None" assert is_valid_native_quant, f"Internal error: shall call AMMO for this quantization {quant_config}" from . import convert convert.quantize( dtype, hf_model_dir, output_dir, mapping, quant_config, override_fields=kwargs.get('override_fields', {}), dataset_cache_dir=kwargs.get('dataset_cache_dir', None), ) def use_lora(self, lora_config: LoraBuildConfig): use_lora(self, lora_config)