TensorRT-LLMs/examples/gpt/build.py
2023-10-10 23:22:17 -07:00

618 lines
24 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import time
from pathlib import Path
from typing import List
import torch
import torch.multiprocessing as mp
import tensorrt_llm
from tensorrt_llm._utils import str_dtype_to_trt
from tensorrt_llm.builder import Builder
from tensorrt_llm.layers import PositionEmbeddingType
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models import smooth_quantize, weight_only_quantize
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
from tensorrt_llm.quantization import QuantMode
from weight import load_from_ft, parse_ft_config, check_embedding_share # isort:skip
MODEL_NAME = "gpt"
def get_engine_name(model, dtype, tp_size, rank):
return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
def find_engines(dir: Path,
model_name: str = "*",
dtype: str = "*",
tp_size: str = "*",
rank: str = "*") -> List[Path]:
template = f"{model_name}_{dtype}_tp{tp_size}_rank{rank}.engine"
return list(dir.glob(template))
def serialize_engine(engine, path):
logger.info(f'Serializing engine to {path}...')
tik = time.time()
with open(path, 'wb') as f:
f.write(bytearray(engine))
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Engine serialized. Total time: {t}')
def parse_arguments(args):
parser = argparse.ArgumentParser()
parser.add_argument('--world_size',
type=int,
default=1,
help='world size, only support tensor parallelism now')
parser.add_argument('--model_dir', type=str, default=None)
parser.add_argument('--dtype',
type=str,
default='float16',
choices=['float16', 'float32', 'bfloat16'])
parser.add_argument('--logits_dtype',
type=str,
default='float32',
choices=['float16', 'float32'])
parser.add_argument(
'--timing_cache',
type=str,
default='model.cache',
help=
'The path of to read timing cache from, will be ignored if the file does not exist'
)
parser.add_argument('--log_level', type=str, default='info')
parser.add_argument('--vocab_size', type=int, default=51200)
parser.add_argument('--n_layer', type=int, default=24)
parser.add_argument('--n_positions', type=int, default=1024)
parser.add_argument('--n_embd', type=int, default=1024)
parser.add_argument('--n_head', type=int, default=16)
parser.add_argument('--hidden_act', type=str, default='gelu')
parser.add_argument(
'--rotary_pct',
type=float,
default=0.0,
help="Setting this to a value > 0.0 (and <= 1.0) activates RoPE.")
parser.add_argument('--inter_size', type=int, default=None)
parser.add_argument('--no_bias', action="store_false")
parser.add_argument('--max_batch_size', type=int, default=256)
parser.add_argument('--max_input_len', type=int, default=200)
parser.add_argument('--max_output_len', type=int, default=200)
parser.add_argument('--max_beam_width', type=int, default=1)
parser.add_argument(
'--use_gpt_attention_plugin',
nargs='?',
const=None,
type=str,
default=False,
choices=['float16', 'float32', 'bfloat16'],
help=
"Activates attention plugin. You can specify the plugin dtype or leave blank to use the model dtype."
)
parser.add_argument(
'--use_gemm_plugin',
nargs='?',
const=None,
type=str,
default=False,
choices=['float16', 'float32', 'bfloat16'],
help=
"Activates GEMM plugin. You can specify the plugin dtype or leave blank to use the model dtype."
)
parser.add_argument(
'--use_layernorm_plugin',
nargs='?',
const=None,
type=str,
default=False,
choices=['float16', 'float32', 'bfloat16'],
help=
"Activates layernorm plugin. You can specify the plugin dtype or leave blank to use the model dtype."
)
parser.add_argument('--parallel_build', default=False, action='store_true')
parser.add_argument('--enable_context_fmha',
default=False,
action='store_true')
parser.add_argument('--enable_context_fmha_fp32_acc',
default=False,
action='store_true')
parser.add_argument('--gpus_per_node', type=int, default=8)
parser.add_argument('--builder_opt', type=int, default=None)
parser.add_argument(
'--output_dir',
type=Path,
default='gpt_outputs',
help=
'The path to save the serialized engine files, timing cache file and model configs'
)
parser.add_argument(
"--multi_query_mode",
"-mq",
default=False,
action='store_true',
help=
"Whether this model uses multi-query attention mechanism (default: False)"
)
parser.add_argument('--remove_input_padding',
default=False,
action='store_true')
# Arguments related to the quantization of the model.
parser.add_argument(
'--use_smooth_quant',
default=False,
action="store_true",
help=
'Use the SmoothQuant method to quantize activations and weights for the various GEMMs.'
'See --per_channel and --per_token for finer-grained quantization options.'
)
parser.add_argument(
'--use_weight_only',
default=False,
action="store_true",
help='Quantize weights for the various GEMMs to INT4/INT8.'
'See --weight_only_precision to set the precision')
parser.add_argument(
'--weight_only_precision',
const='int8',
type=str,
nargs='?',
default='int8',
choices=['int8', 'int4'],
help=
'Define the precision for the weights when using weight-only quantization.'
'You must also use --use_weight_only for that argument to have an impact.'
)
parser.add_argument(
'--per_channel',
default=False,
action="store_true",
help=
'By default, we use a single static scaling factor for the GEMM\'s result. '
'per_channel instead uses a different static scaling factor for each channel. '
'The latter is usually more accurate, but a little slower.')
parser.add_argument(
'--per_token',
default=False,
action="store_true",
help=
'By default, we use a single static scaling factor to scale activations in the int8 range. '
'per_token chooses at run time, and for each token, a custom scaling factor. '
'The latter is usually more accurate, but a little slower.')
parser.add_argument(
'--int8_kv_cache',
default=False,
action="store_true",
help=
'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
)
parser.add_argument(
'--random_seed',
type=int,
default=None,
help=
'Seed to use when initializing the random number generator for torch.')
parser.add_argument(
'--paged_kv_cache',
action="store_true",
default=False,
help=
'By default we use contiguous KV cache. By setting this flag you enable paged KV cache'
)
parser.add_argument('--tokens_per_block',
type=int,
default=64,
help='Number of tokens per block in paged KV cache')
parser.add_argument(
'--max_prompt_embedding_table_size',
type=int,
default=0,
help='Setting to a value > 0 enables support for prompt tuning.')
parser.add_argument(
'--use_inflight_batching',
action="store_true",
default=False,
help="Activates inflight batching mode of gptAttentionPlugin.")
parser.add_argument(
'--use_parallel_embedding',
action="store_true",
default=False,
help=
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
)
parser.add_argument(
'--embedding_sharding_dim',
type=int,
default=0,
choices=[0, 1],
help=
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
'To shard it along hidden dimension, set embedding_sharding_dim=1'
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
)
parser.add_argument(
'--use_embedding_sharing',
action="store_true",
default=False,
help=
'Try to reduce the engine size by sharing the embedding lookup table between two layers.'
'Note: the flag might not take effect when the criteria are not met.')
parser.add_argument(
'--use_lookup_plugin',
nargs='?',
const=None,
default=False,
choices=['float16', 'float32', 'bfloat16'],
help="Activates the lookup plugin which enables embedding sharing.")
parser.add_argument('--gather_all_token_logits',
action='store_true',
default=False)
parser.add_argument('--enable_fp8', default=False, action='store_true')
parser.add_argument(
'--fp8_kv_cache',
default=False,
action="store_true",
help=
'By default, we use dtype for KV cache. fp8_kv_cache chooses fp8 quantization for KV'
)
parser.add_argument(
'--strongly_typed',
default=False,
action="store_true",
help=
'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.'
)
parser.add_argument(
'--use_custom_all_reduce',
action='store_true',
help=
'Activates latency-optimized algorithm for all-reduce instead of NCCL.')
args = parser.parse_args(args)
logger.set_level(args.log_level)
if not args.remove_input_padding:
if args.use_gpt_attention_plugin:
logger.warning(
f"It is recommended to specify --remove_input_padding when using GPT attention plugin"
)
args.bias = not args.no_bias
if args.inter_size is None:
args.inter_size = 4 * args.n_embd
if args.model_dir is not None:
logger.info(f"Setting model configuration from {args.model_dir}.")
n_embd, n_head, n_layer, n_positions, vocab_size, _, hidden_act, rotary_pct, bias, inter_size, multi_query_mode, dtype, prompt_num_tasks, prompt_max_vocab_size = parse_ft_config(
Path(args.model_dir) / "config.ini")
args.n_embd = n_embd
args.n_head = n_head
args.n_layer = n_layer
args.n_positions = n_positions
args.vocab_size = vocab_size
args.hidden_act = hidden_act
args.rotary_pct = rotary_pct
args.bias = bias
args.dtype = dtype
args.inter_size = inter_size
args.multi_query_mode = multi_query_mode
plugins_args = [
'use_gpt_attention_plugin', 'use_gemm_plugin', 'use_layernorm_plugin',
'use_lookup_plugin'
]
for plugin_arg in plugins_args:
if getattr(args, plugin_arg) is None:
logger.info(
f"{plugin_arg} set, without specifying a value. Using {args.dtype} automatically."
)
setattr(args, plugin_arg, args.dtype)
assert not (
args.use_smooth_quant and args.use_weight_only
), "You cannot enable both SmoothQuant and INT8 weight-only together."
if args.use_inflight_batching:
if not args.use_gpt_attention_plugin:
args.use_gpt_attention_plugin = 'float16'
logger.info(
f"Using GPT attention plugin for inflight batching mode. Setting to default '{args.use_gpt_attention_plugin}'"
)
if not args.remove_input_padding:
args.remove_input_padding = True
logger.info(
"Using remove input padding for inflight batching mode.")
if not args.paged_kv_cache:
args.paged_kv_cache = True
logger.info("Using paged KV cache for inflight batching mode.")
if args.use_smooth_quant:
args.quant_mode = QuantMode.use_smooth_quant(args.per_token,
args.per_channel)
elif args.use_weight_only:
args.quant_mode = QuantMode.use_weight_only(
args.weight_only_precision == 'int4')
else:
args.quant_mode = QuantMode(0)
if args.int8_kv_cache:
args.quant_mode = args.quant_mode.set_int8_kv_cache()
if args.fp8_kv_cache:
assert (
args.use_gpt_attention_plugin or args.use_inflight_batching
), "You have to use GPT attention plugin when fp8 KV cache is set"
args.quant_mode = args.quant_mode.set_fp8_kv_cache()
if args.enable_fp8:
args.quant_mode = args.quant_mode.set_fp8_qdq()
return args
def build_rank_engine(builder: Builder,
builder_config: tensorrt_llm.builder.BuilderConfig,
engine_name, rank, args):
'''
@brief: Build the engine on the given rank.
@param rank: The rank to build the engine.
@param args: The cmd line arguments.
@return: The built engine.
'''
kv_dtype = str_dtype_to_trt(args.dtype)
# Share_embedding_table can be set True only when:
# 1) the weight for lm_head() does not exist while other weights exist
# 2) For multiple-processes, use_parallel_embedding=True and embedding_sharding_dim == 0.
# Besides, for TensorRT 9.0, we can observe the engine size reduction when the lookup and gemm plugin are enabled.
share_embedding_table = False
if args.use_embedding_sharing:
if args.world_size > 1:
if args.model_dir is not None and args.embedding_sharding_dim == 0 and args.use_parallel_embedding:
share_embedding_table = check_embedding_share(args.model_dir)
else:
if args.model_dir is not None:
share_embedding_table = check_embedding_share(args.model_dir)
if not share_embedding_table:
logger.warning(f'Cannot share the embedding lookup table.')
if share_embedding_table:
logger.info(
'Engine will share embedding and language modeling weights.')
# Initialize Module
tensorrt_llm_gpt = tensorrt_llm.models.GPTLMHeadModel(
num_layers=args.n_layer,
num_heads=args.n_head,
hidden_size=args.n_embd,
inter_size=args.inter_size,
vocab_size=args.vocab_size,
hidden_act=args.hidden_act,
max_position_embeddings=args.n_positions,
position_embedding_type=PositionEmbeddingType.learned_absolute
if args.rotary_pct == 0.0 else PositionEmbeddingType.rope_gpt_neox,
rotary_embedding_percentage=args.rotary_pct,
dtype=kv_dtype,
logits_dtype=args.logits_dtype,
mapping=Mapping(world_size=args.world_size,
rank=rank,
tp_size=args.world_size), # TP only
apply_query_key_layer_scaling=builder_config.
apply_query_key_layer_scaling,
quant_mode=args.quant_mode,
bias=args.bias,
multi_query_mode=args.multi_query_mode,
use_prompt_tuning=args.max_prompt_embedding_table_size > 0,
use_parallel_embedding=args.use_parallel_embedding,
embedding_sharding_dim=args.embedding_sharding_dim,
share_embedding_table=share_embedding_table)
if args.use_smooth_quant:
tensorrt_llm_gpt = smooth_quantize(tensorrt_llm_gpt, args.quant_mode)
elif args.use_weight_only:
tensorrt_llm_gpt = weight_only_quantize(tensorrt_llm_gpt,
args.quant_mode)
if args.model_dir is not None:
gpt_dummy_fp8_scaling_factors = {
'fc_act': [0.5 for _ in range(args.n_layer)],
'fc_weights': [0.5 for _ in range(args.n_layer)],
'proj_act': [0.5 for _ in range(args.n_layer)],
'proj_weights': [0.5 for _ in range(args.n_layer)],
'qkv_act': [0.5 for _ in range(args.n_layer)],
'qkv_weights': [0.5 for _ in range(args.n_layer)],
'qkv_output': [0.5 for _ in range(args.n_layer)],
'dense_act': [0.5 for _ in range(args.n_layer)],
'dense_weights': [0.5 for _ in range(args.n_layer)],
}
load_from_ft(tensorrt_llm_gpt,
args.model_dir,
rank,
args.world_size,
args.dtype,
args.use_parallel_embedding,
args.embedding_sharding_dim,
share_embedding_table,
scaling_factors=gpt_dummy_fp8_scaling_factors
if args.enable_fp8 else None)
# Module -> Network
network = builder.create_network()
network.trt_network.name = engine_name
if args.use_gpt_attention_plugin:
network.plugin_config.set_gpt_attention_plugin(
dtype=args.use_gpt_attention_plugin)
if args.use_gemm_plugin:
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
if args.use_layernorm_plugin:
network.plugin_config.set_layernorm_plugin(
dtype=args.use_layernorm_plugin)
assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc)
if args.enable_context_fmha:
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
if args.enable_context_fmha_fp32_acc:
network.plugin_config.set_context_fmha(
ContextFMHAType.enabled_with_fp32_acc)
if args.remove_input_padding:
network.plugin_config.enable_remove_input_padding()
if args.paged_kv_cache:
network.plugin_config.enable_paged_kv_cache(args.tokens_per_block)
# Quantization plugins.
if args.use_smooth_quant:
network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype)
network.plugin_config.set_layernorm_quantization_plugin(
dtype=args.dtype)
# FIXME(nkorobov)
# See https://nvbugs/4164762
# See https://nvbugs/4174113
network.plugin_config.set_quantize_tensor_plugin()
network.plugin_config.set_quantize_per_token_plugin()
elif args.use_weight_only:
network.plugin_config.set_weight_only_quant_matmul_plugin(
dtype=args.dtype)
if args.world_size > 1:
network.plugin_config.set_nccl_plugin(args.dtype,
args.use_custom_all_reduce)
if args.use_lookup_plugin:
# Use the plugin for the embedding parallelism and sharing
network.plugin_config.set_lookup_plugin(dtype=args.dtype)
with net_guard(network):
# Prepare
network.set_named_parameters(tensorrt_llm_gpt.named_parameters())
# Forward
inputs = tensorrt_llm_gpt.prepare_inputs(
args.max_batch_size,
args.max_input_len,
args.max_output_len,
True,
args.max_beam_width,
prompt_embedding_table_size=args.max_prompt_embedding_table_size,
gather_all_token_logits=args.gather_all_token_logits)
tensorrt_llm_gpt(*inputs)
tensorrt_llm.graph_rewriting.optimize(network)
engine = None
# Network -> Engine
engine = builder.build_engine(network, builder_config)
if rank == 0:
config_path = args.output_dir / 'config.json'
builder.save_config(builder_config, config_path)
return engine
def build(rank, args):
torch.cuda.set_device(rank % args.gpus_per_node)
tensorrt_llm.logger.set_level(args.log_level)
args.output_dir.mkdir(parents=True, exist_ok=True)
timing_cache_file = args.timing_cache if args.timing_cache else args.output_dir / "model.cache"
timing_cache = timing_cache_file
builder = Builder()
apply_query_key_layer_scaling = False
for cur_rank in range(args.world_size):
# skip other ranks if parallel_build is enabled
if args.parallel_build and cur_rank != rank:
continue
# NOTE(nkorobov): when only int8 kv cache is used together with paged kv cache no int8 tensors are exposed to TRT
int8_trt_flag = args.quant_mode.has_act_and_weight_quant() or (
args.paged_kv_cache == False
and args.quant_mode.has_int8_kv_cache())
builder_config = builder.create_builder_config(
name=MODEL_NAME,
precision=args.dtype,
timing_cache=timing_cache,
tensor_parallel=args.world_size, # TP only
parallel_build=args.parallel_build,
num_layers=args.n_layer,
num_heads=args.n_head,
num_kv_heads=1 if args.multi_query_mode else args.n_head,
hidden_size=args.n_embd,
vocab_size=args.vocab_size,
hidden_act=args.hidden_act,
max_position_embeddings=args.n_positions,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
max_batch_size=args.max_batch_size,
max_input_len=args.max_input_len,
max_output_len=args.max_output_len,
int8=int8_trt_flag,
opt_level=args.builder_opt,
multi_query_mode=args.multi_query_mode,
strongly_typed=args.strongly_typed,
use_prompt_tuning=args.max_prompt_embedding_table_size > 0,
gather_all_token_logits=args.gather_all_token_logits,
fp8=args.enable_fp8,
use_parallel_embedding=args.use_parallel_embedding)
engine_name = get_engine_name(MODEL_NAME, args.dtype, args.world_size,
cur_rank)
engine = build_rank_engine(builder, builder_config, engine_name,
cur_rank, args)
assert engine is not None, f'Failed to build engine for rank {cur_rank}'
if cur_rank == 0:
# Use in-memory timing cache for multiple builder passes.
if not args.parallel_build:
timing_cache = builder_config.trt_builder_config.get_timing_cache(
)
serialize_engine(engine, args.output_dir / engine_name)
if rank == 0:
ok = builder.save_timing_cache(builder_config, timing_cache_file)
assert ok, "Failed to save timing cache."
def run_build(args=None):
args = parse_arguments(args)
if args.random_seed is not None:
torch.manual_seed(args.random_seed)
logger.set_level(args.log_level)
tik = time.time()
if args.parallel_build and args.world_size > 1 and \
torch.cuda.device_count() >= args.world_size:
logger.warning(
f'Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.'
)
mp.spawn(build, nprocs=args.world_size, args=(args, ))
else:
args.parallel_build = False
logger.info('Serially build TensorRT engines.')
build(0, args)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Total time of building all {args.world_size} engines: {t}')
if __name__ == '__main__':
run_build()