Source code for tensorrt_llm.models.llama.model

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
from pathlib import Path
from typing import Optional

from transformers import AutoConfig

from tensorrt_llm.models.llama.weight import (load_from_awq_llama,
                                              load_from_fp8_llama)

from ... import profiler
from ..._utils import pad_vocab_size
from ...functional import RotaryScalingType, Tensor, recv, send
from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear,
                       Embedding, GatedMLP, MoeConfig, PositionEmbeddingType,
                       PromptTuningEmbedding, RmsNorm)
from ...mapping import Mapping
from ...module import Module
from ...plugin import init_all_reduce_helper
from ...quantization import QuantMode
from ...runtime.lora_manager import LoraConfig
from ...top_model_mixin import TopModelMixin
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
                              PretrainedConfig)
from .weight import load_from_hf_llama


class LLaMADecoderLayer(Module):

    def __init__(self, config: PretrainedConfig, layer_idx: int):
        super().__init__()
        self.layer_idx = layer_idx
        self.config = config

        self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size,
                                       eps=config.norm_epsilon,
                                       dtype=config.dtype)

        self.attention = Attention(
            config.hidden_size,
            config.num_attention_heads,
            config.num_key_value_heads,
            max_position_embeddings=config.max_position_embeddings,
            dtype=config.dtype,
            attention_mask_type=AttentionMaskType.causal,
            bias=config.attn_bias,
            position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
            rotary_embedding_base=config.rotary_base,
            rotary_embedding_scaling=config.rotary_scaling,
            tp_group=config.mapping.tp_group,
            tp_size=config.mapping.tp_size,
            tp_rank=config.mapping.tp_rank,
            quant_mode=config.quant_mode,
            enable_pos_shift=config.enable_pos_shift,
            dense_context_fmha=config.dense_context_fmha,
            max_lora_rank=config.max_lora_rank)

        mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size

        ClsMLP = GatedMLP
        mlp_kwargs = {}
        if config.moe_num_experts > 1:
            ClsMLP = MOE
            mlp_kwargs = {
                "moe_config":
                MoeConfig(
                    config.moe_num_experts,
                    config.moe_top_k,
                    config.moe_tp_mode,
                    config.moe_normalization_mode,
                ),
                "tp_rank":
                config.mapping.tp_rank,
            }

        self.mlp = ClsMLP(hidden_size=config.hidden_size,
                          ffn_hidden_size=mlp_hidden_size,
                          hidden_act=config.hidden_act,
                          dtype=config.dtype,
                          bias=config.mlp_bias,
                          tp_group=config.mapping.tp_group,
                          tp_size=config.mapping.tp_size,
                          quant_mode=config.quant_mode,
                          max_lora_rank=config.max_lora_rank,
                          **mlp_kwargs)
        self.post_layernorm = RmsNorm(normalized_shape=config.hidden_size,
                                      eps=config.norm_epsilon,
                                      dtype=config.dtype)

    def forward(
            self,
            hidden_states,
            attention_mask=None,
            medusa_packed_mask=None,  # For Medusa support
            medusa_position_offsets=None,
            use_cache=False,
            kv_cache_params=None,
            attention_params=None,
            lora_layer_params=None):
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)

        attention_output = self.attention(
            hidden_states,
            attention_mask=attention_mask,
            medusa_packed_mask=medusa_packed_mask,  # For Medusa support
            medusa_position_offsets=medusa_position_offsets,
            use_cache=use_cache,
            kv_cache_params=kv_cache_params,
            attention_params=attention_params,
            lora_layer_params=lora_layer_params)

        if use_cache:
            attention_output, presents = attention_output

        hidden_states = residual + attention_output

        residual = hidden_states
        hidden_states = self.post_layernorm(hidden_states)

        hidden_states = self.mlp(hidden_states,
                                 lora_layer_params=lora_layer_params)

        hidden_states = residual + hidden_states
        if use_cache:
            return (hidden_states, presents)
        return hidden_states


[docs] class LLaMAModel(Module): def __init__(self, config: PretrainedConfig) -> None: super().__init__() init_all_reduce_helper() self.mapping = config.mapping self.use_prompt_tuning = config.use_prompt_tuning EmbeddingCls = PromptTuningEmbedding if config.use_prompt_tuning else Embedding if self.mapping.is_first_pp_rank(): self.vocab_embedding = EmbeddingCls( num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, dtype=config.dtype, tp_size=self.mapping.tp_size if config.use_parallel_embedding else 1, tp_group=self.mapping.tp_group if config.use_parallel_embedding else None, sharding_dim=config.embedding_sharding_dim, tp_rank=self.mapping.tp_rank, ) self.layers = DecoderLayerList(LLaMADecoderLayer, config) if self.mapping.is_last_pp_rank(): self.ln_f = RmsNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, dtype=config.dtype)
[docs] def forward( self, input_ids, position_ids=None, use_cache=False, attention_mask=None, medusa_position_offsets=None, # For Medusa support medusa_packed_mask=None, # For Medusa support kv_cache_params=None, attention_params=None, hidden_states=None, prompt_embedding_table: Optional[Tensor] = None, prompt_tasks: Optional[Tensor] = None, prompt_vocab_size: Optional[Tensor] = None, lora_params=None): kv_cache_params.fill_none_tensor_list(len(self.layers)) if use_cache: presents = [] ptuning_args = [ prompt_embedding_table, prompt_tasks, prompt_vocab_size ] if self.use_prompt_tuning else [] if self.mapping.is_first_pp_rank(): hidden_states = self.vocab_embedding(input_ids, *ptuning_args) else: hidden_states = recv(hidden_states, self.mapping.prev_pp_rank()) hidden_states = self.layers.forward( hidden_states, use_cache=use_cache, attention_mask=attention_mask, kv_cache_params=kv_cache_params, attention_params=attention_params, lora_params=lora_params, medusa_position_offsets=medusa_position_offsets, medusa_packed_mask=medusa_packed_mask) if use_cache: hidden_states, presents = hidden_states if self.mapping.is_last_pp_rank(): hidden_states = self.ln_f(hidden_states) else: hidden_states = send(hidden_states, self.mapping.next_pp_rank()) if use_cache: return (hidden_states, tuple(presents)) return hidden_states
[docs] class LLaMAForCausalLM(DecoderModelForCausalLM, TopModelMixin): def __init__(self, config: PretrainedConfig): self.check_config(config) transformer = LLaMAModel(config) vocab_size_padded = pad_vocab_size(config.vocab_size, config.mapping.tp_size) if config.mapping.is_last_pp_rank(): lm_head = ColumnLinear(config.hidden_size, vocab_size_padded, bias=False, dtype=config.dtype, tp_group=config.mapping.tp_group, tp_size=config.mapping.tp_size, gather_output=True) else: lm_head = None self.quant_mode = config.quant_mode self.mapping = config.mapping super().__init__(config, transformer, lm_head)
[docs] def check_config(self, config): config.set_if_not_exist('mlp_bias', False) config.set_if_not_exist('attn_bias', False) config.set_if_not_exist('rotary_base', 10000.0) config.set_if_not_exist('rotary_scaling', None) config.set_if_not_exist('enable_pos_shift', False) config.set_if_not_exist('dense_context_fmha', False) config.set_if_not_exist('moe_num_experts', 0) config.set_if_not_exist('moe_top_k', 0) config.set_if_not_exist('moe_tp_mode', MoeConfig.ParallelismMode.TENSOR_PARALLEL) config.set_if_not_exist( 'moe_normalization_mode', MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE)
[docs] @classmethod def from_hugging_face(cls, hf_model_dir, dtype='float16', mapping: Optional[Mapping] = None, quant_mode: Optional[QuantMode] = None, **kwargs): import transformers from transformers import LlamaConfig from ...models.modeling_utils import PretrainedConfig cfg = LlamaConfig.from_pretrained(hf_model_dir) num_kv_heads = cfg.num_key_value_heads if hasattr(cfg, "num_key_value_heads") \ else cfg.num_attention_heads if mapping is None: mapping = Mapping() if quant_mode is None: quant_mode = QuantMode(0) cfg.mapping = mapping cfg.dtype = dtype cfg.quant_mode = quant_mode moe_config = kwargs.get("moe_config", MoeConfig()) cfg.norm_epsilon = cfg.rms_norm_eps config = { 'architecture': cfg.architectures[0], 'dtype': cfg.dtype, 'logits_dtype': 'float32', 'num_hidden_layers': cfg.num_hidden_layers, 'num_attention_heads': cfg.num_attention_heads, 'hidden_size': cfg.hidden_size, 'intermediate_size': cfg.intermediate_size, 'num_key_value_heads': cfg.num_key_value_heads, 'vocab_size': cfg.vocab_size, 'position_embedding_type': 'rope_gpt_neox', 'max_position_embeddings': cfg.max_position_embeddings, 'hidden_act': cfg.hidden_act, 'rotary_base': getattr(cfg, 'rotary_base', 10000.0), 'rotary_scaling': getattr(cfg, 'rotary_scaling', None), 'norm_epsilon': cfg.rms_norm_eps, 'quantization': { 'group_size': 128, }, 'mapping': { 'world_size': mapping.world_size, 'tp_size': mapping.world_size, }, 'use_parallel_embedding': kwargs.get("use_parallel_embedding", False), 'embedding_sharding_dim': kwargs.get("embedding_sharding_dim", 0), 'use_prompt_tuning': kwargs.get("use_prompt_tuning", False), 'moe_num_experts': moe_config.num_experts, 'moe_top_k': moe_config.top_k, 'moe_tp_mode': moe_config.tp_mode, 'moe_normalization_mode': moe_config.normalization_mode, 'use_fused_mlp': kwargs.get("use_fused_mlp", False), 'enable_pos_shift': kwargs.get("enable_pos_shift", False), 'dense_context_fmha': kwargs.get("dense_context_fmha", False), } if quant_mode.is_int4_weight_only_per_group(): config['quantization'].update({ 'quant_algo': 'W4A8_AWQ', 'has_zero_point': False, 'pre_quant_scale': True, 'exclude_modules': [], }) elif quant_mode.has_fp8_qdq() and quant_mode.has_fp8_kv_cache(): config['quantization'].update({ 'quant_algo': 'FP8', 'kv_cache_quant_algo': 'FP8' }) else: if quant_mode != QuantMode(0): raise ValueError(f"Unsupported quantization mode: {quant_mode}") tllm_llama = LLaMAForCausalLM(PretrainedConfig.from_dict(config)) q_weights = {} if quant_mode.has_any_quant(): q_weights = tllm_llama._quantize(hf_model_dir, dtype, cfg, **kwargs) # For debug purpose, skip weights loading to be faster if kwargs.get("skip_loading_weights", False): return tllm_llama # TODO: support mixtral # weights already loaded in _quantize for int4 weight only if not quant_mode.is_int4_weight_only_per_group(): hf_model = transformers.LlamaForCausalLM profiler.start("Loading weights from HF") hf_llama = hf_model.from_pretrained( hf_model_dir, device_map={ "model": "cpu", "lm_head": "cpu", "embed_tokens": "cpu", "layers": "cpu", "norm": "cpu", }, # Load to CPU memory torch_dtype='auto', ) weights = load_from_hf_llama( tllm_llama, hf_llama, mapping=mapping, dtype=dtype, # TODO: these shall be outside from_hugging_face too. use_gemm_woq_plugin=kwargs.get("use_gemm_woq_plugin", False), lora_config=kwargs.get("lora_config", LoraConfig()), ) profiler.stop("Loading weights from HF") del hf_llama weights.update(q_weights) tllm_llama.load(weights) else: tllm_llama.load(q_weights) return tllm_llama
def _quantize(self, hf_model_dir, dtype, cfg, **kwargs): '''Given the quant_mode set in the Module object, read from given hf model call AMMO to generate quantization scales, and set the scales back the module parameters. ''' # use self destructed temporary path if kwargs[quantization_cache_dir] is not specified # sometimes the quantization checkpoint path needs to be saved for debug purpose quantized_temp_dir = tempfile.TemporaryDirectory("llama-quantized") quantized_checkpoint_path = kwargs.get("quantization_cache_dir", quantized_temp_dir.name) quantize_lm_head = kwargs.get("quantize_lm_head", False) quant_mode = cfg.quant_mode ammo_qformat = None calib_size = None if quant_mode.has_fp8_qdq() or quant_mode.has_fp8_kv_cache(): ammo_qformat = 'fp8' calib_size = 512 # TODO: how to distinguish from quant_mode about int4_awq or int4_gptq? elif quant_mode.is_int4_weight_only_per_group(): ammo_qformat = 'int4_awq' calib_size = 32 assert ammo_qformat is not None # local import to avoid pytest issue when importing AMMO and transformers lib from .quantize import quantize_llama_and_export quantize_llama_and_export(hf_model_dir, quantized_checkpoint_path, ammo_qformat, dtype, calib_size=calib_size, quantize_lm_head=quantize_lm_head) ckpt = Path(quantized_checkpoint_path) / "llama_tp1_rank0.npz" assert ckpt.exists(), f"The expecting checkpoint path {ckpt} does not exist" \ "it's likely quantization failed, pls check error logs" hf_config = AutoConfig.from_pretrained(hf_model_dir, trust_remote_code=True) if ammo_qformat == 'fp8': return load_from_fp8_llama( str(ckpt), hf_config.num_hidden_layers, cfg.mapping, fp8_kv_cache=quant_mode.has_fp8_kv_cache()) else: return load_from_awq_llama(str(ckpt), hf_config.num_hidden_layers, hf_config.vocab_size, cfg.mapping, dtype=dtype) # llama specific setters, user shall has the chance to change the module attributes after # from_hugging_face factory method created the model when these attributes is not included in the huggingface checkpoint
[docs] def rotary_base(self, val): for decoder in self.layers: decoder.attention.rotary_embedding_base = val return self
[docs] def rotary_scaling(self, scaling_type, factor): # TODO: what if there are some other behaviors triggered by the these changes? # should implement these assignment as setters of the Attention Module assert scaling_type in ("linear", "dynamic"), f"Got {scaling_type}" assert factor > 1.0, f"Got {factor}" for decoder in self.layers: decoder.attention.rotary_embedding_scale_type = RotaryScalingType.linear if scaling_type == "linear" else RotaryScalingType.dynamic decoder.attention.rotary_embedding_scale = factor return self
[docs] def default_plugin_config(self, **kwargs): plugin_config = super().default_plugin_config(**kwargs) if self.quant_mode.is_int4_weight_only_per_group(): plugin_config.set_weight_only_groupwise_quant_matmul_plugin() return plugin_config