# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import os import time import onnx import tensorrt as trt import torch import torch.multiprocessing as mp from onnx import TensorProto, helper from transformers import AutoConfig, AutoModelForCausalLM import tensorrt_llm from tensorrt_llm._utils import str_dtype_to_trt from tensorrt_llm.builder import Builder from tensorrt_llm.layers.attention import PositionEmbeddingType from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.models import BaichuanForCausalLM, weight_only_quantize from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType from tensorrt_llm.quantization import QuantMode from weight import load_from_hf_baichuan # isort:skip # 2 routines: get_engine_name, serialize_engine # are direct copy from gpt example, TODO: put in utils? def trt_dtype_to_onnx(dtype): if dtype == trt.float16: return TensorProto.DataType.FLOAT16 elif dtype == trt.float32: return TensorProto.DataType.FLOAT elif dtype == trt.int32: return TensorProto.DataType.INT32 else: raise TypeError("%s is not supported" % dtype) def to_onnx(network, path): inputs = [] for i in range(network.num_inputs): network_input = network.get_input(i) inputs.append( helper.make_tensor_value_info( network_input.name, trt_dtype_to_onnx(network_input.dtype), list(network_input.shape))) outputs = [] for i in range(network.num_outputs): network_output = network.get_output(i) outputs.append( helper.make_tensor_value_info( network_output.name, trt_dtype_to_onnx(network_output.dtype), list(network_output.shape))) nodes = [] for i in range(network.num_layers): layer = network.get_layer(i) layer_inputs = [] for j in range(layer.num_inputs): ipt = layer.get_input(j) if ipt is not None: layer_inputs.append(layer.get_input(j).name) layer_outputs = [ layer.get_output(j).name for j in range(layer.num_outputs) ] nodes.append( helper.make_node(str(layer.type), name=layer.name, inputs=layer_inputs, outputs=layer_outputs, domain="com.nvidia")) onnx_model = helper.make_model(helper.make_graph(nodes, 'attention', inputs, outputs, initializer=None), producer_name='NVIDIA') onnx.save(onnx_model, path) def get_engine_name(model, dtype, tp_size, rank): return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank) def serialize_engine(engine, path): logger.info(f'Serializing engine to {path}...') tik = time.time() with open(path, 'wb') as f: f.write(bytearray(engine)) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) logger.info(f'Engine serialized. Total time: {t}') def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--world_size', type=int, default=1, help='world size, only support tensor parallelism now') parser.add_argument('--model_dir', type=str, default='baichuan-inc/Baichuan-13B-Chat') parser.add_argument('--model_version', type=str, default='v1_13b', choices=['v1_7b', 'v1_13b', 'v2_7b', 'v2_13b']) parser.add_argument('--dtype', type=str, default='float16', choices=['float32', 'bfloat16', 'float16']) parser.add_argument( '--timing_cache', type=str, default='model.cache', help= 'The path of to read timing cache from, will be ignored if the file does not exist' ) parser.add_argument('--log_level', type=str, default='info') parser.add_argument('--vocab_size', type=int, default=64000) parser.add_argument('--n_layer', type=int, default=40) parser.add_argument('--n_positions', type=int, default=4096) parser.add_argument('--n_embd', type=int, default=5120) parser.add_argument('--n_head', type=int, default=40) parser.add_argument('--inter_size', type=int, default=13696) parser.add_argument('--hidden_act', type=str, default='silu') parser.add_argument('--max_batch_size', type=int, default=1) parser.add_argument('--max_input_len', type=int, default=1024) parser.add_argument('--max_output_len', type=int, default=1024) parser.add_argument('--max_beam_width', type=int, default=1) parser.add_argument('--use_gpt_attention_plugin', nargs='?', const='float16', type=str, default=False, choices=['float16', 'bfloat16', 'float32']) parser.add_argument('--use_gemm_plugin', nargs='?', const='float16', type=str, default=False, choices=['float16', 'bfloat16', 'float32']) parser.add_argument('--enable_context_fmha', default=False, action='store_true') parser.add_argument('--enable_context_fmha_fp32_acc', default=False, action='store_true') parser.add_argument('--parallel_build', default=False, action='store_true') parser.add_argument('--visualize', default=False, action='store_true') parser.add_argument('--enable_debug_output', default=False, action='store_true') parser.add_argument('--gpus_per_node', type=int, default=8) parser.add_argument( '--output_dir', type=str, default='baichuan_outputs', help= 'The path to save the serialized engine files, timing cache file and model configs' ) parser.add_argument('--remove_input_padding', default=False, action='store_true') parser.add_argument( '--use_weight_only', default=False, action="store_true", help='Quantize weights for the various GEMMs to INT4/INT8.' 'See --weight_only_precision to set the precision') parser.add_argument( '--weight_only_precision', const='int8', type=str, nargs='?', default='int8', choices=['int8', 'int4'], help= 'Define the precision for the weights when using weight-only quantization.' 'You must also use --use_weight_only for that argument to have an impact.' ) parser.add_argument( '--use_inflight_batching', action="store_true", default=False, help="Activates inflight batching mode of gptAttentionPlugin.") parser.add_argument( '--paged_kv_cache', action="store_true", default=False, help= 'By default we use contiguous KV cache. By setting this flag you enable paged KV cache' ) parser.add_argument('--tokens_per_block', type=int, default=64, help='Number of tokens per block in paged KV cache') parser.add_argument( '--max_num_tokens', type=int, default=None, help='Define the max number of tokens supported by the engine') args = parser.parse_args() if args.use_weight_only: args.quant_mode = QuantMode.use_weight_only( args.weight_only_precision == 'int4') else: args.quant_mode = QuantMode(0) if args.use_inflight_batching: if not args.use_gpt_attention_plugin: args.use_gpt_attention_plugin = 'float16' logger.info( f"Using GPT attention plugin for inflight batching mode. Setting to default '{args.use_gpt_attention_plugin}'" ) if not args.remove_input_padding: args.remove_input_padding = True logger.info( "Using remove input padding for inflight batching mode.") if not args.paged_kv_cache: args.paged_kv_cache = True logger.info("Using paged KV cache for inflight batching mode.") if args.max_num_tokens is not None: assert args.enable_context_fmha if args.model_dir is not None: hf_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) # override the inter_size for Baichuan args.inter_size = hf_config.intermediate_size args.n_embd = hf_config.hidden_size args.n_head = hf_config.num_attention_heads args.n_layer = hf_config.num_hidden_layers if args.model_version == 'v1_7b' or args.model_version == 'v2_7b': args.n_positions = hf_config.max_position_embeddings else: args.n_positions = hf_config.model_max_length args.vocab_size = hf_config.vocab_size args.hidden_act = hf_config.hidden_act else: # default values are based on v1_13b, change them based on model_version if args.model_version == 'v1_7b': args.inter_size = 11008 args.n_embd = 4096 args.n_head = 32 args.n_layer = 32 args.n_positions = 4096 args.vocab_size = 64000 args.hidden_act = 'silu' elif args.model_version == 'v2_7b': args.inter_size = 11008 args.n_embd = 4096 args.n_head = 32 args.n_layer = 32 args.n_positions = 4096 args.vocab_size = 125696 args.hidden_act = 'silu' elif args.model_version == 'v2_13b': args.inter_size = 13696 args.n_embd = 5120 args.n_head = 40 args.n_layer = 40 args.n_positions = 4096 args.vocab_size = 125696 args.hidden_act = 'silu' if args.dtype == 'bfloat16': assert args.use_gemm_plugin, "Please use gemm plugin when dtype is bfloat16" return args def build_rank_engine(builder: Builder, builder_config: tensorrt_llm.builder.BuilderConfig, engine_name, rank, args): ''' @brief: Build the engine on the given rank. @param rank: The rank to build the engine. @param args: The cmd line arguments. @return: The built engine. ''' kv_dtype = str_dtype_to_trt(args.dtype) if args.model_version == 'v1_7b' or args.model_version == 'v2_7b': position_embedding_type = PositionEmbeddingType.rope_gpt_neox else: position_embedding_type = PositionEmbeddingType.alibi # Initialize Module tensorrt_llm_baichuan = BaichuanForCausalLM( num_layers=args.n_layer, num_heads=args.n_head, hidden_size=args.n_embd, vocab_size=args.vocab_size, hidden_act=args.hidden_act, max_position_embeddings=args.n_positions, position_embedding_type=position_embedding_type, dtype=kv_dtype, mlp_hidden_size=args.inter_size, mapping=Mapping(world_size=args.world_size, rank=rank, tp_size=args.world_size)) if args.use_weight_only and args.weight_only_precision == 'int8': tensorrt_llm_baichuan = weight_only_quantize( tensorrt_llm_baichuan, QuantMode.use_weight_only()) elif args.use_weight_only and args.weight_only_precision == 'int4': tensorrt_llm_baichuan = weight_only_quantize( tensorrt_llm_baichuan, QuantMode.use_weight_only(use_int4_weights=True)) if args.model_dir is not None: logger.info( f'Loading HF Baichuan {args.model_version} ... from {args.model_dir}' ) tik = time.time() hf_baichuan = AutoModelForCausalLM.from_pretrained( args.model_dir, device_map={ "model": "cpu", "lm_head": "cpu" }, # Load to CPU memory torch_dtype="auto", trust_remote_code=True) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) logger.info(f'HF Baichuan {args.model_version} loaded. Total time: {t}') load_from_hf_baichuan(tensorrt_llm_baichuan, hf_baichuan, args.model_version, rank, args.world_size, dtype=args.dtype) del hf_baichuan # Module -> Network network = builder.create_network() network.trt_network.name = engine_name if args.use_gpt_attention_plugin: network.plugin_config.set_gpt_attention_plugin( dtype=args.use_gpt_attention_plugin) if args.use_gemm_plugin: network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin) assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc) if args.enable_context_fmha: network.plugin_config.set_context_fmha(ContextFMHAType.enabled) if args.enable_context_fmha_fp32_acc: network.plugin_config.set_context_fmha( ContextFMHAType.enabled_with_fp32_acc) if args.use_weight_only: network.plugin_config.set_weight_only_quant_matmul_plugin( dtype='float16') if args.world_size > 1: network.plugin_config.set_nccl_plugin(args.dtype) if args.remove_input_padding: network.plugin_config.enable_remove_input_padding() if args.paged_kv_cache: network.plugin_config.enable_paged_kv_cache(args.tokens_per_block) with net_guard(network): # Prepare network.set_named_parameters(tensorrt_llm_baichuan.named_parameters()) # Forward inputs = tensorrt_llm_baichuan.prepare_inputs(args.max_batch_size, args.max_input_len, args.max_output_len, True, args.max_beam_width, args.max_num_tokens) tensorrt_llm_baichuan(*inputs) if args.enable_debug_output: # mark intermediate nodes' outputs for k, v in tensorrt_llm_baichuan.named_network_outputs(): v = v.trt_tensor v.name = k network.trt_network.mark_output(v) v.dtype = kv_dtype if args.visualize: model_path = os.path.join(args.output_dir, 'test.onnx') to_onnx(network.trt_network, model_path) tensorrt_llm.graph_rewriting.optimize(network) engine = None # Network -> Engine engine = builder.build_engine(network, builder_config) if rank == 0: config_path = os.path.join(args.output_dir, 'config.json') builder.save_config(builder_config, config_path) return engine def build(rank, args): torch.cuda.set_device(rank % args.gpus_per_node) tensorrt_llm.logger.set_level(args.log_level) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) # when doing serializing build, all ranks share one engine builder = Builder() cache = None model_name = 'baichuan' for cur_rank in range(args.world_size): # skip other ranks if parallel_build is enabled if args.parallel_build and cur_rank != rank: continue builder_config = builder.create_builder_config( name=model_name, precision=args.dtype, timing_cache=args.timing_cache if cache is None else cache, tensor_parallel=args.world_size, # TP only parallel_build=args.parallel_build, num_layers=args.n_layer, num_heads=args.n_head, hidden_size=args.n_embd, vocab_size=args.vocab_size, hidden_act=args.hidden_act, max_position_embeddings=args.n_positions, max_batch_size=args.max_batch_size, max_input_len=args.max_input_len, max_output_len=args.max_output_len, max_num_tokens=args.max_num_tokens, int8=args.quant_mode.has_act_and_weight_quant()) engine_name = get_engine_name(model_name, args.dtype, args.world_size, cur_rank) engine = build_rank_engine(builder, builder_config, engine_name, cur_rank, args) assert engine is not None, f'Failed to build engine for rank {cur_rank}' if cur_rank == 0: # Use in-memory timing cache for multiple builder passes. if not args.parallel_build: cache = builder_config.trt_builder_config.get_timing_cache() serialize_engine(engine, os.path.join(args.output_dir, engine_name)) if rank == 0: ok = builder.save_timing_cache( builder_config, os.path.join(args.output_dir, "model.cache")) assert ok, "Failed to save timing cache." if __name__ == '__main__': args = parse_arguments() logger.set_level(args.log_level) tik = time.time() if args.parallel_build and args.world_size > 1 and \ torch.cuda.device_count() >= args.world_size: logger.warning( f'Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.' ) mp.spawn(build, nprocs=args.world_size, args=(args, )) else: args.parallel_build = False logger.info('Serially build TensorRT engines.') build(0, args) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) logger.info(f'Total time of building all {args.world_size} engines: {t}')