mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
378 lines
14 KiB
Python
378 lines
14 KiB
Python
import argparse
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm._utils import str_dtype_to_trt
|
|
from tensorrt_llm.builder import Builder
|
|
from tensorrt_llm.logger import logger
|
|
from tensorrt_llm.network import net_guard
|
|
|
|
from weight import load_t5_from_pytorch, parse_config # isort:skip
|
|
|
|
MODEL_NAME = "enc_dec"
|
|
|
|
|
|
def get_engine_name(model, dtype, tp_size, rank):
|
|
return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
|
|
|
|
|
|
def serialize_engine(engine, path):
|
|
logger.info(f'Serializing engine to {path}...')
|
|
tik = time.time()
|
|
with open(path, 'wb') as f:
|
|
f.write(bytearray(engine))
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
logger.info(f'Engine serialized. Total time: {t}')
|
|
|
|
|
|
def parse_arguments(args, component):
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--model_dir', type=str, default=None)
|
|
parser.add_argument('--dtype',
|
|
type=str,
|
|
default='float16',
|
|
choices=['float16', 'float32', 'bfloat16'])
|
|
parser.add_argument('--logits_dtype',
|
|
type=str,
|
|
default='float32',
|
|
choices=['float16', 'float32'])
|
|
parser.add_argument(
|
|
'--timing_cache',
|
|
type=str,
|
|
default='model.cache',
|
|
help=
|
|
'The path of to read timing cache from, will be ignored if the file does not exist'
|
|
)
|
|
parser.add_argument('--log_level', type=str, default='info')
|
|
parser.add_argument('--vocab_size', type=int, default=32128)
|
|
parser.add_argument('--n_layer', type=int, default=6)
|
|
parser.add_argument('--n_positions', type=int, default=1024)
|
|
parser.add_argument('--n_embd', type=int, default=1024)
|
|
parser.add_argument('--n_head', type=int, default=8)
|
|
parser.add_argument('--hidden_act', type=str, default='gelu')
|
|
parser.add_argument('--inter_size', type=int, default=None)
|
|
parser.add_argument('--no_bias', action="store_false")
|
|
parser.add_argument('--max_batch_size', type=int, default=8)
|
|
parser.add_argument('--max_encoder_input_len', type=int, default=1024)
|
|
parser.add_argument('--max_input_len', type=int, default=200)
|
|
parser.add_argument('--max_output_len', type=int, default=200)
|
|
parser.add_argument('--max_beam_width', type=int, default=1)
|
|
parser.add_argument(
|
|
'--use_bert_attention_plugin',
|
|
nargs='?',
|
|
const=None,
|
|
type=str,
|
|
default=False,
|
|
choices=['float16', 'float32', 'bfloat16'],
|
|
help=
|
|
"Activates BERT attention plugin. You can specify the plugin dtype or leave blank to use the model dtype."
|
|
)
|
|
parser.add_argument(
|
|
'--use_gpt_attention_plugin',
|
|
nargs='?',
|
|
const=None,
|
|
type=str,
|
|
default=False,
|
|
choices=['float16', 'float32', 'bfloat16'],
|
|
help=
|
|
"Activates attention plugin. You can specify the plugin dtype or leave blank to use the model dtype."
|
|
)
|
|
parser.add_argument(
|
|
'--use_gemm_plugin',
|
|
nargs='?',
|
|
const=None,
|
|
type=str,
|
|
default=False,
|
|
choices=['float16', 'float32', 'bfloat16'],
|
|
help=
|
|
"Activates GEMM plugin. You can specify the plugin dtype or leave blank to use the model dtype."
|
|
)
|
|
parser.add_argument(
|
|
'--use_layernorm_plugin',
|
|
nargs='?',
|
|
const=None,
|
|
type=str,
|
|
default=False,
|
|
choices=['float16', 'float32', 'bfloat16'],
|
|
help=
|
|
"Activates layernorm plugin. You can specify the plugin dtype or leave blank to use the model dtype."
|
|
)
|
|
parser.add_argument('--enable_qk_half_accum',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument('--gpus_per_node', type=int, default=8)
|
|
parser.add_argument('--builder_opt', type=int, default=None)
|
|
parser.add_argument(
|
|
'--output_dir',
|
|
type=Path,
|
|
default='trt_engines',
|
|
help=
|
|
'The path to save the serialized engine files, timing cache file and model configs'
|
|
)
|
|
parser.add_argument('--remove_input_padding',
|
|
default=False,
|
|
action='store_true')
|
|
parser.add_argument(
|
|
'--random_seed',
|
|
type=int,
|
|
default=None,
|
|
help=
|
|
'Seed to use when initializing the random number generator for torch.')
|
|
parser.add_argument(
|
|
'--use_lookup_plugin',
|
|
nargs='?',
|
|
const=None,
|
|
default=False,
|
|
choices=['float16', 'float32', 'bfloat16'],
|
|
help="Activates the lookup plugin which enables embedding sharing.")
|
|
|
|
parser.add_argument(
|
|
'--strongly_typed',
|
|
default=False,
|
|
action="store_true",
|
|
help=
|
|
'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.'
|
|
)
|
|
|
|
args = parser.parse_args(args)
|
|
logger.set_level(args.log_level)
|
|
|
|
args.bias = not args.no_bias
|
|
if args.inter_size is None:
|
|
args.inter_size = 4 * args.n_embd
|
|
|
|
if args.model_dir is not None:
|
|
logger.info(f"Setting model configuration from {args.model_dir}.")
|
|
args = parse_config(
|
|
Path(args.model_dir) / "config.ini", component, args)
|
|
plugins_args = [
|
|
'use_bert_attention_plugin', 'use_gpt_attention_plugin',
|
|
'use_gemm_plugin', 'use_layernorm_plugin', 'use_lookup_plugin'
|
|
]
|
|
for plugin_arg in plugins_args:
|
|
if getattr(args, plugin_arg) is None:
|
|
logger.info(
|
|
f"{plugin_arg} set, without specifying a value. Using {args.dtype} automatically."
|
|
)
|
|
setattr(args, plugin_arg, args.dtype)
|
|
|
|
return args
|
|
|
|
|
|
def build_rank_engine(builder: Builder,
|
|
builder_config: tensorrt_llm.builder.BuilderConfig,
|
|
engine_name, rank, args):
|
|
'''
|
|
@brief: Build the engine on the given rank.
|
|
@param rank: The rank to build the engine.
|
|
@param args: The cmd line arguments.
|
|
@return: The built engine.
|
|
'''
|
|
kv_dtype = str_dtype_to_trt(args.dtype)
|
|
|
|
# Initialize Module
|
|
if args.component == 'encoder':
|
|
tllm_model = tensorrt_llm.models.EncoderModel(
|
|
num_layers=args.n_layer,
|
|
num_heads=args.n_head,
|
|
hidden_size=args.hidden_size,
|
|
ffn_hidden_size=args.ffn_hidden_size,
|
|
vocab_size=args.vocab_size,
|
|
max_position_embeddings=args.n_positions,
|
|
has_position_embedding=args.has_position_embedding,
|
|
relative_attention=args.relative_attention,
|
|
max_distance=args.max_distance,
|
|
num_buckets=args.num_buckets,
|
|
has_embedding_layernorm=args.has_embedding_layernorm,
|
|
has_embedding_scale=args.has_embedding_scale,
|
|
q_scaling=args.q_scaling,
|
|
has_attention_qkvo_bias=args.has_attention_qkvo_bias,
|
|
has_mlp_bias=args.has_mlp_bias,
|
|
has_model_final_layernorm=args.has_model_final_layernorm,
|
|
layernorm_eps=args.layernorm_eps,
|
|
layernorm_position=args.layernorm_position,
|
|
layernorm_type=args.layernorm_type,
|
|
hidden_act=args.hidden_act,
|
|
dtype=kv_dtype)
|
|
elif args.component == 'decoder':
|
|
tllm_model = tensorrt_llm.models.DecoderModel(
|
|
num_layers=args.n_layer,
|
|
num_heads=args.n_head,
|
|
hidden_size=args.hidden_size,
|
|
ffn_hidden_size=args.ffn_hidden_size,
|
|
encoder_hidden_size=args.encoder_hidden_size,
|
|
encoder_num_heads=args.encoder_num_heads,
|
|
vocab_size=args.vocab_size,
|
|
max_position_embeddings=args.n_positions,
|
|
has_position_embedding=args.has_position_embedding,
|
|
relative_attention=args.relative_attention,
|
|
max_distance=args.max_distance,
|
|
num_buckets=args.num_buckets,
|
|
has_embedding_layernorm=args.has_embedding_layernorm,
|
|
has_embedding_scale=args.has_embedding_scale,
|
|
q_scaling=args.q_scaling,
|
|
has_attention_qkvo_bias=args.has_attention_qkvo_bias,
|
|
has_mlp_bias=args.has_mlp_bias,
|
|
has_model_final_layernorm=args.has_model_final_layernorm,
|
|
layernorm_eps=args.layernorm_eps,
|
|
layernorm_position=args.layernorm_position,
|
|
layernorm_type=args.layernorm_type,
|
|
hidden_act=args.hidden_act,
|
|
dtype=kv_dtype,
|
|
logits_dtype=args.logits_dtype)
|
|
|
|
# No support for relative attention bias in plain TRT mode
|
|
# (If to add such support, need to add into
|
|
# Attention and BertAttention at tensorrt_llm/layers/attention.py)
|
|
if args.relative_attention:
|
|
assert args.use_bert_attention_plugin, "Relative attention bias is only supported when using BertAttention Plugin"
|
|
assert args.use_gpt_attention_plugin, "Relative attention bias is only supported when using GPTAttention Plugin"
|
|
|
|
if args.model_dir is not None:
|
|
load_t5_from_pytorch(tllm_model,
|
|
args.model_dir,
|
|
args.component,
|
|
dtype=args.dtype)
|
|
|
|
# Module -> Network
|
|
network = builder.create_network()
|
|
network.trt_network.name = engine_name
|
|
if args.use_bert_attention_plugin:
|
|
network.plugin_config.set_bert_attention_plugin(
|
|
dtype=args.use_bert_attention_plugin)
|
|
if args.use_gpt_attention_plugin:
|
|
network.plugin_config.set_gpt_attention_plugin(
|
|
dtype=args.use_gpt_attention_plugin)
|
|
if args.use_gemm_plugin:
|
|
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
|
|
if args.use_layernorm_plugin:
|
|
network.plugin_config.set_layernorm_plugin(
|
|
dtype=args.use_layernorm_plugin)
|
|
if args.enable_qk_half_accum:
|
|
network.plugin_config.enable_qk_half_accum()
|
|
if args.remove_input_padding:
|
|
network.plugin_config.enable_remove_input_padding()
|
|
|
|
if args.use_lookup_plugin:
|
|
# Use the plugin for the embedding parallelism and sharing
|
|
network.plugin_config.set_lookup_plugin(dtype=args.dtype)
|
|
|
|
with net_guard(network):
|
|
# Prepare
|
|
network.set_named_parameters(tllm_model.named_parameters())
|
|
|
|
# Forward
|
|
if args.component == 'encoder':
|
|
inputs = tllm_model.prepare_inputs(
|
|
args.max_batch_size,
|
|
args.max_input_len,
|
|
)
|
|
elif args.component == 'decoder':
|
|
inputs = tllm_model.prepare_inputs(
|
|
args.n_layer,
|
|
args.max_batch_size,
|
|
args.max_beam_width,
|
|
args.max_input_len,
|
|
args.max_output_len,
|
|
args.max_encoder_input_len,
|
|
)
|
|
|
|
tllm_model(*inputs)
|
|
|
|
# Adding debug outputs into the network --------------------------
|
|
for k, v in tllm_model.named_network_outputs():
|
|
network._mark_output(v, k,
|
|
tensorrt_llm.str_dtype_to_trt(args.dtype))
|
|
# ----------------------------------------------------------------
|
|
|
|
# Network -> Engine
|
|
engine = builder.build_engine(network, builder_config)
|
|
if rank == 0:
|
|
config_path = args.output_dir / args.component / 'config.json'
|
|
builder.save_config(builder_config, config_path)
|
|
|
|
tensorrt_llm.tools.cleanup(network, tllm_model)
|
|
|
|
return engine
|
|
|
|
|
|
def build(rank, args):
|
|
torch.cuda.set_device(rank % args.gpus_per_node)
|
|
tensorrt_llm.logger.set_level(args.log_level)
|
|
component_dir = args.output_dir / args.component
|
|
component_dir.mkdir(parents=True, exist_ok=True)
|
|
timing_cache_file = args.timing_cache if args.timing_cache else component_dir / "model.cache"
|
|
timing_cache = timing_cache_file
|
|
|
|
builder = Builder()
|
|
apply_query_key_layer_scaling = False
|
|
|
|
# Currently only support single GPU
|
|
world_size = 1
|
|
for cur_rank in range(world_size):
|
|
builder_config = builder.create_builder_config(
|
|
name=MODEL_NAME,
|
|
precision=args.dtype,
|
|
timing_cache=timing_cache,
|
|
tensor_parallel=world_size, # TP only
|
|
num_layers=args.n_layer,
|
|
num_heads=args.n_head,
|
|
hidden_size=args.hidden_size,
|
|
vocab_size=args.vocab_size,
|
|
hidden_act=args.hidden_act,
|
|
max_position_embeddings=args.n_positions,
|
|
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
|
|
max_batch_size=args.max_batch_size,
|
|
max_input_len=args.max_input_len,
|
|
max_output_len=args.max_output_len,
|
|
opt_level=args.builder_opt,
|
|
cross_attention=(args.component == 'decoder'),
|
|
has_position_embedding=args.has_position_embedding,
|
|
has_token_type_embedding=args.has_token_type_embedding,
|
|
strongly_typed=args.strongly_typed)
|
|
|
|
engine_name = get_engine_name(MODEL_NAME, args.dtype, world_size,
|
|
cur_rank)
|
|
engine = build_rank_engine(builder, builder_config, engine_name,
|
|
cur_rank, args)
|
|
assert engine is not None, f'Failed to build engine for rank {cur_rank}'
|
|
|
|
if cur_rank == 0:
|
|
# Use in-memory timing cache for multiple builder passes.
|
|
timing_cache = builder_config.trt_builder_config.get_timing_cache()
|
|
|
|
serialize_engine(engine, component_dir / engine_name)
|
|
|
|
if rank == 0:
|
|
ok = builder.save_timing_cache(builder_config, timing_cache_file)
|
|
assert ok, "Failed to save timing cache."
|
|
|
|
|
|
def run_build(component, args=None):
|
|
assert component == 'encoder' or component == 'decoder', 'Unsupported component!'
|
|
args = parse_arguments(args, component)
|
|
args.component = component
|
|
|
|
if args.random_seed is not None:
|
|
torch.manual_seed(args.random_seed)
|
|
|
|
logger.set_level(args.log_level)
|
|
tik = time.time()
|
|
|
|
# Currently only support single GPU serial build
|
|
logger.info('Serially build TensorRT engines.')
|
|
build(0, args)
|
|
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
logger.info(f'Total time of building all engines: {t}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_build(component='encoder')
|
|
run_build(component='decoder')
|