import argparse import copy import dataclasses import json import os from enum import IntFlag, auto from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Union import numpy as np import safetensors import torch from .._common import default_net from .._utils import (get_init_params, numpy_to_torch, release_gc, str_dtype_to_torch, str_dtype_to_trt, trt_dtype_to_torch) from ..bindings import KVCacheType from ..functional import (PositionEmbeddingType, Tensor, gather_last_token_logits, tanh) from ..layers import (MLP, AttentionParams, Embedding, FusedGatedMLP, FusedRgLru, GatedMLP, KeyValueCacheParams, LoraParams, PromptTuningEmbedding, RgLru) from ..layers.attention import Attention, BertAttention from ..layers.linear import ColumnLinear, Linear, RowLinear from ..layers.lora import Lora from ..layers.moe import MOE, MoeOOTB from ..logger import logger from ..mapping import Mapping from ..module import Module, ModuleList from ..parameter import Parameter from ..plugin import init_all_reduce_helper from ..quantization import QuantMode from ..quantization.layers import (WeightOnlyGroupwiseQuantLinear, WeightOnlyGroupwiseQuantRowLinear, WeightOnlyQuantLinear, WeightOnlyQuantRowLinear) from ..quantization.mode import (KV_CACHE_QUANT_ALGO_LIST, QUANT_ALGO_LIST, W8A8_SQ_PLUGIN_LIST, QuantAlgo) from ..top_model_mixin import TopModelMixin from .convert_utils import weight_only_quantize_dict from .generation_mixin import GenerationMixin @dataclasses.dataclass(kw_only=True, frozen=True) class Gemma2ConfigGroup: query_pre_attn_scalar: int final_logit_softcapping: Optional[float] attn_logit_softcapping: Optional[float] @classmethod def keys(cls): return {f.name for f in dataclasses.fields(cls)} if TYPE_CHECKING: from typing import Type, TypeVar from typing_extensions import Self ConfigGroups = Union[Gemma2ConfigGroup] """Groupings of config where, if one of said properties exists, we assume all of the properties exist (even if they are `None`)""" CG = TypeVar("CG", bound=ConfigGroups) class SpeculativeDecodingMode(IntFlag): # [WARNING] KEEP BELOW DEFINITION IN SYNC WITH cpp/tensorrt_llm/runtime/speculativeDecodingMode.h NONE = auto() DRAFT_TOKENS_EXTERNAL = auto() MEDUSA = auto() LOOKAHEAD_DECODING = auto() EXPLICIT_DRAFT_TOKENS = auto() @staticmethod def from_arguments(args: argparse.Namespace): if args.speculative_decoding_mode is None: return SpeculativeDecodingMode.NONE elif args.speculative_decoding_mode == "draft_tokens_external": return SpeculativeDecodingMode.DRAFT_TOKENS_EXTERNAL elif args.speculative_decoding_mode == "medusa": return SpeculativeDecodingMode.MEDUSA elif args.speculative_decoding_mode == "lookahead_decoding": return SpeculativeDecodingMode.LOOKAHEAD_DECODING elif args.speculative_decoding_mode == "explicit_draft_tokens": return SpeculativeDecodingMode.EXPLICIT_DRAFT_TOKENS else: assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode @dataclasses.dataclass class QuantConfig: '''Serializable quantization configuration class, part of the PretrainedConfig ''' quant_algo: Optional[QuantAlgo] = None kv_cache_quant_algo: Optional[QuantAlgo] = None group_size: Optional[int] = 128 smoothquant_val: float = 0.5 clamp_val: Optional[List[float]] = None has_zero_point: Optional[bool] = False pre_quant_scale: Optional[bool] = False exclude_modules: Optional[List[str]] = None @property def use_plugin_sq(self): return self.quant_algo in W8A8_SQ_PLUGIN_LIST @cached_property def quant_mode(self) -> QuantMode: return QuantMode.from_quant_algo( self.quant_algo, self.kv_cache_quant_algo, ) @property def requires_calibration(self): return self.quant_algo in (set(QUANT_ALGO_LIST) - { QuantAlgo.W8A16, QuantAlgo.W4A16, QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN }) or self.kv_cache_quant_algo in KV_CACHE_QUANT_ALGO_LIST @property def requires_modelopt_quantization(self): if self.quant_algo in [ QuantAlgo.W4A16_AWQ, QuantAlgo.FP8, QuantAlgo.W8A8_SQ_PER_CHANNEL, QuantAlgo.W4A8_AWQ ]: return True elif self.quant_algo is None and self.kv_cache_quant_algo == QuantAlgo.FP8: return True else: return False def get_modelopt_qformat(self): algo_to_modelopt_map = { QuantAlgo.W8A16: "int8_wo", QuantAlgo.W4A16: "int4_wo", QuantAlgo.W4A16_AWQ: "int4_awq", QuantAlgo.W4A8_AWQ: 'w4a8_awq', QuantAlgo.FP8: 'fp8', QuantAlgo.W8A8_SQ_PER_CHANNEL: 'int8_sq', } if self.quant_algo is not None: assert self.quant_algo in algo_to_modelopt_map, f"We don't use Modelopt for quantization algorithm {self.quant_algo}, you probably shall not call this" return algo_to_modelopt_map[self.quant_algo] else: return 'full_prec' def get_modelopt_kv_cache_dtype(self): algo_to_modelopt_map = { QuantAlgo.FP8: 'fp8', QuantAlgo.INT8: 'int8', } if self.kv_cache_quant_algo is not None: assert self.kv_cache_quant_algo in algo_to_modelopt_map, f"We don't use Modelopt for quantization algorithm {self.kv_cache_quant_algo}, you probably shall not call this" return algo_to_modelopt_map[self.kv_cache_quant_algo] else: return None @classmethod def from_dict(cls, config: dict): return cls(**config) def to_dict(self): return dataclasses.asdict(self) class PretrainedConfig: def __init__(self, *, architecture: str, dtype: str, hidden_size: int, num_hidden_layers: int, num_attention_heads: int, vocab_size: Optional[int] = None, hidden_act: str = 'gelu', logits_dtype: str = 'float32', norm_epsilon: float = 1e-5, position_embedding_type: Union[ PositionEmbeddingType, str] = PositionEmbeddingType.learned_absolute, max_position_embeddings: Optional[int] = None, num_key_value_heads: Optional[int] = None, intermediate_size: Optional[int] = None, mapping: Optional[Union[Mapping, dict]] = None, quantization: Optional[Union[QuantConfig, dict]] = None, use_parallel_embedding: bool = False, embedding_sharding_dim: int = 0, share_embedding_table: bool = False, head_size: Optional[int] = None, qk_layernorm: bool = False, **kwargs): self.architecture = architecture self.dtype = dtype self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.hidden_act = hidden_act self.logits_dtype = logits_dtype self.norm_epsilon = norm_epsilon if isinstance(position_embedding_type, str): position_embedding_type = PositionEmbeddingType.from_string( position_embedding_type) assert isinstance(position_embedding_type, PositionEmbeddingType) self.position_embedding_type = position_embedding_type self.max_position_embeddings = max_position_embeddings if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads if intermediate_size is None: intermediate_size = hidden_size * 4 self.intermediate_size = intermediate_size if mapping is None: mapping = Mapping() elif isinstance(mapping, dict): mapping = Mapping.from_dict(mapping) assert isinstance(mapping, Mapping) self.mapping = mapping if quantization is None: quantization = QuantConfig() elif isinstance(quantization, dict): quantization = QuantConfig.from_dict(quantization) assert isinstance(quantization, QuantConfig) self.quantization = quantization self.use_parallel_embedding = use_parallel_embedding self.embedding_sharding_dim = embedding_sharding_dim self.share_embedding_table = share_embedding_table if share_embedding_table and mapping.tp_size > 1: if (not use_parallel_embedding) or (use_parallel_embedding and embedding_sharding_dim == 1): raise NotImplementedError( "For tensor parallelism, sharing the embedding table must set" \ "use_parallel_embedding=True and embedding_sharding_dim=0" ) if share_embedding_table and mapping.pp_size > 1: raise NotImplementedError( "Embedding table cannot be shared for pipeline parallelism") if share_embedding_table and mapping.cp_size > 1: raise NotImplementedError( "Embedding table cannot be shared for context parallelism") if head_size is None: head_size = hidden_size // num_attention_heads self.head_size = head_size self.qk_layernorm = qk_layernorm for key, value in kwargs.items(): try: setattr(self, key, value) logger.warning( f"Implicitly setting {self.__class__.__name__}.{key} = {value}" ) except AttributeError as err: raise err @property def kv_dtype(self): if self.quant_mode.has_int8_kv_cache(): return 'int8' elif self.quant_mode.has_fp8_kv_cache(): return 'fp8' else: return self.dtype def set_if_not_exist(self, key, value): if not hasattr(self, key): setattr(self, key, value) @classmethod def from_dict(cls, config: dict): # Maybe we need AutoConfig for this from . import MODEL_MAP model_cls = MODEL_MAP[config['architecture']] config_cls = getattr(model_cls, 'config_class', cls) return config_cls(**config) def to_dict(self): output = copy.deepcopy(self.__dict__) output['position_embedding_type'] = str(self.position_embedding_type) output['mapping'] = self.mapping.to_dict() output['mapping'].pop('rank') output['quantization'] = self.quantization.to_dict() return output @classmethod def from_json_file(cls, config_file: str): with open(config_file) as f: config = json.load(f) return cls.from_dict(config) @classmethod def from_checkpoint(cls, ckpt_dir: str): return cls.from_json_file(os.path.join(ckpt_dir, 'config.json')) def to_json_file(self, config_file: str): with open(config_file, 'w') as f: json.dump(self.to_dict(), f, indent=4) @property def quant_mode(self): return self.quantization.quant_mode def set_rank(self, rank): self.mapping = Mapping(self.mapping.world_size, rank=rank, cp_size=self.mapping.cp_size, tp_size=self.mapping.tp_size, pp_size=self.mapping.pp_size, moe_tp_size=self.mapping.moe_tp_size, moe_ep_size=self.mapping.moe_ep_size, gpus_per_node=self.mapping.gpus_per_node) def get_config_group(self, group_cls: "Type[CG]") -> "CG": cfg = {k: v for k, v in self.to_dict().items() if k in group_cls.keys()} return group_cls(**cfg) def has_config_group(self, group_cls: "Type[CG]") -> "bool": return all(hasattr(self, key) for key in group_cls.keys()) def for_each_rank(self) -> "Generator[Self, None, None]": for rank in range(self.mapping.world_size): config_copy = copy.deepcopy(self) config_copy.set_rank(rank) yield config_copy class DecoderLayerList(ModuleList): def __init__(self, cls, config): self.num_hidden_layers = config.num_hidden_layers self.layer_list = config.mapping.pp_layers(config.num_hidden_layers) super().__init__([cls(config, idx) for idx in self.layer_list]) def forward(self, hidden_states, use_cache=False, attention_mask=None, kv_cache_params=None, attention_params=None, position_ids=None, lora_params=None, spec_decoding_params=None, vision_token_mask=None): kv_cache_params.fill_none_tensor_list(len(self.layer_list)) if use_cache: presents = [] for layer_idx, (layer, past) in enumerate( zip(self, kv_cache_params.past_key_value)): lora_layer_params = None if lora_params is not None and lora_params.lora_ranks is not None: lora_layer_params = lora_params.get_layer_params(layer_idx) kwargs = {} if position_ids is not None: kwargs['position_ids'] = position_ids if vision_token_mask is not None: kwargs['vision_token_mask'] = vision_token_mask if lora_layer_params is not None: kwargs['lora_layer_params'] = lora_layer_params if spec_decoding_params is not None: kwargs['spec_decoding_params'] = spec_decoding_params if default_net().plugin_config.reduce_fusion: if layer_idx < self.layer_list[-1]: kwargs['next_layer_input_layernorm_args'] = ( self[layer_idx + 1].input_layernorm.weight.value, self[layer_idx + 1].input_layernorm.eps) else: kwargs['next_layer_input_layernorm_args'] = None hidden_states = layer( hidden_states, use_cache=use_cache, attention_mask=attention_mask, kv_cache_params=KeyValueCacheParams( past_key_value=[past], host_past_key_value_lengths=kv_cache_params. host_past_key_value_lengths, host_max_attention_window_sizes=kv_cache_params. host_max_attention_window_sizes, host_sink_token_length=kv_cache_params. host_sink_token_length, kv_cache_block_offsets=kv_cache_params. kv_cache_block_offsets, host_kv_cache_block_offsets=kv_cache_params. host_kv_cache_block_offsets, host_kv_cache_pool_pointers=kv_cache_params. host_kv_cache_pool_pointers, cache_indirection=kv_cache_params.cache_indirection), attention_params=attention_params, **kwargs) if use_cache: presents.append(hidden_states[1]) hidden_states = hidden_states[0] if use_cache: return hidden_states, presents return hidden_states class PostInitCaller(type): def __call__(cls, *args, **kwargs): obj = type.__call__(cls, *args, **kwargs) obj.__post_init__() return obj class PretrainedModel(Module, GenerationMixin, TopModelMixin, metaclass=PostInitCaller): def __init__(self, config: PretrainedConfig): super().__init__() init_all_reduce_helper() self.config = config def __post_init__(self): from ..quantization.quantize import quantize quantize(self, self.config.quantization) # Currently, use_parallel_embedding and share_embedding_table must be enabled before weight loading; # otherwise, the model will be inconsistent with the weights loaded from checkpoint. optimize_model( self, use_parallel_embedding=self.config.use_parallel_embedding, share_embedding_table=self.config.share_embedding_table, ) def release(self): release_gc() def __del__(self): self.release() def check_config(self, config): raise NotImplementedError( f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." ) @classmethod def from_config(cls, config: PretrainedConfig): return cls(config) @classmethod def from_checkpoint(cls, ckpt_dir: str, rank: Optional[int] = None, config: Optional[PretrainedConfig] = None): if config is None: config = PretrainedConfig.from_json_file( os.path.join(ckpt_dir, 'config.json')) if rank is not None: config.set_rank(rank) rank = config.mapping.rank weights_path = os.path.join(ckpt_dir, f'rank{rank}.safetensors') assert os.path.isfile(weights_path) weights = safetensors.torch.load_file(weights_path) is_checkpoint_pruned = getattr(config, 'is_pruned', False) preprocess_weights(weights, config, from_pruned=is_checkpoint_pruned) model = cls(config) model.load(weights, from_pruned=is_checkpoint_pruned) return model def load(self, weights, from_pruned=False): expected_names = set() required_names = set() for name, param in self.named_parameters(): expected_names.add(name) if not param.is_inited(): required_names.add(name) provided_names = set(weights.keys()) if not required_names.issubset(provided_names): raise RuntimeError( f"Required but not provided tensors:{required_names.difference(provided_names)}" ) if not provided_names.issubset(expected_names): logger.warning( f"Provided but not expected tensors: {provided_names.difference(expected_names)}" ) for name, param in self.named_parameters(): if name in provided_names: if not from_pruned: try: param.value = weights[name] except Exception as e: raise RuntimeError( f"Encounter error '{e}' for parameter '{name}'") else: param.set_value_or_dummy(weights[name]) def save_checkpoint(self, output_dir, save_config=True): # multiple ranks could share same config.json, so adding a save_config parameter to let user avoiding writing config.json in all ranks rank = self.config.mapping.rank weights = { name: numpy_to_torch(param.raw_value) for name, param in self.named_parameters() } safetensors.torch.save_file( weights, os.path.join(output_dir, f'rank{rank}.safetensors')) if save_config: self.config.to_json_file(os.path.join(output_dir, 'config.json')) def prepare_inputs( self, max_batch_size, max_input_len, max_seq_len, max_num_tokens, use_cache, max_beam_width: int = 1, opt_num_tokens: int = None, prompt_embedding_table_size: int = 0, position_encoding_2d: bool = False, max_draft_len: int = 0, speculative_decoding_draft_tokens_external: bool = False, spec_decoding_is_generation_length_variable: bool = False, gather_context_logits: bool = False, gather_generation_logits: bool = False, lora_target_modules: List[str] = None, opt_batch_size: int = 0): '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the ranges of the dimensions of when using TRT dynamic shapes. @return: a list contains values which can be fed into the self.forward() ''' # Prepare inputs remove_input_padding = default_net().plugin_config.remove_input_padding use_gpt_attention_plugin = default_net( ).plugin_config.gpt_attention_plugin use_gemm_plugin = default_net().plugin_config.gemm_plugin paged_kv_cache = default_net().plugin_config.paged_kv_cache tokens_per_block = default_net().plugin_config.tokens_per_block use_lora_plugin = default_net().plugin_config.lora_plugin multiple_profiles = default_net().plugin_config.multiple_profiles streamingllm = default_net().plugin_config.streamingllm kv_cache_type = None if not use_cache: kv_cache_type = KVCacheType.DISABLED else: if paged_kv_cache: kv_cache_type = KVCacheType.PAGED else: kv_cache_type = KVCacheType.CONTINUOUS model_inputs = self.prepare_basic_inputs( max_batch_size=max_batch_size, max_beam_width=max_beam_width, max_input_len=max_input_len, max_seq_len=max_seq_len, hidden_size=self.config.hidden_size, num_kv_heads=self.config.num_key_value_heads, head_size=self.config.head_size, num_layers=self.config.num_hidden_layers, kv_dtype=str_dtype_to_trt(self.config.kv_dtype), remove_input_padding=remove_input_padding, use_gpt_attention_plugin=use_gpt_attention_plugin, use_gemm_plugin=use_gemm_plugin, kv_cache_type=kv_cache_type, tokens_per_block=tokens_per_block, num_heads=self.config.num_attention_heads, max_num_tokens=max_num_tokens, opt_num_tokens=opt_num_tokens, dtype=str_dtype_to_trt(self.config.dtype), prompt_embedding_table_size=prompt_embedding_table_size, position_encoding_2d=position_encoding_2d, mapping=self.config.mapping, gather_context_logits=gather_context_logits, gather_generation_logits=gather_generation_logits, use_lora_plugin=use_lora_plugin, max_draft_len=max_draft_len, speculative_decoding_draft_tokens_external= speculative_decoding_draft_tokens_external, spec_decoding_is_generation_length_variable= spec_decoding_is_generation_length_variable, lora_target_modules=lora_target_modules, multiple_profiles=multiple_profiles, streamingllm=streamingllm, opt_batch_size=opt_batch_size) result = { 'input_ids': model_inputs['input_ids'], 'position_ids': model_inputs['position_ids'], 'use_cache': kv_cache_type != KVCacheType.DISABLED, 'last_token_ids': model_inputs['last_token_ids'], 'attention_mask': model_inputs['attention_mask'], 'kv_cache_params': KeyValueCacheParams( past_key_value=model_inputs['past_key_value'], host_past_key_value_lengths=model_inputs[ 'host_past_key_value_lengths'], host_max_attention_window_sizes=model_inputs[ 'host_max_attention_window_sizes'], host_sink_token_length=model_inputs['host_sink_token_length'], kv_cache_block_offsets=model_inputs['kv_cache_block_offsets'], host_kv_cache_block_offsets=model_inputs[ 'host_kv_cache_block_offsets'], host_kv_cache_pool_pointers=model_inputs[ 'host_kv_cache_pool_pointers'], cache_indirection=model_inputs['cache_indirection'], ), 'attention_params': AttentionParams( sequence_length=model_inputs['sequence_length'], context_lengths=model_inputs['context_lengths'], host_context_lengths=model_inputs['host_context_lengths'], max_context_length=max_input_len, host_request_types=model_inputs['host_request_types'], host_runtime_perf_knobs=model_inputs['host_runtime_perf_knobs']) } if prompt_embedding_table_size > 0: result['prompt_embedding_table'] = model_inputs[ 'prompt_embedding_table'] result['prompt_tasks'] = model_inputs['tasks'] result['prompt_vocab_size'] = model_inputs['prompt_vocab_size'] if model_inputs['hidden_states_input'] is not None: result['hidden_states'] = model_inputs['hidden_states_input'] if use_lora_plugin: result['lora_params'] = LoraParams( model_inputs['lora_ranks'], model_inputs['lora_weights_pointers'], host_context_lengths=model_inputs['host_context_lengths'], host_request_types=model_inputs['host_request_types']) if model_inputs['spec_decoding_params'] is not None: result['spec_decoding_params'] = model_inputs[ 'spec_decoding_params'] return result @classmethod def quantize( cls, hf_model_dir: str, output_dir: str, dtype: str = 'auto', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, *, device: str = 'cuda', calib_dataset: str = 'cnn_dailymail', calib_batches: int = 512, calib_batch_size: int = 1, calib_max_seq_length: int = 512, random_seed: int = 1234, tokenizer_max_seq_length: int = 2048, **kwargs, ): config_cls = getattr(cls, 'config_class', None) if config_cls is None: raise NotImplementedError( f"{cls.__name__} has not implemented corresponding config class, which is needed for correct config parsing." ) config: PretrainedConfig = config_cls.from_hugging_face( hf_model_dir, dtype=dtype, mapping=mapping, quant_config=quant_config, **kwargs) if config.mapping.moe_ep_size > 1: raise NotImplementedError( "Quantization for expert parallelism is not supported") if not config.quantization.requires_modelopt_quantization: raise ValueError( f"The quant_config ({quant_config}) should not call modelopt quantization" ) from ..quantization import quantize_and_export quantize_and_export( model_dir=str(hf_model_dir), device=device, calib_dataset=calib_dataset, dtype=config.dtype, qformat=config.quantization.get_modelopt_qformat(), kv_cache_dtype=config.quantization.get_modelopt_kv_cache_dtype(), calib_size=calib_batches, batch_size=calib_batch_size, calib_max_seq_length=calib_max_seq_length, awq_block_size=config.quantization.group_size, output_dir=output_dir, tp_size=config.mapping.tp_size, pp_size=config.mapping.pp_size, seed=random_seed, tokenizer_max_seq_length=tokenizer_max_seq_length, ) class DecoderModelForCausalLM(PretrainedModel): def __init__(self, config: PretrainedConfig, transformer, lm_head): super().__init__(config) self.transformer = transformer self.lm_head = lm_head self.mup_width_multiplier = getattr(config, 'mup_width_multiplier', None) # Create constant attention parameters to be reused by all layers. Attention.create_attention_const_params(self, config) self.position_embedding_type = config.position_embedding_type def forward(self, input_ids: Tensor, position_ids=None, use_cache=False, last_token_ids=None, attention_mask=None, kv_cache_params=None, attention_params=None, hidden_states=None, prompt_embedding_table: Optional[Tensor] = None, prompt_tasks: Optional[Tensor] = None, prompt_vocab_size: Optional[Tensor] = None, lora_params=None, spec_decoding_params=None): # fill attention params. attention_params = Attention.fill_attention_params( self, attention_params) kwargs = { 'input_ids': input_ids, 'position_ids': position_ids, 'use_cache': use_cache, 'attention_mask': attention_mask, 'kv_cache_params': kv_cache_params, 'attention_params': attention_params, } if lora_params is not None: kwargs['lora_params'] = lora_params if hidden_states is not None: kwargs['hidden_states'] = hidden_states if prompt_embedding_table is not None: kwargs['prompt_embedding_table'] = prompt_embedding_table if prompt_tasks is not None: kwargs['prompt_tasks'] = prompt_tasks if prompt_vocab_size is not None: kwargs['prompt_vocab_size'] = prompt_vocab_size if spec_decoding_params is not None: kwargs['spec_decoding_params'] = spec_decoding_params hidden_states = self.transformer.forward(**kwargs) if use_cache: hidden_states, presents = hidden_states if self.config.mapping.is_last_pp_rank(): hidden_states = gather_last_token_logits( hidden_states, last_token_ids, default_net().plugin_config.remove_input_padding) # [batch_size, hidden_size] -> [batch_size, vocab_size] lm_logits = self.lm_head(hidden_states) if hasattr(self.config, 'output_multiplier_scale'): lm_logits *= getattr(self.config, 'output_multiplier_scale', 1) if self.mup_width_multiplier is not None: lm_logits = lm_logits / self.mup_width_multiplier if self.config.has_config_group(Gemma2ConfigGroup): softcap = self.config.get_config_group( Gemma2ConfigGroup).final_logit_softcapping if softcap: lm_logits = lm_logits * float(1 / softcap) lm_logits = tanh(lm_logits) * float(softcap) lm_logits.mark_output('logits', self.config.logits_dtype) else: hidden_states.mark_output('hidden_states_output', self.config.dtype) if use_cache and not default_net().plugin_config.paged_kv_cache: for i, present in zip( self.config.mapping.pp_layers( self.config.num_hidden_layers), presents): present.mark_output(f'present_key_value_{i}', self.config.kv_dtype) if self.config.mapping.is_last_pp_rank(): return (lm_logits, presents, hidden_states) return (hidden_states, presents) else: if self.config.mapping.is_last_pp_rank(): return lm_logits, hidden_states return hidden_states def fuse_gate_mlp( model: PretrainedModel, gemm_swiglu_plugin_dtype: Optional[str] = None, ) -> PretrainedModel: from ..quantization.quantize import fp8_quantize quant_algo = model.config.quantization.quant_algo if quant_algo != QuantAlgo.FP8 and quant_algo is not None: logger.warning("fuse_gate_mlp cannot be done for this model. Skipping.") return model for name, mlp, layer in model.named_modules_with_parent(): if isinstance(mlp, GatedMLP): init_params = get_init_params(mlp) hidden_act = init_params["hidden_act"] if hidden_act not in ["silu", "gelu"]: logger.warning( f"fuse_gate_mlp cannot be done for {name} due to unsupported activation {hidden_act}. Skipping." ) continue init_params["inner_layernorm"] = mlp.inner_layernorm is not None fused_layer = FusedGatedMLP(**init_params) if quant_algo == QuantAlgo.FP8: fused_layer = fp8_quantize(fused_layer, model.config.quantization) if isinstance(mlp.dtype, str): dtype = str_dtype_to_torch(mlp.dtype) else: dtype = trt_dtype_to_torch(mlp.dtype) gate_weight = numpy_to_torch(mlp.gate.weight.raw_value) fc_weight = numpy_to_torch(mlp.fc.weight.raw_value) assert gate_weight.dtype == fc_weight.dtype need_qdq = gate_weight.dtype == torch.float8_e4m3fn gate_weight = gate_weight.to(dtype) fc_weight = fc_weight.to(dtype) # dequantize if needed if need_qdq: gate_weight = gate_weight.to(dtype) * numpy_to_torch( mlp.gate.weights_scaling_factor.raw_value) fc_weight = fc_weight.to(dtype) * numpy_to_torch( mlp.fc.weights_scaling_factor.raw_value) # concat fused_weight = torch.cat([gate_weight, fc_weight], dim=0) fused_weight_scaling_factor = numpy_to_torch( max( mlp.gate.weights_scaling_factor.raw_value, mlp.fc.weights_scaling_factor.raw_value, )) # quantize if needed if need_qdq: fused_weight = (fused_weight / fused_weight_scaling_factor).to( torch.float8_e4m3fn) if gemm_swiglu_plugin_dtype == 'fp8': # gemm_swiglu_plugin needs (k, n) weights # but weights should still be k-major for fp8 fused_layer.fused_fc.weight = Parameter( shape=(fused_layer.fused_fc.in_features, fused_layer.fused_fc.out_features), dtype='fp8') fused_layer.fused_fc.weight.value = fused_weight.view( fused_layer.fused_fc.in_features, fused_layer.fused_fc.out_features) else: fused_layer.fused_fc.weight.value = fused_weight fused_layer.fused_fc.weights_scaling_factor.value = fused_weight_scaling_factor fused_layer.fused_fc.activation_scaling_factor.value = max( mlp.gate.activation_scaling_factor.raw_value, mlp.fc.activation_scaling_factor.raw_value, ) elif quant_algo is None: fused_layer.fused_fc.weight.value = np.concatenate( [ mlp.gate.weight.raw_value, mlp.fc.weight.raw_value, ], axis=0, ) if mlp.bias: fused_layer.fused_fc.bias.value = np.concatenate( [mlp.gate.bias.raw_value, mlp.fc.bias.raw_value], axis=0) else: raise ValueError(f'Unsupported quant algo: {quant_algo}') fused_layer.proj = mlp.proj fused_layer.inner_layernorm = mlp.inner_layernorm mlp_name = name.rsplit('.', 1)[-1] setattr(layer, mlp_name, fused_layer) return model def unfuse_qkv_gemm(model: PretrainedModel) -> PretrainedModel: '''Split all the models' Attention layer's QKV GEMM into 3 GEMMs layer.q layer.k, layer.v and return the changed model ''' from ..quantization.quantize import quantize for name, layer in model.named_modules(): if isinstance(layer, Attention) and not layer.cross_attention: assert layer.tp_size == 1, "please disable manual tp when enable auto parallel" if layer.qkv is None: continue qkv_params = get_init_params(layer.qkv, ColumnLinear) qkv_params["bias"] = qkv_params["bias"] is not None qkv_params["strict_dtype"] = qkv_params.get( "strict_dtype") is not None q = ColumnLinear( **{ **qkv_params, "out_features": layer.tp_size * layer.num_attention_heads * layer.attention_head_size, }) k = ColumnLinear( **{ **qkv_params, "out_features": layer.tp_size * layer.num_attention_kv_heads * layer.attention_head_size, }) v = ColumnLinear( **{ **qkv_params, "out_features": layer.tp_size * layer.num_attention_kv_heads * layer.attention_head_size, }) q = quantize(q, model.config.quantization) k = quantize(k, model.config.quantization) v = quantize(v, model.config.quantization) out_features = q.out_features + k.out_features + v.out_features if isinstance(layer.qkv, ( WeightOnlyQuantLinear, WeightOnlyQuantRowLinear, WeightOnlyGroupwiseQuantLinear, WeightOnlyGroupwiseQuantRowLinear, )): out_dim = 1 else: out_dim = 0 if layer.qkv.weight.is_inited(): qkv_weight = layer.qkv.weight.raw_value weights = np.split(qkv_weight, [ qkv_weight.shape[out_dim] * q.out_features // out_features, qkv_weight.shape[out_dim] * (q.out_features + k.out_features) // out_features, ], axis=out_dim) for gemm, weight in zip([q, k, v], weights): gemm.weight.value = weight if layer.qkv.bias is not None and layer.qkv.bias.is_inited(): qkv_bias = layer.qkv.bias.raw_value biases = np.split(qkv_bias, [ qkv_bias.shape[out_dim] * q.out_features // out_features, qkv_bias.shape[out_dim] * (q.out_features + k.out_features) // out_features, ], axis=out_dim) for gemm, bias in zip([q, k, v], biases): gemm.bias.value = bias for name, parameter in layer.qkv._parameters.items(): if name not in ["weight", "bias"]: for gemm in [q, k, v]: setattr(gemm, name, parameter) layer.q = q layer.k = k layer.v = v layer.qkv = None return model def fuse_rg_lru(model: PretrainedModel) -> PretrainedModel: for name, rg_lru, parent in model.named_modules_with_parent(): if isinstance(rg_lru, RgLru): fused_layer = FusedRgLru(**get_init_params(rg_lru)) fused_layer.gate.weight.value = np.concatenate( [ rg_lru.input_gate.weight.raw_value, rg_lru.recurrent_gate.weight.raw_value, ], axis=-1, ) fused_layer.gate.bias.value = np.concatenate( [ rg_lru.input_gate.bias.raw_value, rg_lru.recurrent_gate.bias.raw_value, ], axis=-1, ) fused_layer.recurrent_param.value = rg_lru.recurrent_param.raw_value rg_lru_name = name.rsplit('.', 1)[-1] setattr(parent, rg_lru_name, fused_layer) return model def set_prompt_tuning(model: PretrainedModel) -> PretrainedModel: '''Replace the given models embedding layer with a PromptTuningEmbedding layer in-place, return the changed model Pre-conditions: vocab_embedding exists Post-conditions: isinstance(vocab_embedding, PromptTuningEmbedding) ''' for name, embedding, parent in model.named_modules_with_parent(): layer_name = name.rsplit('.', 1)[-1] if layer_name == "vocab_embedding" and isinstance(embedding, Embedding): ptuning_embedding = PromptTuningEmbedding( **get_init_params(embedding)) ptuning_embedding.weight.value = embedding.weight.raw_value parent.vocab_embedding = ptuning_embedding return model def add_lora(model: PretrainedModel, max_lora_rank: Optional[int]) -> PretrainedModel: ''' Add lora layers to the Attention/BertAttention/Linear/RowLinear/FusedGatedMLP layers to the given model, return the changed model ''' for name, layer in model.named_modules(): max_rank = max_lora_rank if isinstance(layer, (Attention, BertAttention)): if max_rank is None: max_rank = min( layer.hidden_size, layer.num_attention_heads * layer.attention_head_size, layer.num_attention_kv_heads * layer.attention_head_size) layer.qkv_lora = Lora( in_hidden_size=layer.hidden_size, out_hidden_sizes=[ layer.num_attention_heads * layer.attention_head_size, layer.num_attention_kv_heads * layer.attention_head_size, layer.num_attention_kv_heads * layer.attention_head_size ], max_low_rank=max_rank, ) if isinstance(layer, (Linear, RowLinear)): if max_rank is None: max_rank = min(layer.in_features, layer.out_features) layer.lora = Lora( in_hidden_size=layer.in_features, out_hidden_sizes=[layer.out_features], max_low_rank=max_rank, ) if isinstance(layer, (MLP, FusedGatedMLP)): if max_rank is None: max_rank = min(layer.hidden_size, layer.ffn_hidden_size // layer.tp_size) layer.lora = Lora( in_hidden_size=layer.hidden_size, out_hidden_sizes=[ layer.ffn_hidden_size // layer.tp_size, layer.ffn_hidden_size // layer.tp_size ], max_low_rank=max_rank, ) if isinstance(layer, MOE): if max_rank is None: max_rank = min(layer.hidden_size, layer.ffn_hidden_size // layer.tp_size) layer.max_low_rank = max_rank return model def to_ootb_moe(model: PretrainedModel) -> PretrainedModel: ''' Use OOTB MoE instead of MoE plugin, return the changed model ''' for name, layer, parent in model.named_modules_with_parent(): if isinstance(layer, MOE): layer_name = name.rsplit('.', 1)[-1] ootb_layer = layer.to(MoeOOTB, model.config.quantization) setattr(parent, layer_name, ootb_layer) return model def parallelize_embedding(model: PretrainedModel) -> PretrainedModel: for name, embedding, parent in model.named_modules_with_parent(): layer_name = name.rsplit('.', 1)[-1] if isinstance(embedding, Embedding) and embedding.tp_group is None: init_params = get_init_params(embedding) init_params["tp_group"] = model.config.mapping.tp_group init_params["tp_size"] = model.config.mapping.tp_size init_params["tp_rank"] = model.config.mapping.tp_rank init_params["sharding_dim"] = model.config.embedding_sharding_dim new_embedding = embedding.__class__(**init_params) setattr(parent, layer_name, new_embedding) return model def share_embedding(model: PretrainedModel) -> PretrainedModel: lm_head = None vocab_embedding = None for name, layer in model.named_modules(): layer_name = name.rsplit('.', 1)[-1] if layer_name == "lm_head": lm_head = layer if layer_name == "vocab_embedding": vocab_embedding = layer if lm_head is not None and vocab_embedding is not None: break if lm_head is not None and vocab_embedding is not None: lm_head.weight = vocab_embedding.weight if (hasattr(vocab_embedding, "per_token_scale") and vocab_embedding.per_token_scale is not None): lm_head.per_channel_scale = vocab_embedding.per_token_scale return model def set_fp8_context_fhma(model: PretrainedModel) -> PretrainedModel: for name, layer in model.named_modules(): if isinstance(layer, Attention): scale = [1.0] / layer.dense.activation_scaling_factor.raw_value layer.attention_output_orig_quant_scale = Parameter( value=scale.astype(np.float32)) return model def optimize_model( model: PretrainedModel, use_parallel_embedding: bool = False, share_embedding_table: bool = False, use_ootb_moe: bool = False, use_fused_mlp: bool = False, gemm_swiglu_plugin_dtype: Optional[str] = None, use_fused_rg_lru: bool = False, use_unfused_qkv_gemm: bool = False, use_prompt_tuning: bool = False, use_lora: bool = False, max_lora_rank: Optional[int] = None, use_fp8_context_fmha: bool = False, ) -> PretrainedModel: """ Run optimization passes on model. There are dependencies between some passes, so we always run passes in the order of arguments to guarantee the execution order. """ # before weight loading if use_parallel_embedding: model = parallelize_embedding(model) if share_embedding_table: model = share_embedding(model) # After weight loading if use_ootb_moe: model = to_ootb_moe(model) if use_fused_mlp: model = fuse_gate_mlp(model, gemm_swiglu_plugin_dtype) if use_fused_rg_lru: model = fuse_rg_lru(model) if use_unfused_qkv_gemm: model = unfuse_qkv_gemm(model) if use_prompt_tuning: model = set_prompt_tuning(model) if use_lora: model = add_lora(model, max_lora_rank) if use_fp8_context_fmha: model = set_fp8_context_fhma(model) return model def preprocess_weights(weights: Dict[str, torch.Tensor], model_config: PretrainedConfig, from_pruned=False) -> None: """This function in-place modifies weights and model_config, making them compatible with each other. Note: Typically, it should be called before model creation and weight loading. For example, preprocess_weights(weights, model_config) model = XXXForCausalLM(model_config) model.load(weights) """ quant_algo = model_config.quantization.quant_algo kv_cache_quant_algo = model_config.quantization.kv_cache_quant_algo exclude_modules = model_config.quantization.exclude_modules # INT4_AWQ if quant_algo == QuantAlgo.W4A8_AWQ or quant_algo == QuantAlgo.W4A16_AWQ: preprocessor = torch.ops.trtllm.preprocess_weights_for_mixed_gemm if quant_algo == QuantAlgo.W4A8_AWQ: activation_type = torch.float8_e4m3fn elif quant_algo == QuantAlgo.W4A16_AWQ: activation_type = torch.float16 for name, param in weights.items(): if from_pruned and param.numel() == 0: continue if name.endswith('weight') and param.dtype == torch.int8: dtype = torch.float16 if model_config.dtype == "bfloat16": dtype = torch.bfloat16 weights[name] = preprocessor(param.T.contiguous(), torch.quint4x2, activation_type).view(dtype) if name.endswith('weights_scaling_factor'): weights[name] = param.T.contiguous().to( str_dtype_to_torch(model_config.dtype)) if name.endswith('prequant_scaling_factor'): weights[name] = param.reshape(1, -1) if model_config.mapping.tp_rank > 0: if name.endswith('attention.dense.bias') or name.endswith( 'mlp.proj.bias'): weights[name] = torch.zeros_like(param) if quant_algo == QuantAlgo.W4A8_AWQ: for name in list(weights): if name.endswith('weights_scaling_factor'): activation_scaling_factor = weights.pop( name.replace('weights_scaling_factor', 'activation_scaling_factor')) weights_scaling_factor_2 = weights.pop( name.replace('weights_scaling_factor', 'weights_scaling_factor_2')) weights[name] /= weights_scaling_factor_2 weights[name.replace( 'weights_scaling_factor', 'prequant_scaling_factor')] /= activation_scaling_factor weights[name.replace( 'weights_scaling_factor', 'alpha' )] = activation_scaling_factor * weights_scaling_factor_2 # FP8 elif quant_algo == QuantAlgo.FP8: for name, param in weights.items(): if name.endswith('weight') and param.dtype == torch.int8: weights[name] = param.view(torch.float8_e4m3fn) # lm_head is not quantized to FP8 if "lm_head.weight" in weights: assert weights['lm_head.weight'].dtype == str_dtype_to_torch( model_config.dtype) weights.pop('lm_head.weights_scaling_factor', None) weights.pop('lm_head.activation_scaling_factor', None) elif quant_algo == QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN: for name, param in weights.items(): if name.endswith('weight') and param.dtype == torch.int8: weights[name] = param.view(torch.float8_e4m3fn) # lm_head is not quantized to FP8 if "lm_head.weight" in weights: assert weights['lm_head.weight'].dtype == str_dtype_to_torch( model_config.dtype) weights.pop('lm_head.weights_scaling_factor', None) weights.pop('lm_head.activation_scaling_factor', None) elif quant_algo in [QuantAlgo.W4A16, QuantAlgo.W8A16]: weights = weight_only_quantize_dict(weights=weights, quant_algo=quant_algo, exclude_modules=exclude_modules, plugin=True) # FP8 kv_cache_scaling_factor is always 1.0 if kv_cache_quant_algo == QuantAlgo.FP8: for name, param in weights.items(): if name.endswith('kv_cache_scaling_factor'): weights[name] = torch.tensor([1.0], dtype=torch.float32) # Parallel block rowlinear should not have duplicate bias. elif model_config.architecture == 'GPTJForCausalLM': if model_config.mapping.tp_rank > 0: for name, param in weights.items(): if 'attention.dense.bias' in name or 'mlp.proj.bias' in name: weights[name] = torch.zeros_like(param) # For share_embedding_table check_share_embedding(weights, model_config) def check_share_embedding(weights: Dict[str, torch.Tensor], model_config: PretrainedConfig): if model_config.share_embedding_table: if "lm_head.weight" in weights and "transformer.vocab_embedding.weight" in weights: if (weights["lm_head.weight"] - weights["transformer.vocab_embedding.weight"]).any(): logger.warning( "lm_head.weight and transformer.vocab_embedding.weight are not identical, " "share_embedding_table cannot be enabled; setting share_embedding_table=False." ) model_config.share_embedding_table = False else: weights.pop("lm_head.weight") def get_kv_cache_type_from_legacy(use_cache: bool, paged_kv_cache: bool) -> KVCacheType: if use_cache: if paged_kv_cache: return KVCacheType.PAGED else: return KVCacheType.CONTINUOUS else: return KVCacheType.DISABLED def save_config(config: PretrainedConfig, *, output_dir: str, log: bool) -> None: config_path = Path(output_dir) / "config.json" if log: logger.debug(f"Saving TensorRT-LLM configuration to {config_path}") config_path.parent.mkdir(exist_ok=True, parents=True) config_path.write_text(json.dumps(config.to_dict(), indent=4)) def save_checkpoint(*, output_dir: str, weights: dict, rank: int) -> None: """ Checkpoint saver for weight loader.""" safetensors.torch.save_file( weights, os.path.join(output_dir, f'rank{rank}.safetensors'))