# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import os import time from pathlib import Path import onnx import tensorrt as trt import torch import torch.multiprocessing as mp from onnx import TensorProto, helper from transformers import AutoConfig, AutoModelForCausalLM import tensorrt_llm from tensorrt_llm._utils import str_dtype_to_trt from tensorrt_llm.builder import Builder from tensorrt_llm.layers.attention import PositionEmbeddingType from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.models import BaichuanForCausalLM, quantize_model from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType from tensorrt_llm.quantization import QuantMode from weight import load_from_hf_baichuan, load_from_binary, parse_bin_config # isort:skip # 2 routines: get_engine_name, serialize_engine # are direct copy from gpt example, TODO: put in utils? def trt_dtype_to_onnx(dtype): if dtype == trt.float16: return TensorProto.DataType.FLOAT16 elif dtype == trt.float32: return TensorProto.DataType.FLOAT elif dtype == trt.int32: return TensorProto.DataType.INT32 else: raise TypeError("%s is not supported" % dtype) def to_onnx(network, path): inputs = [] for i in range(network.num_inputs): network_input = network.get_input(i) inputs.append( helper.make_tensor_value_info( network_input.name, trt_dtype_to_onnx(network_input.dtype), list(network_input.shape))) outputs = [] for i in range(network.num_outputs): network_output = network.get_output(i) outputs.append( helper.make_tensor_value_info( network_output.name, trt_dtype_to_onnx(network_output.dtype), list(network_output.shape))) nodes = [] for i in range(network.num_layers): layer = network.get_layer(i) layer_inputs = [] for j in range(layer.num_inputs): ipt = layer.get_input(j) if ipt is not None: layer_inputs.append(layer.get_input(j).name) layer_outputs = [ layer.get_output(j).name for j in range(layer.num_outputs) ] nodes.append( helper.make_node(str(layer.type), name=layer.name, inputs=layer_inputs, outputs=layer_outputs, domain="com.nvidia")) onnx_model = helper.make_model(helper.make_graph(nodes, 'attention', inputs, outputs, initializer=None), producer_name='NVIDIA') onnx.save(onnx_model, path) def get_engine_name(model, dtype, tp_size, rank): return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank) def serialize_engine(engine, path): logger.info(f'Serializing engine to {path}...') tik = time.time() with open(path, 'wb') as f: f.write(bytearray(engine)) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) logger.info(f'Engine serialized. Total time: {t}') def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--world_size', type=int, default=1, help='world size, only support tensor parallelism now') parser.add_argument('--model_dir', type=str, default=None) parser.add_argument('--bin_model_dir', type=str, default=None) parser.add_argument('--model_version', type=str, default='v1_13b', choices=['v1_7b', 'v1_13b', 'v2_7b', 'v2_13b']) parser.add_argument('--dtype', type=str, default='float16', choices=['float32', 'bfloat16', 'float16']) parser.add_argument( '--timing_cache', type=str, default='model.cache', help= 'The path of to read timing cache from, will be ignored if the file does not exist' ) parser.add_argument('--log_level', type=str, default='info') parser.add_argument('--vocab_size', type=int, default=64000) parser.add_argument('--n_layer', type=int, default=40) parser.add_argument('--n_positions', type=int, default=4096) parser.add_argument('--n_embd', type=int, default=5120) parser.add_argument('--n_head', type=int, default=40) parser.add_argument('--inter_size', type=int, default=13696) parser.add_argument('--hidden_act', type=str, default='silu') parser.add_argument('--max_batch_size', type=int, default=1) parser.add_argument('--max_input_len', type=int, default=1024) parser.add_argument('--max_output_len', type=int, default=1024) parser.add_argument('--max_beam_width', type=int, default=1) parser.add_argument('--use_gpt_attention_plugin', nargs='?', const='float16', type=str, default=False, choices=['float16', 'bfloat16', 'float32']) parser.add_argument('--use_gemm_plugin', nargs='?', const='float16', type=str, default=False, choices=['float16', 'bfloat16', 'float32']) parser.add_argument('--enable_context_fmha', default=False, action='store_true') parser.add_argument('--enable_context_fmha_fp32_acc', default=False, action='store_true') parser.add_argument('--parallel_build', default=False, action='store_true') parser.add_argument('--visualize', default=False, action='store_true') parser.add_argument('--enable_debug_output', default=False, action='store_true') parser.add_argument('--gpus_per_node', type=int, default=8) parser.add_argument( '--output_dir', type=str, default='baichuan_outputs', help= 'The path to save the serialized engine files, timing cache file and model configs' ) parser.add_argument('--remove_input_padding', default=False, action='store_true') # Arguments related to the quantization of the model. parser.add_argument( '--use_smooth_quant', default=False, action="store_true", help= 'Use the SmoothQuant method to quantize activations and weights for the various GEMMs.' 'See --per_channel and --per_token for finer-grained quantization options.' ) parser.add_argument( '--per_channel', default=False, action="store_true", help= 'By default, we use a single static scaling factor for the GEMM\'s result. ' 'per_channel instead uses a different static scaling factor for each channel. ' 'The latter is usually more accurate, but a little slower.') parser.add_argument( '--per_token', default=False, action="store_true", help= 'By default, we use a single static scaling factor to scale activations in the int8 range. ' 'per_token chooses at run time, and for each token, a custom scaling factor. ' 'The latter is usually more accurate, but a little slower.') parser.add_argument( '--int8_kv_cache', default=False, action="store_true", help= 'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV' ) parser.add_argument( '--use_weight_only', default=False, action="store_true", help='Quantize weights for the various GEMMs to INT4/INT8.' 'See --weight_only_precision to set the precision') parser.add_argument( '--weight_only_precision', const='int8', type=str, nargs='?', default='int8', choices=['int8', 'int4'], help= 'Define the precision for the weights when using weight-only quantization.' 'You must also use --use_weight_only for that argument to have an impact.' ) parser.add_argument( '--use_inflight_batching', action="store_true", default=False, help="Activates inflight batching mode of gptAttentionPlugin.") parser.add_argument( '--paged_kv_cache', action="store_true", default=False, help= 'By default we use contiguous KV cache. By setting this flag you enable paged KV cache' ) parser.add_argument('--tokens_per_block', type=int, default=64, help='Number of tokens per block in paged KV cache') parser.add_argument( '--max_num_tokens', type=int, default=None, help='Define the max number of tokens supported by the engine') args = parser.parse_args() assert not ( args.use_smooth_quant and args.use_weight_only ), "You cannot enable both SmoothQuant and INT8 weight-only together." if not args.remove_input_padding: if args.use_gpt_attention_plugin: logger.warning( f"It is recommended to specify --remove_input_padding when using GPT attention plugin" ) if args.use_inflight_batching: if not args.use_gpt_attention_plugin: args.use_gpt_attention_plugin = 'float16' logger.info( f"Using GPT attention plugin for inflight batching mode. Setting to default '{args.use_gpt_attention_plugin}'" ) if not args.remove_input_padding: args.remove_input_padding = True logger.info( "Using remove input padding for inflight batching mode.") if not args.paged_kv_cache: args.paged_kv_cache = True logger.info("Using paged KV cache for inflight batching mode.") if args.max_num_tokens is not None: assert args.enable_context_fmha if args.use_smooth_quant: args.quant_mode = QuantMode.use_smooth_quant(args.per_token, args.per_channel) elif args.use_weight_only: args.quant_mode = QuantMode.use_weight_only( args.weight_only_precision == 'int4') else: args.quant_mode = QuantMode(0) if args.int8_kv_cache: args.quant_mode = args.quant_mode.set_int8_kv_cache() if args.model_dir is not None: hf_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) # override the inter_size for Baichuan args.inter_size = hf_config.intermediate_size args.n_embd = hf_config.hidden_size args.n_head = hf_config.num_attention_heads args.n_layer = hf_config.num_hidden_layers if args.model_version == 'v1_7b' or args.model_version == 'v2_7b': args.n_positions = hf_config.max_position_embeddings else: args.n_positions = hf_config.model_max_length args.vocab_size = hf_config.vocab_size args.hidden_act = hf_config.hidden_act elif args.bin_model_dir is not None: n_embd, n_head, n_layer, n_positions, vocab_size, hidden_act, inter_size, _ = parse_bin_config( Path(args.bin_model_dir) / "config.ini") args.inter_size = inter_size args.n_embd = n_embd args.n_head = n_head args.n_layer = n_layer args.n_positions = n_positions args.vocab_size = vocab_size args.hidden_act = hidden_act else: # default values are based on v1_13b, change them based on model_version if args.model_version == 'v1_7b': args.inter_size = 11008 args.n_embd = 4096 args.n_head = 32 args.n_layer = 32 args.n_positions = 4096 args.vocab_size = 64000 args.hidden_act = 'silu' elif args.model_version == 'v2_7b': args.inter_size = 11008 args.n_embd = 4096 args.n_head = 32 args.n_layer = 32 args.n_positions = 4096 args.vocab_size = 125696 args.hidden_act = 'silu' elif args.model_version == 'v2_13b': args.inter_size = 13696 args.n_embd = 5120 args.n_head = 40 args.n_layer = 40 args.n_positions = 4096 args.vocab_size = 125696 args.hidden_act = 'silu' return args def build_rank_engine(builder: Builder, builder_config: tensorrt_llm.builder.BuilderConfig, engine_name, rank, args): ''' @brief: Build the engine on the given rank. @param rank: The rank to build the engine. @param args: The cmd line arguments. @return: The built engine. ''' dtype = str_dtype_to_trt(args.dtype) mapping = Mapping(world_size=args.world_size, rank=rank, tp_size=args.world_size) if args.model_version == 'v1_7b' or args.model_version == 'v2_7b': position_embedding_type = PositionEmbeddingType.rope_gpt_neox else: position_embedding_type = PositionEmbeddingType.alibi # Initialize Module tensorrt_llm_baichuan = BaichuanForCausalLM( num_layers=args.n_layer, num_heads=args.n_head, num_kv_heads=None, hidden_size=args.n_embd, vocab_size=args.vocab_size, hidden_act=args.hidden_act, max_position_embeddings=args.n_positions, position_embedding_type=position_embedding_type, dtype=dtype, mlp_hidden_size=args.inter_size, mapping=mapping, quant_mode=args.quant_mode) if args.use_smooth_quant or args.use_weight_only: tensorrt_llm_baichuan = quantize_model(tensorrt_llm_baichuan, args.quant_mode) if args.model_dir is not None: logger.info( f'Loading HF Baichuan {args.model_version} ... from {args.model_dir}' ) tik = time.time() hf_baichuan = AutoModelForCausalLM.from_pretrained( args.model_dir, device_map={ "model": "cpu", "lm_head": "cpu" }, # Load to CPU memory torch_dtype="auto", trust_remote_code=True) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) logger.info(f'HF Baichuan {args.model_version} loaded. Total time: {t}') load_from_hf_baichuan(tensorrt_llm_baichuan, hf_baichuan, args.model_version, rank, args.world_size, dtype=args.dtype) del hf_baichuan elif args.bin_model_dir is not None: load_from_binary(tensorrt_llm_baichuan, args.bin_model_dir, mapping, fp16=(args.dtype == 'float16'), multi_query_mode=False) # Module -> Network network = builder.create_network() network.trt_network.name = engine_name if args.use_gpt_attention_plugin: network.plugin_config.set_gpt_attention_plugin( dtype=args.use_gpt_attention_plugin) if args.use_gemm_plugin: network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin) # Quantization plugins. if args.use_smooth_quant: network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype) network.plugin_config.set_rmsnorm_quantization_plugin(dtype=args.dtype) network.plugin_config.set_quantize_tensor_plugin() network.plugin_config.set_quantize_per_token_plugin() assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc) if args.enable_context_fmha: network.plugin_config.set_context_fmha(ContextFMHAType.enabled) if args.enable_context_fmha_fp32_acc: network.plugin_config.set_context_fmha( ContextFMHAType.enabled_with_fp32_acc) if args.use_weight_only: network.plugin_config.set_weight_only_quant_matmul_plugin( dtype='float16') if args.world_size > 1: network.plugin_config.set_nccl_plugin(args.dtype) if args.remove_input_padding: network.plugin_config.enable_remove_input_padding() if args.paged_kv_cache: network.plugin_config.enable_paged_kv_cache(args.tokens_per_block) with net_guard(network): # Prepare network.set_named_parameters(tensorrt_llm_baichuan.named_parameters()) # Forward inputs = tensorrt_llm_baichuan.prepare_inputs(args.max_batch_size, args.max_input_len, args.max_output_len, True, args.max_beam_width, args.max_num_tokens) tensorrt_llm_baichuan(*inputs) if args.enable_debug_output: # mark intermediate nodes' outputs for k, v in tensorrt_llm_baichuan.named_network_outputs(): v = v.trt_tensor v.name = k network.trt_network.mark_output(v) v.dtype = dtype if args.visualize: model_path = os.path.join(args.output_dir, 'test.onnx') to_onnx(network.trt_network, model_path) tensorrt_llm.graph_rewriting.optimize(network) engine = None # Network -> Engine engine = builder.build_engine(network, builder_config) if rank == 0: config_path = os.path.join(args.output_dir, 'config.json') builder.save_config(builder_config, config_path) tensorrt_llm.tools.cleanup(network, tensorrt_llm_baichuan) return engine def build(rank, args): torch.cuda.set_device(rank % args.gpus_per_node) tensorrt_llm.logger.set_level(args.log_level) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) # when doing serializing build, all ranks share one engine builder = Builder() cache = None model_name = 'baichuan' for cur_rank in range(args.world_size): # skip other ranks if parallel_build is enabled if args.parallel_build and cur_rank != rank: continue # NOTE(nkorobov): when only int8 kv cache is used together with paged kv cache no int8 tensors are exposed to TRT int8_trt_flag = args.quant_mode.has_act_or_weight_quant() or ( not args.paged_kv_cache and args.quant_mode.has_int8_kv_cache()) builder_config = builder.create_builder_config( name=model_name, precision=args.dtype, timing_cache=args.timing_cache if cache is None else cache, tensor_parallel=args.world_size, # TP only parallel_build=args.parallel_build, num_layers=args.n_layer, num_heads=args.n_head, hidden_size=args.n_embd, vocab_size=args.vocab_size, hidden_act=args.hidden_act, max_position_embeddings=args.n_positions, max_batch_size=args.max_batch_size, max_input_len=args.max_input_len, max_output_len=args.max_output_len, max_num_tokens=args.max_num_tokens, int8=int8_trt_flag, quant_mode=args.quant_mode) engine_name = get_engine_name(model_name, args.dtype, args.world_size, cur_rank) engine = build_rank_engine(builder, builder_config, engine_name, cur_rank, args) assert engine is not None, f'Failed to build engine for rank {cur_rank}' if cur_rank == 0: # Use in-memory timing cache for multiple builder passes. if not args.parallel_build: cache = builder_config.trt_builder_config.get_timing_cache() serialize_engine(engine, os.path.join(args.output_dir, engine_name)) del engine if rank == 0: ok = builder.save_timing_cache( builder_config, os.path.join(args.output_dir, "model.cache")) assert ok, "Failed to save timing cache." if __name__ == '__main__': args = parse_arguments() logger.set_level(args.log_level) tik = time.time() if args.parallel_build and args.world_size > 1 and \ torch.cuda.device_count() >= args.world_size: logger.warning( f'Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.' ) mp.spawn(build, nprocs=args.world_size, args=(args, )) else: args.parallel_build = False logger.info('Serially build TensorRT engines.') build(0, args) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) logger.info(f'Total time of building all {args.world_size} engines: {t}')