TensorRT-LLMs/examples/enc_dec/build.py
Kaiyu Xie 655524dd82
Update TensorRT-LLM (#1168)
* Update TensorRT-LLM

---------

Co-authored-by: Bhuvanesh Sridharan <bhuvan.sridharan@gmail.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-02-27 17:37:34 +08:00

607 lines
24 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import configparser
import time
from pathlib import Path
import torch
import torch.multiprocessing as mp
from run import get_engine_name
import tensorrt_llm
from tensorrt_llm._utils import str_dtype_to_trt
from tensorrt_llm.builder import Builder
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
from t5.weight import parse_t5_config, load_from_hf_t5, load_from_binary_t5 # isort:skip
from bart.weight import parse_bart_config, load_from_binary_bart # isort:skip
from nmt.weight import parse_nmt_config, load_from_binary_nmt # isort:skip
def serialize_engine(engine, path):
logger.info(f'Serializing engine to {path}...')
tik = time.time()
with open(path, 'wb') as f:
f.write(engine)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Engine serialized. Total time: {t}')
def parse_config(ini_file, component, args):
config = configparser.ConfigParser()
assert ini_file.exists(), f"Missing config file {ini_file}"
config.read(ini_file)
model_type = config.get('structure', 'model_type')
args.model_type = model_type
args = globals()[f'parse_{model_type}_config'](config, component, args)
return args
def parse_arguments(component):
parser = argparse.ArgumentParser()
parser.add_argument('--world_size',
type=int,
default=1,
help='MPI world size (must equal TP * PP)')
parser.add_argument('--tp_size',
type=int,
default=1,
help='N-way tensor parallelism size')
parser.add_argument('--pp_size',
type=int,
default=1,
help='N-way pipeline parallelism size')
parser.add_argument(
'--gpus_per_node',
type=int,
default=8,
help=
'Number of GPUs each node has in a multi-node setup. This is a cluster spec and can be greater/smaller than world size'
)
parser.add_argument('--parallel_build', default=False, action='store_true')
parser.add_argument('--weight_dir',
'-i',
type=str,
default=None,
help='Path to the converted weight file')
parser.add_argument(
'--output_dir',
'-o',
type=Path,
default='trt_engines',
help=
'The path to save the serialized engine files, timing cache file and model configs'
)
parser.add_argument(
'--weight_from_pytorch_ckpt',
default=False,
action='store_true',
help=
'Load weight from PyTorch checkpoint. model_dir must point to ckpt directory'
)
parser.add_argument('--engine_name',
'-n',
type=str,
default='enc_dec',
help='TensorRT engine name prefix')
parser.add_argument('--debug_mode', action='store_true')
parser.add_argument(
'--timing_cache',
type=str,
default='model.cache',
help=
'The path of to read timing cache from, will be ignored if the file does not exist'
)
parser.add_argument(
'--profiling_verbosity',
type=str,
default='layer_names_only',
choices=['layer_names_only', 'detailed', 'none'],
help=
'The profiling verbosity for the generated TRT engine. Set to detailed can inspect tactic choices and kernel parameters.'
)
parser.add_argument('--model_type',
type=str,
choices=['t5', 'bart', 'nmt'],
default='t5')
parser.add_argument(
'--dtype',
type=str,
default='float16',
choices=['float16', 'float32', 'bfloat16'],
help=
'Target inference dtype. Weights and Computation will be in this dtype, no matter what original dtype the weight checkpoint has.'
)
parser.add_argument('--logits_dtype',
type=str,
default='float32',
choices=['float16', 'float32'])
parser.add_argument('--log_level', type=str, default='info')
parser.add_argument('--max_batch_size', type=int, default=8)
parser.add_argument('--max_encoder_input_len', type=int, default=1024)
parser.add_argument(
'--max_decoder_input_len',
type=int,
default=1,
help=
'If you want deocder_forced_input_ids feature, set to value greater than 1. Otherwise, encoder-decoder model start from decoder_start_token_id of length 1'
)
parser.add_argument('--max_output_len', type=int, default=200)
parser.add_argument('--max_beam_width', type=int, default=1)
parser.add_argument(
'--use_bert_attention_plugin',
nargs='?',
const=None,
type=str,
default=False,
choices=['float16', 'float32', 'bfloat16'],
help=
"Activates BERT attention plugin. You can specify the plugin dtype or leave blank to use the model dtype."
)
parser.add_argument(
'--use_gpt_attention_plugin',
nargs='?',
const=None,
type=str,
default=False,
choices=['float16', 'float32', 'bfloat16'],
help=
"Activates attention plugin. You can specify the plugin dtype or leave blank to use the model dtype."
)
parser.add_argument(
'--use_gemm_plugin',
nargs='?',
const=None,
type=str,
default=False,
choices=['float16', 'float32', 'bfloat16'],
help=
"Activates GEMM plugin. You can specify the plugin dtype or leave blank to use the model dtype."
)
parser.add_argument(
'--use_lookup_plugin',
nargs='?',
const=None,
default=False,
choices=['float16', 'float32', 'bfloat16'],
help="Activates the lookup plugin which enables embedding sharding.")
parser.add_argument('--enable_qk_half_accum',
default=False,
action='store_true')
parser.add_argument('--enable_context_fmha',
default=False,
action='store_true')
parser.add_argument('--enable_context_fmha_fp32_acc',
default=False,
action='store_true')
parser.add_argument('--builder_opt', type=int, default=None)
parser.add_argument('--remove_input_padding',
default=False,
action='store_true')
parser.add_argument(
'--random_seed',
type=int,
default=None,
help=
'Seed to use when initializing the random number generator for torch.')
parser.add_argument(
'--max_prompt_embedding_table_size',
'--max_multimodal_len',
type=int,
default=0,
help=
'Setting to a value > 0 enables support for prompt tuning or multimodal input.'
)
parser.add_argument(
'--use_parallel_embedding',
action="store_true",
default=False,
help=
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
)
parser.add_argument(
'--embedding_sharding_dim',
type=int,
default=0,
choices=[0, 1],
help=
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
'To shard it along hidden dimension, set embedding_sharding_dim=1'
'Note: embedding sharding is only enabled when embedding_sharding_dim = 0'
)
parser.add_argument(
'--use_custom_all_reduce',
action='store_true',
help=
'Activates latency-optimized algorithm for all-reduce instead of NCCL.')
parser.add_argument(
'--strongly_typed',
default=False,
action="store_true",
help=
'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.'
)
parser.add_argument(
'--gather_all_token_logits',
action='store_true',
default=False,
help='Enable both gather_context_logits and gather_generation_logits')
parser.add_argument('--gather_context_logits',
action='store_true',
default=False,
help='Gather context logits')
parser.add_argument('--gather_generation_logits',
action='store_true',
default=False,
help='Gather generation logits')
parser.add_argument(
'--skip_encoder',
'--nougat',
default=False,
action="store_true",
help=
'Skip building encoder for nougat model. Encoder is not an LLM in nougat'
)
parser.add_argument(
'--use_lora_plugin',
nargs='?',
const=None,
default=False,
choices=['float16', 'float32', 'bfloat16'],
help='Activates the lora plugin which enables embedding sharing.')
# parse cmdline args
args = parser.parse_args()
logger.set_level(args.log_level)
if component == 'encoder' and args.skip_encoder:
# Skip further processing
return args
# parse model config and add to args
if args.weight_dir is not None:
logger.info(f"Setting model configuration from {args.weight_dir}.")
args = parse_config(
Path(args.weight_dir) / "config.ini", component, args)
assert args.pp_size * args.tp_size == args.world_size
plugins_args = [
'use_bert_attention_plugin', 'use_gpt_attention_plugin',
'use_gemm_plugin', 'use_lora_plugin', 'use_lookup_plugin'
]
for plugin_arg in plugins_args:
if getattr(args, plugin_arg) is None:
logger.info(
f"{plugin_arg} set, without specifying a value. Using {args.dtype} automatically."
)
setattr(args, plugin_arg, args.dtype)
if args.gather_all_token_logits:
args.gather_context_logits = True
args.gather_generation_logits = True
return args
def build_rank_engine(builder: Builder,
builder_config: tensorrt_llm.builder.BuilderConfig,
engine_name, rank, args):
'''
@brief: Build the engine on the given rank.
@param rank: The rank to build the engine.
@param args: The cmd line arguments.
@return: The built engine.
'''
dtype = str_dtype_to_trt(args.dtype)
mapping = Mapping(world_size=args.world_size,
rank=rank,
tp_size=args.tp_size,
pp_size=args.pp_size)
assert args.n_layer % args.pp_size == 0, \
f"num_layers {args.n_layer} must be a multiple of pipeline parallelism size {args.pp_size}"
fp16_clamping = (args.dtype == 'float16') and (args.model_type == 't5')
# Initialize Module
if args.component == 'encoder':
tllm_model = tensorrt_llm.models.EncoderModel(
num_layers=args.n_layer,
num_heads=args.n_head,
num_kv_heads=args.n_head,
head_size=args.head_size,
hidden_size=args.hidden_size,
ffn_hidden_size=args.ffn_hidden_size,
vocab_size=args.vocab_size,
max_position_embeddings=args.n_positions,
has_position_embedding=args.has_position_embedding,
relative_attention=args.relative_attention,
max_distance=args.max_distance,
num_buckets=args.num_buckets,
has_embedding_layernorm=args.has_embedding_layernorm,
has_embedding_scale=args.has_embedding_scale,
q_scaling=args.q_scaling,
has_attention_qkvo_bias=args.has_attention_qkvo_bias,
has_mlp_bias=args.has_mlp_bias,
has_model_final_layernorm=args.has_model_final_layernorm,
layernorm_eps=args.layernorm_eps,
layernorm_position=args.layernorm_position,
layernorm_type=args.layernorm_type,
hidden_act=args.hidden_act,
mlp_type=args.mlp_type,
dtype=dtype,
use_prompt_tuning=args.max_prompt_embedding_table_size > 0,
use_parallel_embedding=args.use_parallel_embedding,
embedding_sharding_dim=args.embedding_sharding_dim,
mapping=mapping,
fp16_clamping=fp16_clamping,
max_lora_rank=args.max_lora_rank if args.use_lora_plugin else 0)
elif args.component == 'decoder':
tllm_model = tensorrt_llm.models.DecoderModel(
num_layers=args.n_layer,
num_heads=args.n_head,
num_kv_heads=args.n_head,
head_size=args.head_size,
hidden_size=args.hidden_size,
ffn_hidden_size=args.ffn_hidden_size,
encoder_hidden_size=args.encoder_hidden_size,
encoder_num_heads=args.encoder_num_heads,
encoder_head_size=args.encoder_head_size,
vocab_size=args.vocab_size,
max_position_embeddings=args.n_positions,
has_position_embedding=args.has_position_embedding,
relative_attention=args.relative_attention,
max_distance=args.max_distance,
num_buckets=args.num_buckets,
has_embedding_layernorm=args.has_embedding_layernorm,
has_embedding_scale=args.has_embedding_scale,
q_scaling=args.q_scaling,
has_attention_qkvo_bias=args.has_attention_qkvo_bias,
has_mlp_bias=args.has_mlp_bias,
has_model_final_layernorm=args.has_model_final_layernorm,
layernorm_eps=args.layernorm_eps,
layernorm_position=args.layernorm_position,
layernorm_type=args.layernorm_type,
hidden_act=args.hidden_act,
mlp_type=args.mlp_type,
use_parallel_embedding=args.use_parallel_embedding,
embedding_sharding_dim=args.embedding_sharding_dim,
max_lora_rank=args.max_lora_rank if args.use_lora_plugin else 0,
mapping=mapping,
rescale_before_lm_head=args.rescale_before_lm_head,
dtype=dtype,
logits_dtype=args.logits_dtype,
fp16_clamping=fp16_clamping)
if args.weight_from_pytorch_ckpt:
assert args.tp_size == 1, "Loading from framework model via memory is for demonstration purpose. For multi-GPU inference, please use loading from binary for better performance."
globals()[f'load_from_hf_{args.model_type}'](tllm_model,
args.weight_dir,
args.component,
dtype=args.dtype)
else:
globals()[f'load_from_binary_{args.model_type}'](tllm_model,
args.weight_dir,
args,
mapping=mapping,
dtype=args.dtype)
# Module -> Network
network = builder.create_network()
network.trt_network.name = engine_name
network.plugin_config.to_legacy_setting()
if args.use_bert_attention_plugin:
network.plugin_config.set_bert_attention_plugin(
dtype=args.use_bert_attention_plugin)
if args.use_gpt_attention_plugin:
network.plugin_config.set_gpt_attention_plugin(
dtype=args.use_gpt_attention_plugin)
if args.use_gemm_plugin:
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
if args.enable_qk_half_accum:
network.plugin_config.enable_qk_half_accum()
assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc)
if args.enable_context_fmha and not args.relative_attention:
logger.warning("Only non-T5 enc-dec models support FMHA")
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
if args.enable_context_fmha_fp32_acc and not args.relative_attention:
logger.warning("Only non-T5 enc-dec models support FMHA")
network.plugin_config.set_context_fmha(
ContextFMHAType.enabled_with_fp32_acc)
if args.remove_input_padding:
network.plugin_config.enable_remove_input_padding()
if args.use_lookup_plugin:
# Use the plugin for the embedding parallelism and sharding
network.plugin_config.set_lookup_plugin(dtype=args.dtype)
if args.use_lora_plugin:
network.plugin_config.set_lora_plugin(dtype=args.use_lora_plugin)
if args.world_size > 1:
network.plugin_config.set_nccl_plugin(args.dtype,
args.use_custom_all_reduce)
with net_guard(network):
# Prepare
network.set_named_parameters(tllm_model.named_parameters())
# Forward
if args.component == 'encoder':
inputs = tllm_model.prepare_inputs(
max_batch_size=args.max_batch_size,
max_input_len=args.max_encoder_input_len,
prompt_embedding_table_size=args.
max_prompt_embedding_table_size,
lora_target_modules=args.lora_target_modules
if args.use_lora_plugin else None,
)
elif args.component == 'decoder':
inputs = tllm_model.prepare_inputs(
max_batch_size=args.max_batch_size,
max_beam_width=args.max_beam_width,
max_decoder_input_len=args.max_decoder_input_len,
max_new_tokens=args.max_output_len,
max_encoder_input_len=args.max_encoder_input_len,
gather_context_logits=args.gather_context_logits,
gather_generation_logits=args.gather_generation_logits,
lora_target_modules=args.lora_target_modules
if args.use_lora_plugin else None,
)
tllm_model(*inputs)
# Adding debug outputs into the network --------------------------
if args.debug_mode:
for k, v in tllm_model.named_network_outputs():
network._mark_output(v, k,
tensorrt_llm.str_dtype_to_trt(args.dtype))
# ----------------------------------------------------------------
# Network -> Engine
engine = builder.build_engine(network, builder_config)
return engine
def build(rank, args):
torch.cuda.set_device(rank % args.gpus_per_node)
tensorrt_llm.logger.set_level(args.log_level)
component_dir = args.output_dir / args.dtype / f"tp{args.tp_size}" / args.component
component_dir.mkdir(parents=True, exist_ok=True)
builder = Builder()
apply_query_key_layer_scaling = False
cache = None
for cur_rank in range(args.world_size):
# skip other ranks if parallel_build is enabled
if args.parallel_build and cur_rank != rank:
continue
builder_config = builder.create_builder_config(
name=args.engine_name,
precision=args.dtype,
timing_cache=component_dir /
args.timing_cache if cache is None else cache,
profiling_verbosity=args.profiling_verbosity,
tensor_parallel=args.tp_size,
pipeline_parallel=args.pp_size,
gpus_per_node=args.gpus_per_node,
parallel_build=args.parallel_build,
num_layers=args.n_layer,
num_heads=args.n_head,
hidden_size=args.hidden_size,
head_size=args.head_size,
vocab_size=args.vocab_size,
hidden_act=args.hidden_act,
max_position_embeddings=args.n_positions,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
max_batch_size=args.max_batch_size,
max_beam_width=args.max_beam_width,
max_decoder_input_len=args.max_decoder_input_len,
max_output_len=args.max_output_len,
max_encoder_input_len=args.max_encoder_input_len,
opt_level=args.builder_opt,
cross_attention=(args.component == 'decoder'),
has_position_embedding=args.has_position_embedding,
has_token_type_embedding=args.has_token_type_embedding,
strongly_typed=args.strongly_typed,
gather_context_logits=args.gather_context_logits,
gather_generation_logits=args.gather_generation_logits,
max_prompt_embedding_table_size=(
args.max_prompt_embedding_table_size
if args.component == 'encoder' else 0),
lora_target_modules=args.lora_target_modules
if args.use_lora_plugin else None,
hf_modules_to_trtllm_modules=args.hf_modules_to_trtllm_modules
if args.use_lora_plugin else None,
trtllm_modules_to_hf_modules=args.trtllm_modules_to_hf_modules
if args.use_lora_plugin else None,
)
engine_name = get_engine_name(args.engine_name, args.dtype,
args.tp_size, args.pp_size, cur_rank)
engine = build_rank_engine(builder, builder_config, engine_name,
cur_rank, args)
assert engine is not None, f'Failed to build engine for rank {cur_rank}'
if cur_rank == 0:
# save build config
config_path = component_dir / 'config.json'
builder.save_config(builder_config, config_path)
# Use in-memory timing cache for multiple builder passes.
if not args.parallel_build:
cache = builder_config.trt_builder_config.get_timing_cache()
serialize_engine(engine, component_dir / engine_name)
if rank == 0:
# save timing cache to speedup future use
ok = builder.save_timing_cache(builder_config,
component_dir / args.timing_cache)
assert ok, "Failed to save timing cache."
def run_build(component):
assert component == 'encoder' or component == 'decoder', 'Unsupported component!'
args = parse_arguments(component)
# special handling in prompt tuning / multimodal cases
if args.max_prompt_embedding_table_size > 0:
if component == 'decoder' and args.skip_encoder:
# for Nougat-like structure that only uses the decoder of enc-dec, encoder_output length equals to multimodal length, so max_encoder_input_len == max_encoder_output_len == max_multimodal_len == max_prompt_embedding_table_size MUST hold.
args.max_encoder_input_len = args.max_prompt_embedding_table_size
logger.warning(
"Forcing max_encoder_input_len equal to max_prompt_embedding_table_size"
)
# otherwise, e.g. for BLIP2-T5, the entire enc-dec is used, so multimodal length (visual output length) and encoder_output length (LLM input length) are two different things
if component == 'encoder' and args.skip_encoder:
logger.warning("Skipping build of encoder for Nougat model")
return
args.component = component
if args.random_seed is not None:
torch.manual_seed(args.random_seed)
logger.set_level(args.log_level)
tik = time.time()
if args.parallel_build and args.world_size > 1 and \
torch.cuda.device_count() >= args.world_size:
logger.warning(
f'Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.'
)
mp.spawn(build, nprocs=args.world_size, args=(args, ))
else:
args.parallel_build = False
logger.info('Serially build TensorRT engines.')
build(0, args)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Total time of building all engines: {t}')
if __name__ == '__main__':
run_build(component='encoder')
run_build(component='decoder')