TensorRT-LLMs/examples/enc_dec/build.py
Kaiyu Xie b2fd493c16
Update TensorRT-LLM (#349)
* Update TensorRT-LLM

---------

Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2023-11-10 22:30:31 +08:00

378 lines
14 KiB
Python

import argparse
import time
from pathlib import Path
import torch
import tensorrt_llm
from tensorrt_llm._utils import str_dtype_to_trt
from tensorrt_llm.builder import Builder
from tensorrt_llm.logger import logger
from tensorrt_llm.network import net_guard
from weight import load_t5_from_pytorch, parse_config # isort:skip
MODEL_NAME = "enc_dec"
def get_engine_name(model, dtype, tp_size, rank):
return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
def serialize_engine(engine, path):
logger.info(f'Serializing engine to {path}...')
tik = time.time()
with open(path, 'wb') as f:
f.write(bytearray(engine))
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Engine serialized. Total time: {t}')
def parse_arguments(args, component):
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, default=None)
parser.add_argument('--dtype',
type=str,
default='float16',
choices=['float16', 'float32', 'bfloat16'])
parser.add_argument('--logits_dtype',
type=str,
default='float32',
choices=['float16', 'float32'])
parser.add_argument(
'--timing_cache',
type=str,
default='model.cache',
help=
'The path of to read timing cache from, will be ignored if the file does not exist'
)
parser.add_argument('--log_level', type=str, default='info')
parser.add_argument('--vocab_size', type=int, default=32128)
parser.add_argument('--n_layer', type=int, default=6)
parser.add_argument('--n_positions', type=int, default=1024)
parser.add_argument('--n_embd', type=int, default=1024)
parser.add_argument('--n_head', type=int, default=8)
parser.add_argument('--hidden_act', type=str, default='gelu')
parser.add_argument('--inter_size', type=int, default=None)
parser.add_argument('--no_bias', action="store_false")
parser.add_argument('--max_batch_size', type=int, default=8)
parser.add_argument('--max_encoder_input_len', type=int, default=1024)
parser.add_argument('--max_input_len', type=int, default=200)
parser.add_argument('--max_output_len', type=int, default=200)
parser.add_argument('--max_beam_width', type=int, default=1)
parser.add_argument(
'--use_bert_attention_plugin',
nargs='?',
const=None,
type=str,
default=False,
choices=['float16', 'float32', 'bfloat16'],
help=
"Activates BERT attention plugin. You can specify the plugin dtype or leave blank to use the model dtype."
)
parser.add_argument(
'--use_gpt_attention_plugin',
nargs='?',
const=None,
type=str,
default=False,
choices=['float16', 'float32', 'bfloat16'],
help=
"Activates attention plugin. You can specify the plugin dtype or leave blank to use the model dtype."
)
parser.add_argument(
'--use_gemm_plugin',
nargs='?',
const=None,
type=str,
default=False,
choices=['float16', 'float32', 'bfloat16'],
help=
"Activates GEMM plugin. You can specify the plugin dtype or leave blank to use the model dtype."
)
parser.add_argument(
'--use_layernorm_plugin',
nargs='?',
const=None,
type=str,
default=False,
choices=['float16', 'float32', 'bfloat16'],
help=
"Activates layernorm plugin. You can specify the plugin dtype or leave blank to use the model dtype."
)
parser.add_argument('--enable_qk_half_accum',
default=False,
action='store_true')
parser.add_argument('--gpus_per_node', type=int, default=8)
parser.add_argument('--builder_opt', type=int, default=None)
parser.add_argument(
'--output_dir',
type=Path,
default='trt_engines',
help=
'The path to save the serialized engine files, timing cache file and model configs'
)
parser.add_argument('--remove_input_padding',
default=False,
action='store_true')
parser.add_argument(
'--random_seed',
type=int,
default=None,
help=
'Seed to use when initializing the random number generator for torch.')
parser.add_argument(
'--use_lookup_plugin',
nargs='?',
const=None,
default=False,
choices=['float16', 'float32', 'bfloat16'],
help="Activates the lookup plugin which enables embedding sharing.")
parser.add_argument(
'--strongly_typed',
default=False,
action="store_true",
help=
'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.'
)
args = parser.parse_args(args)
logger.set_level(args.log_level)
args.bias = not args.no_bias
if args.inter_size is None:
args.inter_size = 4 * args.n_embd
if args.model_dir is not None:
logger.info(f"Setting model configuration from {args.model_dir}.")
args = parse_config(
Path(args.model_dir) / "config.ini", component, args)
plugins_args = [
'use_bert_attention_plugin', 'use_gpt_attention_plugin',
'use_gemm_plugin', 'use_layernorm_plugin', 'use_lookup_plugin'
]
for plugin_arg in plugins_args:
if getattr(args, plugin_arg) is None:
logger.info(
f"{plugin_arg} set, without specifying a value. Using {args.dtype} automatically."
)
setattr(args, plugin_arg, args.dtype)
return args
def build_rank_engine(builder: Builder,
builder_config: tensorrt_llm.builder.BuilderConfig,
engine_name, rank, args):
'''
@brief: Build the engine on the given rank.
@param rank: The rank to build the engine.
@param args: The cmd line arguments.
@return: The built engine.
'''
kv_dtype = str_dtype_to_trt(args.dtype)
# Initialize Module
if args.component == 'encoder':
tllm_model = tensorrt_llm.models.EncoderModel(
num_layers=args.n_layer,
num_heads=args.n_head,
hidden_size=args.hidden_size,
ffn_hidden_size=args.ffn_hidden_size,
vocab_size=args.vocab_size,
max_position_embeddings=args.n_positions,
has_position_embedding=args.has_position_embedding,
relative_attention=args.relative_attention,
max_distance=args.max_distance,
num_buckets=args.num_buckets,
has_embedding_layernorm=args.has_embedding_layernorm,
has_embedding_scale=args.has_embedding_scale,
q_scaling=args.q_scaling,
has_attention_qkvo_bias=args.has_attention_qkvo_bias,
has_mlp_bias=args.has_mlp_bias,
has_model_final_layernorm=args.has_model_final_layernorm,
layernorm_eps=args.layernorm_eps,
layernorm_position=args.layernorm_position,
layernorm_type=args.layernorm_type,
hidden_act=args.hidden_act,
dtype=kv_dtype)
elif args.component == 'decoder':
tllm_model = tensorrt_llm.models.DecoderModel(
num_layers=args.n_layer,
num_heads=args.n_head,
hidden_size=args.hidden_size,
ffn_hidden_size=args.ffn_hidden_size,
encoder_hidden_size=args.encoder_hidden_size,
encoder_num_heads=args.encoder_num_heads,
vocab_size=args.vocab_size,
max_position_embeddings=args.n_positions,
has_position_embedding=args.has_position_embedding,
relative_attention=args.relative_attention,
max_distance=args.max_distance,
num_buckets=args.num_buckets,
has_embedding_layernorm=args.has_embedding_layernorm,
has_embedding_scale=args.has_embedding_scale,
q_scaling=args.q_scaling,
has_attention_qkvo_bias=args.has_attention_qkvo_bias,
has_mlp_bias=args.has_mlp_bias,
has_model_final_layernorm=args.has_model_final_layernorm,
layernorm_eps=args.layernorm_eps,
layernorm_position=args.layernorm_position,
layernorm_type=args.layernorm_type,
hidden_act=args.hidden_act,
dtype=kv_dtype,
logits_dtype=args.logits_dtype)
# No support for relative attention bias in plain TRT mode
# (If to add such support, need to add into
# Attention and BertAttention at tensorrt_llm/layers/attention.py)
if args.relative_attention:
assert args.use_bert_attention_plugin, "Relative attention bias is only supported when using BertAttention Plugin"
assert args.use_gpt_attention_plugin, "Relative attention bias is only supported when using GPTAttention Plugin"
if args.model_dir is not None:
load_t5_from_pytorch(tllm_model,
args.model_dir,
args.component,
dtype=args.dtype)
# Module -> Network
network = builder.create_network()
network.trt_network.name = engine_name
if args.use_bert_attention_plugin:
network.plugin_config.set_bert_attention_plugin(
dtype=args.use_bert_attention_plugin)
if args.use_gpt_attention_plugin:
network.plugin_config.set_gpt_attention_plugin(
dtype=args.use_gpt_attention_plugin)
if args.use_gemm_plugin:
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
if args.use_layernorm_plugin:
network.plugin_config.set_layernorm_plugin(
dtype=args.use_layernorm_plugin)
if args.enable_qk_half_accum:
network.plugin_config.enable_qk_half_accum()
if args.remove_input_padding:
network.plugin_config.enable_remove_input_padding()
if args.use_lookup_plugin:
# Use the plugin for the embedding parallelism and sharing
network.plugin_config.set_lookup_plugin(dtype=args.dtype)
with net_guard(network):
# Prepare
network.set_named_parameters(tllm_model.named_parameters())
# Forward
if args.component == 'encoder':
inputs = tllm_model.prepare_inputs(
args.max_batch_size,
args.max_input_len,
)
elif args.component == 'decoder':
inputs = tllm_model.prepare_inputs(
args.n_layer,
args.max_batch_size,
args.max_beam_width,
args.max_input_len,
args.max_output_len,
args.max_encoder_input_len,
)
tllm_model(*inputs)
# Adding debug outputs into the network --------------------------
for k, v in tllm_model.named_network_outputs():
network._mark_output(v, k,
tensorrt_llm.str_dtype_to_trt(args.dtype))
# ----------------------------------------------------------------
# Network -> Engine
engine = builder.build_engine(network, builder_config)
if rank == 0:
config_path = args.output_dir / args.component / 'config.json'
builder.save_config(builder_config, config_path)
tensorrt_llm.tools.cleanup(network, tllm_model)
return engine
def build(rank, args):
torch.cuda.set_device(rank % args.gpus_per_node)
tensorrt_llm.logger.set_level(args.log_level)
component_dir = args.output_dir / args.component
component_dir.mkdir(parents=True, exist_ok=True)
timing_cache_file = args.timing_cache if args.timing_cache else component_dir / "model.cache"
timing_cache = timing_cache_file
builder = Builder()
apply_query_key_layer_scaling = False
# Currently only support single GPU
world_size = 1
for cur_rank in range(world_size):
builder_config = builder.create_builder_config(
name=MODEL_NAME,
precision=args.dtype,
timing_cache=timing_cache,
tensor_parallel=world_size, # TP only
num_layers=args.n_layer,
num_heads=args.n_head,
hidden_size=args.hidden_size,
vocab_size=args.vocab_size,
hidden_act=args.hidden_act,
max_position_embeddings=args.n_positions,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
max_batch_size=args.max_batch_size,
max_input_len=args.max_input_len,
max_output_len=args.max_output_len,
opt_level=args.builder_opt,
cross_attention=(args.component == 'decoder'),
has_position_embedding=args.has_position_embedding,
has_token_type_embedding=args.has_token_type_embedding,
strongly_typed=args.strongly_typed)
engine_name = get_engine_name(MODEL_NAME, args.dtype, world_size,
cur_rank)
engine = build_rank_engine(builder, builder_config, engine_name,
cur_rank, args)
assert engine is not None, f'Failed to build engine for rank {cur_rank}'
if cur_rank == 0:
# Use in-memory timing cache for multiple builder passes.
timing_cache = builder_config.trt_builder_config.get_timing_cache()
serialize_engine(engine, component_dir / engine_name)
if rank == 0:
ok = builder.save_timing_cache(builder_config, timing_cache_file)
assert ok, "Failed to save timing cache."
def run_build(component, args=None):
assert component == 'encoder' or component == 'decoder', 'Unsupported component!'
args = parse_arguments(args, component)
args.component = component
if args.random_seed is not None:
torch.manual_seed(args.random_seed)
logger.set_level(args.log_level)
tik = time.time()
# Currently only support single GPU serial build
logger.info('Serially build TensorRT engines.')
build(0, args)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Total time of building all engines: {t}')
if __name__ == '__main__':
run_build(component='encoder')
run_build(component='decoder')