#!/usr/bin/env python3 import argparse import time from enum import Enum from pathlib import Path from typing import Dict, Optional, Type import tensorrt_llm from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.gemma.config import GEMMA_ARCHITECTURE, GemmaConfig from tensorrt_llm.models.gemma.convert import (HfParser, JAXParser, KerasParser, Parsers, QuantizeModifiers, TorchParser, load_gemma_weights, non_modelopt_quantize_if_needed) from tensorrt_llm.models.gemma.model import GemmaForCausalLM from tensorrt_llm.models.modeling_utils import (QuantConfig, save_checkpoint, save_config) from tensorrt_llm.quantization import QuantAlgo class CheckpointType(str, Enum): jax = "jax" keras = "keras" torch = "torch" hf = "hf" def parse_arguments() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--ckpt-type", type=CheckpointType, choices=list(CheckpointType)) parser.add_argument("--model-dir", type=Path, required=True) parser.add_argument("--output-model-dir", type=Path, required=True) parser.add_argument("--world-size", type=int, default=1, help="world size, only support tensor parallelism now") parser.add_argument( "--use-weight-only-with-precision", choices=["int8", "int4", "w4a8_awq", "w4a16_awq"], help= "help='Quantize weights for the various GEMMs to INT4/INT8. Define the precision for the weights.", ) parser.add_argument( "--use-int8-weight-only-embedding", action="store_true", help= "Use weight only on embedding table and lm_head. (Only supported on Hopper GPU)", ) parser.add_argument("--dtype", type=str, choices=["float32", "bfloat16", "float16"]) parser.add_argument( "--enable_fp8", action="store_true", help="Use FP8 Linear layer for Attention QKV/Dense and MLP.") parser.add_argument( "--fp8_kv_cache", action="store_true", help= "By default, we use dtype for KV cache. fp8_kv_cache chooses fp8 quantization for KV", ) parser.add_argument( "--quant_ckpt_path", default=None, help= "Path of a directory to quantized model checkpoints in .safetensors format or \ path of a quantized model checkpoint in .npz format") parser.add_argument( '--calib_dataset', type=str, default='ccdv/cnn_dailymail', help= "The huggingface dataset name or the local directory of the dataset for calibration." ) parser.add_argument('--use_smooth_quant', action="store_true", help="Use smooth quant.") parser.add_argument( "--int8_kv_cache", "--calibrate_kv_cache", "-kv", action="store_true", help= "Generate scaling factors for KV cache. Used for storing KV cache in int8." ) parser.add_argument( '--per_channel', action="store_true", help= 'By default, we use a single static scaling factor for the GEMM\'s result. ' 'per_channel instead uses a different static scaling factor for each channel. ' 'The latter is usually more accurate, but a little slower.') parser.add_argument( '--per_token', action="store_true", help= 'By default, we use a single static scaling factor to scale activations in the int8 range. ' 'per_token chooses at run time, and for each token, a custom scaling factor. ' 'The latter is usually more accurate, but a little slower.') parser.add_argument( "--smoothquant", "--use_smooth_quant_plugin", "-sq", type=float, default=None, help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)" " to Smoothquant the model, and output int8 weights." " A good first try is 0.5. Must be in [0, 1]") parser.add_argument( '--tokenizer_dir', default=None, help='tokenizer path; defaults to jax_model_dir if left unspecified') parser.add_argument("--load_model_on_cpu", action="store_true") args = parser.parse_args() return args CKPT_PARSER: Dict[CheckpointType, Type[Parsers]] = { CheckpointType.jax: JAXParser, CheckpointType.keras: KerasParser, CheckpointType.torch: TorchParser, CheckpointType.hf: HfParser } def compute_quant_algo(args: argparse.Namespace) -> Optional[QuantAlgo]: if args.use_weight_only_with_precision: return { "int8": QuantAlgo.W8A16, "int4": QuantAlgo.W4A16, "w4a8_awq": QuantAlgo.W4A8_AWQ, "w4a16_awq": QuantAlgo.W4A16_AWQ, }[args.use_weight_only_with_precision] elif args.enable_fp8: return QuantAlgo.FP8 if args.use_smooth_quant: return QuantAlgo.W8A8_SQ_PER_CHANNEL elif args.smoothquant is not None: if args.per_token and args.per_channel: return QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN elif not args.per_token and not args.per_channel: return QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN elif not args.per_token and args.per_channel: return QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN elif args.per_token and not args.per_channel: return QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN return None def create_quant_config(args: argparse.Namespace) -> QuantConfig: quant_algo = compute_quant_algo(args) GemmaForCausalLM.assert_valid_quant_algo(quant_algo) quant_config = QuantConfig(quant_algo=quant_algo, smoothquant_val=args.smoothquant) if args.fp8_kv_cache: quant_config.kv_cache_quant_algo = QuantAlgo.FP8 if args.int8_kv_cache: quant_config.kv_cache_quant_algo = QuantAlgo.INT8 if args.use_weight_only_with_precision: use_awq = args.use_weight_only_with_precision.endswith("awq") use_int4 = args.use_weight_only_with_precision.endswith("int4") if use_awq: quant_config.group_size = 128 if use_awq or use_int4 or not args.use_int8_weight_only_embedding: quant_config.has_zero_point = False quant_config.pre_quant_scale = True else: quant_config.exclude_modules = ['router'] return quant_config def main() -> None: args = parse_arguments() tik = time.time() quant_config = create_quant_config(args) ckpt_parser = CKPT_PARSER[args.ckpt_type]() mapping = Mapping( world_size=args.world_size, tp_size=args.world_size, pp_size=1, ) """We don't support pipeline parallelism yet for Gemma.""" if isinstance(ckpt_parser, HfParser): trt_llm_config = GemmaConfig.from_hugging_face( args.model_dir, args.dtype, mapping=mapping, quant_config=quant_config, ) else: print(f"Loading source parameters from {args.model_dir.absolute()}") ckpt_params = ckpt_parser.load_parameters(args.model_dir) input_embedding_weights = ckpt_parser.embedding_weights(ckpt_params) num_embed, _ = input_embedding_weights.shape ckpt_params_dtype = str(input_embedding_weights.dtype).split(".")[ -1] # np.bfloat16 -> bfloat16 ckpt_config = ckpt_parser.get_config(args.model_dir, ckpt_params, num_embed) # 2B TransformerConfig(num_layers=18, num_embed=256128, embed_dim=2048, hidden_dim=16384, num_heads=8, head_dim=256, num_kv_heads=1) # 7B TransformerConfig(...) del ckpt_params print(f"Source configuration determined from parameters: {ckpt_config}") trt_llm_config = tensorrt_llm.models.GemmaConfig( architecture=GEMMA_ARCHITECTURE, dtype=args.dtype or ckpt_params_dtype, logits_dtype="float32", vocab_size=ckpt_config.num_embed, max_position_embeddings=8192, hidden_size=ckpt_config.embed_dim, num_hidden_layers=ckpt_config.num_layers, num_attention_heads=ckpt_config.num_heads, num_key_value_heads=ckpt_config.num_kv_heads, head_size=ckpt_config.head_dim, hidden_act="gelu", intermediate_size=ckpt_config.hidden_dim, norm_epsilon=1e-6, # hard-coded in RMSNorm from gemma/layers.py position_embedding_type="rope_gpt_neox", mapping=mapping, gpus_per_node=8, quantization=quant_config, use_parallel_embedding=mapping.tp_size > 1, share_embedding_table=True, ) if hasattr(ckpt_config, "model_type") and ckpt_config.model_type == "gemma2": trt_llm_config.inter_layernorms = True trt_llm_config.final_logit_softcapping = ckpt_config.final_logit_softcapping trt_llm_config.attn_logit_softcapping = ckpt_config.attn_logit_softcapping trt_llm_config.query_pre_attn_scalar = ckpt_config.query_pre_attn_scalar trt_llm_config_dict = trt_llm_config.to_dict() print(f"Determined TensorRT-LLM configuration {trt_llm_config_dict}") save_config(trt_llm_config, output_dir=args.output_model_dir, log=True) for config in trt_llm_config.for_each_rank(): hf_weights = load_gemma_weights( parameters_or_model_dir=args.model_dir, trt_llm_config=config, ckpt_parser=ckpt_parser, load_model_on_cpu=args.load_model_on_cpu) ranked_weights = non_modelopt_quantize_if_needed( hf_weights, model_dir=args.model_dir, quantize_modifiers=QuantizeModifiers.from_args(args), trt_llm_config=config) save_checkpoint(output_dir=args.output_model_dir, weights=ranked_weights, rank=config.mapping.rank) elapsed = time.strftime("%H:%M:%S", time.gmtime(time.time() - tik)) print(f"Total time of converting checkpoints: {elapsed}") if __name__ == "__main__": main()