TensorRT-LLMs/examples/opt/build.py
2023-12-01 22:27:51 +08:00

409 lines
16 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import time
from pathlib import Path
import torch
import torch.multiprocessing as mp
import tensorrt_llm
from tensorrt_llm._utils import str_dtype_to_trt
from tensorrt_llm.builder import Builder
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models import quantize_model
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
from tensorrt_llm.quantization import QuantMode
from weight import load_from_ft, parse_ft_config, check_embedding_share # isort:skip
MODEL_NAME = "opt"
def get_engine_name(model, dtype, tp_size, rank):
return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
def serialize_engine(engine, path):
logger.info(f'Serializing engine to {path}...')
tik = time.time()
with open(path, 'wb') as f:
f.write(engine)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Engine serialized. Total time: {t}')
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--world_size',
type=int,
default=1,
help='world size, only support tensor parallelism now')
parser.add_argument('--model_dir', type=str, default=None)
parser.add_argument('--dtype',
type=str,
default='float16',
choices=['float16', 'float32'])
parser.add_argument(
'--timing_cache',
type=str,
default='model.cache',
help=
'The path of to read timing cache from, will be ignored if the file does not exist'
)
parser.add_argument('--log_level', type=str, default='info')
parser.add_argument('--vocab_size', type=int, default=50272)
parser.add_argument('--n_layer', type=int, default=12)
parser.add_argument('--n_positions', type=int, default=2048)
parser.add_argument('--n_embd', type=int, default=768)
parser.add_argument('--n_head', type=int, default=12)
parser.add_argument('--hidden_act', type=str, default='relu')
parser.add_argument('--max_batch_size', type=int, default=256)
parser.add_argument('--max_input_len', type=int, default=200)
parser.add_argument('--max_output_len', type=int, default=200)
parser.add_argument('--max_beam_width', type=int, default=1)
parser.add_argument('--pre_norm', action='store_true')
parser.add_argument('--post_norm', dest='pre_norm', action='store_false')
parser.add_argument('--do_layer_norm_before',
default=False,
action='store_true')
parser.add_argument('--use_gpt_attention_plugin',
nargs='?',
const='float16',
type=str,
default=False,
choices=['float16', 'float32'])
parser.add_argument('--use_gemm_plugin',
nargs='?',
const='float16',
type=str,
default=False,
choices=['float16', 'float32'])
parser.add_argument('--use_layernorm_plugin',
nargs='?',
const='float16',
type=str,
default=False,
choices=['float16', 'float32'])
parser.add_argument('--parallel_build', default=False, action='store_true')
parser.add_argument('--enable_context_fmha',
default=False,
action='store_true')
parser.add_argument('--enable_context_fmha_fp32_acc',
default=False,
action='store_true')
parser.add_argument(
'--multi_block_mode',
default=False,
action='store_true',
help=
'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \
It is beneifical when batchxnum_heads cannot fully utilize GPU.'
)
parser.add_argument('--gpus_per_node', type=int, default=8)
parser.add_argument(
'--output_dir',
type=str,
default='engine_outputs',
help=
'The path to save the serialized engine files, timing cache file and model configs'
)
parser.add_argument('--remove_input_padding',
default=False,
action='store_true')
parser.add_argument(
'--int8_kv_cache',
default=False,
action="store_true",
help=
'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
)
parser.add_argument(
'--use_weight_only',
default=False,
action="store_true",
help='Quantize weights for the various GEMMs to INT4/INT8.'
'See --weight_only_precision to set the precision')
parser.add_argument(
'--weight_only_precision',
const='int8',
type=str,
nargs='?',
default='int8',
choices=['int8', 'int4'],
help=
'Define the precision for the weights when using weight-only quantization.'
'You must also use --use_weight_only for that argument to have an impact.'
)
parser.add_argument(
'--max_prompt_embedding_table_size',
type=int,
default=0,
help='Setting to a value > 0 enables support for prompt tuning.')
parser.add_argument(
'--use_parallel_embedding',
action="store_true",
default=False,
help=
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
)
parser.add_argument(
'--embedding_sharding_dim',
type=int,
default=0,
choices=[0, 1],
help=
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
'To shard it along hidden dimension, set embedding_sharding_dim=1'
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
)
parser.add_argument(
'--use_embedding_sharing',
action="store_true",
default=False,
help=
'Try to reduce the engine size by sharing the embedding lookup table between two layers.'
'Note: the flag might not take effect when the criteria are not met.')
parser.add_argument(
'--use_lookup_plugin',
nargs='?',
const=None,
default=False,
choices=['float16', 'float32', 'bfloat16'],
help=
"Activates the lookup plugin which enables embedding sharing. It is also required for language modeling embedding weight sharing."
)
parser.add_argument(
'--strongly_typed',
default=False,
action="store_true",
help=
'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.'
)
args = parser.parse_args()
if args.use_weight_only:
args.quant_mode = QuantMode.use_weight_only(
args.weight_only_precision == 'int4')
else:
args.quant_mode = QuantMode(0)
if args.int8_kv_cache:
args.quant_mode = args.quant_mode.set_int8_kv_cache()
if args.model_dir is not None:
n_embd, n_head, n_layer, n_positions, vocab_size, do_layer_norm_before = parse_ft_config(
Path(args.model_dir) / "config.ini")
args.n_embd = n_embd
args.n_head = n_head
args.n_layer = n_layer
args.n_positions = n_positions
args.vocab_size = vocab_size
args.do_layer_norm_before = do_layer_norm_before
return args
def build_rank_engine(builder: Builder,
builder_config: tensorrt_llm.builder.BuilderConfig,
engine_name, rank, args):
'''
@brief: Build the engine on the given rank.
@param rank: The rank to build the engine.
@param args: The cmd line arguments.
@return: The built engine.
'''
kv_dtype = str_dtype_to_trt(args.dtype)
# Share_embedding_table can be set True only when:
# 1) the weight for lm_head() does not exist while other weights exist
# 2) For multiple-processes, use_parallel_embedding=True and embedding_sharding_dim == 0.
# Besides, for TensorRT 9.0, we can observe the engine size reduction when the lookup and gemm plugin are enabled.
share_embedding_table = False
if args.use_embedding_sharing:
if args.world_size > 1:
if args.model_dir is not None and args.embedding_sharding_dim == 0 and args.use_parallel_embedding:
share_embedding_table = check_embedding_share(args.model_dir)
else:
if args.model_dir is not None:
share_embedding_table = check_embedding_share(args.model_dir)
if not share_embedding_table:
logger.warning(f'Cannot share the embedding lookup table.')
# Initialize Module
tensorrt_llm_gpt = tensorrt_llm.models.OPTLMHeadModel(
num_layers=args.n_layer,
num_heads=args.n_head,
hidden_size=args.n_embd,
vocab_size=args.vocab_size,
hidden_act=args.hidden_act,
max_position_embeddings=args.n_positions,
dtype=kv_dtype,
mapping=Mapping(world_size=args.world_size,
rank=rank,
tp_size=args.world_size), # TP only
pre_norm=args.pre_norm,
do_layer_norm_before=args.do_layer_norm_before,
use_prompt_tuning=args.max_prompt_embedding_table_size > 0,
use_parallel_embedding=args.use_parallel_embedding,
embedding_sharding_dim=args.embedding_sharding_dim,
share_embedding_table=share_embedding_table)
if args.use_weight_only:
tensorrt_llm_gpt = quantize_model(tensorrt_llm_gpt, args.quant_mode)
if args.model_dir is not None:
load_from_ft(tensorrt_llm_gpt,
args.model_dir,
rank,
args.world_size,
fp16=(args.dtype == 'float16'),
use_parallel_embedding=args.use_parallel_embedding,
sharding_dim=args.embedding_sharding_dim,
share_embedding_table=share_embedding_table)
# Module -> Network
network = builder.create_network()
network.trt_network.name = engine_name
if args.use_gpt_attention_plugin:
network.plugin_config.set_gpt_attention_plugin(
args.use_gpt_attention_plugin)
if args.use_gemm_plugin:
network.plugin_config.set_gemm_plugin(args.use_gemm_plugin)
if args.use_layernorm_plugin:
network.plugin_config.set_layernorm_plugin(args.use_layernorm_plugin)
if args.use_lookup_plugin:
# Use the plugin for the embedding parallelism and sharing
network.plugin_config.set_lookup_plugin(dtype=args.dtype)
assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc)
if args.enable_context_fmha:
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
if args.enable_context_fmha_fp32_acc:
network.plugin_config.set_context_fmha(
ContextFMHAType.enabled_with_fp32_acc)
if args.multi_block_mode:
network.plugin_config.enable_mmha_multi_block_mode()
if args.use_weight_only:
assert (args.dtype == 'float16')
network.plugin_config.set_weight_only_quant_matmul_plugin(
dtype=args.dtype)
if args.world_size > 1:
network.plugin_config.set_nccl_plugin(args.dtype)
if args.remove_input_padding:
network.plugin_config.enable_remove_input_padding()
with net_guard(network):
# Prepare
network.set_named_parameters(tensorrt_llm_gpt.named_parameters())
# Forward
inputs = tensorrt_llm_gpt.prepare_inputs(
args.max_batch_size,
args.max_input_len,
args.max_output_len,
True,
args.max_beam_width,
prompt_embedding_table_size=args.max_prompt_embedding_table_size)
tensorrt_llm_gpt(*inputs)
tensorrt_llm.graph_rewriting.optimize(network)
engine = None
# Network -> Engine
engine = builder.build_engine(network, builder_config)
if rank == 0:
config_path = os.path.join(args.output_dir, 'config.json')
builder.save_config(builder_config, config_path)
return engine
def build(rank, args):
torch.cuda.set_device(rank % args.gpus_per_node)
tensorrt_llm.logger.set_level(args.log_level)
os.makedirs(args.output_dir, exist_ok=True)
builder = Builder()
cache = None
for cur_rank in range(args.world_size):
# skip other ranks if parallel_build is enabled
if args.parallel_build and cur_rank != rank:
continue
builder_config = builder.create_builder_config(
name=MODEL_NAME,
precision=args.dtype,
timing_cache=args.timing_cache if cache is None else cache,
tensor_parallel=args.world_size, # TP only
parallel_build=args.parallel_build,
num_layers=args.n_layer,
num_heads=args.n_head,
hidden_size=args.n_embd,
vocab_size=args.vocab_size,
hidden_act=args.hidden_act,
max_position_embeddings=args.n_positions,
max_batch_size=args.max_batch_size,
max_beam_width=args.max_beam_width,
max_input_len=args.max_input_len,
max_output_len=args.max_output_len,
max_prompt_embedding_table_size=args.
max_prompt_embedding_table_size,
int8=(args.quant_mode.has_act_or_weight_quant()
or args.quant_mode.has_int8_kv_cache()),
strongly_typed=args.strongly_typed)
engine_name = get_engine_name(MODEL_NAME, args.dtype, args.world_size,
cur_rank)
engine = build_rank_engine(builder, builder_config, engine_name,
cur_rank, args)
assert engine is not None, f'Failed to build engine for rank {cur_rank}'
if cur_rank == 0:
# Use in-memory timing cache for multiple builder passes.
if not args.parallel_build:
cache = builder_config.trt_builder_config.get_timing_cache()
serialize_engine(engine, os.path.join(args.output_dir, engine_name))
del engine
if rank == 0:
ok = builder.save_timing_cache(
builder_config, os.path.join(args.output_dir, "model.cache"))
assert ok, "Failed to save timing cache."
if __name__ == '__main__':
args = parse_arguments()
logger.set_level(args.log_level)
tik = time.time()
if args.parallel_build and args.world_size > 1 and \
torch.cuda.device_count() >= args.world_size:
logger.warning(
f'Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.'
)
mp.spawn(build, nprocs=args.world_size, args=(args, ))
else:
args.parallel_build = False
logger.info('Serially build TensorRT engines.')
build(0, args)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Total time of building all {args.world_size} engines: {t}')