# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tempfile from pathlib import Path from typing import Optional from transformers import AutoConfig from tensorrt_llm import profiler from tensorrt_llm._utils import pad_vocab_size from tensorrt_llm.functional import RotaryScalingType, Tensor, recv, send from tensorrt_llm.layers import (MOE, Attention, AttentionMaskType, ColumnLinear, Embedding, FusedGatedMLP, GatedMLP, MoeConfig, PositionEmbeddingType, PromptTuningEmbedding, RmsNorm) from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import (DecoderLayerList, DecoderModelForCausalLM) from tensorrt_llm.module import Module from tensorrt_llm.plugin import init_all_reduce_helper from tensorrt_llm.quantization import QuantMode from tensorrt_llm.runtime.lora_manager import LoraConfig from tensorrt_llm.top_model_mixin import TopModelMixin from .weight import load_from_fp8_llama, load_from_hf_llama class GemmaDecoderLayer(Module): def __init__(self, config, layer_idx): super().__init__() self.layer_idx = layer_idx self.config = config self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, dtype=config.dtype) self.attention = Attention( layer_idx=layer_idx, hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, attention_head_size=config.head_size, max_position_embeddings=config.max_position_embeddings, dtype=config.dtype, attention_mask_type=AttentionMaskType.causal, bias=config.attn_bias, position_embedding_type=PositionEmbeddingType.rope_gpt_neox, rotary_embedding_base=config.rotary_base, rotary_embedding_scaling=config.rotary_scaling, tp_group=config.mapping.tp_group, tp_size=config.mapping.tp_size, quant_mode=config.quant_mode, enable_pos_shift=config.enable_pos_shift, dense_context_fmha=config.dense_context_fmha, ) # max_lora_rank=config.max_lora_rank) mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size ClsMLP = GatedMLP mlp_kwargs = {} if config.moe_num_experts > 1: ClsMLP = MOE mlp_kwargs = { "moe_config": MoeConfig( config.moe_num_experts, config.moe_top_k, config.moe_tp_mode, config.moe_normalization_mode, ), "tp_rank": config.mapping.tp_rank, } elif config.use_fused_mlp: ClsMLP = FusedGatedMLP self.mlp = ClsMLP( hidden_size=config.hidden_size, ffn_hidden_size=mlp_hidden_size, hidden_act=config.hidden_act, dtype=config.dtype, bias=config.mlp_bias, tp_group=config.mapping.tp_group, tp_size=config.mapping.tp_size, quant_mode=config.quant_mode, # max_lora_rank=config.max_lora_rank, **mlp_kwargs) self.post_layernorm = RmsNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, dtype=config.dtype) def forward( self, hidden_states, attention_mask=None, medusa_packed_mask=None, # For Medusa support medusa_position_offsets=None, use_cache=False, kv_cache_params=None, attention_params=None, lora_layer_params=None): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) attention_output = self.attention( hidden_states, attention_mask=attention_mask, medusa_packed_mask=medusa_packed_mask, # For Medusa support medusa_position_offsets=medusa_position_offsets, use_cache=use_cache, kv_cache_params=kv_cache_params, attention_params=attention_params, lora_layer_params=lora_layer_params) if use_cache: attention_output, presents = attention_output hidden_states = residual + attention_output residual = hidden_states hidden_states = self.post_layernorm(hidden_states) hidden_states = self.mlp(hidden_states, lora_layer_params=lora_layer_params) hidden_states = residual + hidden_states if use_cache: return (hidden_states, presents) return hidden_states class GemmaModel(Module): def __init__(self, config) -> None: super().__init__() init_all_reduce_helper() self.mapping = config.mapping self.use_prompt_tuning = config.use_prompt_tuning EmbeddingCls = PromptTuningEmbedding if config.use_prompt_tuning else Embedding if self.mapping.is_first_pp_rank(): self.vocab_embedding = EmbeddingCls( num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, dtype=config.dtype, tp_size=self.mapping.tp_size if config.use_parallel_embedding else 1, tp_group=self.mapping.tp_group if config.use_parallel_embedding else None, sharding_dim=config.embedding_sharding_dim, tp_rank=self.mapping.tp_rank, ) self.layers = DecoderLayerList(GemmaDecoderLayer, config) if self.mapping.is_last_pp_rank(): self.ln_f = RmsNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, dtype=config.dtype) def forward( self, input_ids, position_ids=None, use_cache=False, attention_mask=None, medusa_position_offsets=None, # For Medusa support medusa_packed_mask=None, # For Medusa support kv_cache_params=None, attention_params=None, hidden_states=None, prompt_embedding_table: Optional[Tensor] = None, prompt_tasks: Optional[Tensor] = None, prompt_vocab_size: Optional[Tensor] = None, lora_params=None): kv_cache_params.fill_none_tensor_list(len(self.layers)) if use_cache: presents = [] ptuning_args = [] # if self.use_prompt_tuning: # ptuning_args = [ # prompt_embedding_table, prompt_tasks, prompt_vocab_size # ] if self.mapping.is_first_pp_rank(): hidden_states = self.vocab_embedding(input_ids, *ptuning_args) else: hidden_states = recv(hidden_states, self.mapping.prev_pp_rank()) hidden_states = self.layers.forward( hidden_states, use_cache=use_cache, attention_mask=attention_mask, kv_cache_params=kv_cache_params, attention_params=attention_params, # all_reduce_workspace=all_reduce_workspace, lora_params=lora_params, # medusa_position_offsets=medusa_position_offsets, # medusa_packed_mask=medusa_packed_mask, ) if use_cache: hidden_states, presents = hidden_states if self.mapping.is_last_pp_rank(): hidden_states = self.ln_f(hidden_states) else: hidden_states = send(hidden_states, self.mapping.next_pp_rank()) if use_cache: return (hidden_states, tuple(presents)) return hidden_states class GemmaForCausalLM(DecoderModelForCausalLM, TopModelMixin): def __init__(self, config): self.check_config(config) transformer = GemmaModel(config) vocab_size_padded = pad_vocab_size(config.vocab_size, config.mapping.tp_size) if config.mapping.is_last_pp_rank(): lm_head = ColumnLinear(config.hidden_size, vocab_size_padded, bias=False, dtype=config.dtype, tp_group=config.mapping.tp_group, tp_size=config.mapping.tp_size, gather_output=True) else: lm_head = None self.quant_mode = config.quant_mode self.mapping = config.mapping super().__init__(config, transformer, lm_head) @classmethod def from_hugging_face(cls, hf_model_dir, dtype='float16', mapping: Optional[Mapping] = None, quant_mode: Optional[QuantMode] = None, **kwargs): import transformers from transformers import LlamaConfig from ...models.modeling_utils import PretrainedConfig cfg = LlamaConfig.from_pretrained(hf_model_dir) num_kv_heads = cfg.num_key_value_heads if hasattr(cfg, "num_key_value_heads") \ else cfg.num_attention_heads if mapping is None: mapping = Mapping() if quant_mode is None: quant_mode = QuantMode(0) cfg.mapping = mapping cfg.dtype = dtype cfg.quant_mode = quant_mode moe_config = kwargs.get("moe_config", MoeConfig()) cfg.norm_epsilon = cfg.rms_norm_eps config = { 'architecture': cfg.architectures[0], 'dtype': cfg.dtype, 'logits_dtype': 'float32', 'num_hidden_layers': cfg.num_hidden_layers, 'num_attention_heads': cfg.num_attention_heads, 'hidden_size': cfg.hidden_size, 'intermediate_size': cfg.intermediate_size, 'num_key_value_heads': cfg.num_key_value_heads, 'vocab_size': cfg.vocab_size, 'position_embedding_type': 'rope_gpt_neox', 'max_position_embeddings': cfg.max_position_embeddings, 'hidden_act': cfg.hidden_act, 'rotary_base': getattr(cfg, 'rotary_base', 10000.0), 'rotary_scaling': getattr(cfg, 'rotary_scaling', None), 'norm_epsilon': cfg.rms_norm_eps, 'quantization': quant_mode.to_dict(), 'mapping': { 'world_size': mapping.world_size, 'tp_size': mapping.world_size, }, 'use_parallel_embedding': kwargs.get("use_parallel_embedding", False), 'embedding_sharding_dim': kwargs.get("embedding_sharding_dim", 0), 'use_prompt_tuning': kwargs.get("use_prompt_tuning", False), 'moe_num_experts': moe_config.num_experts, 'moe_top_k': moe_config.top_k, 'moe_tp_mode': moe_config.tp_mode, 'moe_normalization_mode': moe_config.normalization_mode, 'use_fused_mlp': kwargs.get("use_fused_mlp", False), 'enable_pos_shift': kwargs.get("enable_pos_shift", False), 'dense_context_fmha': kwargs.get("dense_context_fmha", False), } if quant_mode.is_int4_weight_only_per_group(): config['quantization'].update({ 'zero': False, 'pre_quant_scale': True, 'exclude_modules': [], }) tllm_llama = GemmaForCausalLM(PretrainedConfig.from_dict(config)) q_weights = {} if quant_mode.has_any_quant(): q_weights = tllm_llama._quantize(hf_model_dir, dtype, cfg, **kwargs) # For debug purpose, skip weights loading to be faster if kwargs.get("skip_loading_weights", False): return tllm_llama # TODO: support mixtral # weights already loaded in _quantize for int4 weight only if not quant_mode.is_int4_weight_only_per_group(): hf_model = transformers.LlamaForCausalLM profiler.start("Loading weights from HF") hf_llama = hf_model.from_pretrained( hf_model_dir, device_map={ "model": "cpu", "lm_head": "cpu", "embed_tokens": "cpu", "layers": "cpu", "norm": "cpu", }, # Load to CPU memory torch_dtype='auto', ) weights = load_from_hf_llama( tllm_llama, hf_llama, mapping=mapping, dtype=dtype, # TODO: these shall be outside from_hugging_face too. use_gemm_woq_plugin=kwargs.get("use_gemm_woq_plugin", False), lora_config=kwargs.get("lora_config", LoraConfig()), ) profiler.stop("Loading weights from HF") del hf_llama weights.update(q_weights) tllm_llama.load(weights) else: tllm_llama.load(q_weights) return tllm_llama def _quantize(self, hf_model_dir, dtype, cfg, **kwargs): '''Given the quant_mode set in the Module object, read from given hf model call AMMO to generate quantization scales, and set the scales back the module parameters. ''' # use self destructed temporary path if kwargs[quantization_cache_dir] is not specified # sometimes the quantization checkpoint path needs to be saved for debug purpose quantized_temp_dir = tempfile.TemporaryDirectory("llama-quantized") quantized_checkpoint_path = kwargs.get("quantization_cache_dir", quantized_temp_dir.name) quantize_lm_head = kwargs.get("quantize_lm_head", False) quant_mode = cfg.quant_mode ammo_qformat = None calib_size = None if quant_mode.has_fp8_qdq() or quant_mode.has_fp8_kv_cache(): ammo_qformat = 'fp8' calib_size = 512 # TODO: how to distinguish from quant_mode about int4_awq or int4_gptq? elif quant_mode.is_int4_weight_only_per_group(): ammo_qformat = 'int4_awq' calib_size = 32 assert ammo_qformat is not None # local import to avoid pytest issue when importing AMMO and transformers lib from .quantize import quantize_llama_and_export quantize_llama_and_export(hf_model_dir, quantized_checkpoint_path, ammo_qformat, dtype, calib_size=calib_size, quantize_lm_head=quantize_lm_head) ckpt = Path(quantized_checkpoint_path) / "llama_tp1_rank0.npz" assert ckpt.exists(), f"The expecting checkpoint path {ckpt} does not exist" \ "it's likely quantization failed, pls check error logs" hf_config = AutoConfig.from_pretrained(hf_model_dir, trust_remote_code=True) if ammo_qformat == 'fp8': return load_from_fp8_llama( str(ckpt), hf_config, cfg.mapping, fp8_kv_cache=quant_mode.has_fp8_kv_cache()) else: return load_from_awq_llama(str(ckpt), hf_config, cfg.mapping, dtype=dtype) # llama specific setters, user shall has the chance to change the module attributes after # from_hugging_face factory method created the model when these attributes is not included in the huggingface checkpoint def rotary_base(self, val): for decoder in self.layers: decoder.attention.rotary_embedding_base = val return self def rotary_scaling(self, scaling_type, factor): # TODO: what if there are some other behaviors triggered by the these changes? # should implement these assignment as setters of the Attention Module assert scaling_type in ("linear", "dynamic"), f"Got {scaling_type}" assert factor > 1.0, f"Got {factor}" for decoder in self.layers: decoder.attention.rotary_embedding_scale_type = RotaryScalingType.linear if scaling_type == "linear" else RotaryScalingType.dynamic decoder.attention.rotary_embedding_scale = factor return self def default_plugin_config(self, **kwargs): plugin_config = super().default_plugin_config(**kwargs) if self.quant_mode.is_int4_weight_only_per_group(): plugin_config.set_weight_only_groupwise_quant_matmul_plugin() return plugin_config def check_config(self, config): config.set_if_not_exist('use_parallel_embedding', False) config.set_if_not_exist('embedding_sharding_dim', 0) config.set_if_not_exist('mlp_bias', False) config.set_if_not_exist('attn_bias', False) config.set_if_not_exist('rotary_base', 10000.0) config.set_if_not_exist('rotary_scaling', None) config.set_if_not_exist('enable_pos_shift', False) config.set_if_not_exist('dense_context_fmha', False) config.set_if_not_exist('use_fused_mlp', False) config.set_if_not_exist('moe_num_experts', 0) config.set_if_not_exist('moe_top_k', 0) config.set_if_not_exist('moe_tp_mode', MoeConfig.ParallelismMode.TENSOR_PARALLEL) config.set_if_not_exist( 'moe_normalization_mode', MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE)