TensorRT-LLMs/examples/chatglm/build.py
Kaiyu Xie 8dd9c91470
Update TensorRT-LLM (#539)
* Update TensorRT-LLM

---------

Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2023-12-04 18:06:59 +08:00

776 lines
28 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import time
from pathlib import Path
from typing import List
# isort: off
import torch
import torch.multiprocessing as mp
import tensorrt as trt
# isort: on
from visualize import to_onnx
from weight import get_scaling_factors, load_from_hf
import tensorrt_llm
from tensorrt_llm._utils import str_dtype_to_trt
from tensorrt_llm.builder import Builder
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models import ChatGLMHeadModel, quantize_model
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
from tensorrt_llm.profiler import check_gpt_mem_usage
from tensorrt_llm.quantization import QuantMode
def get_engine_name(model, dtype, tp_size, pp_size, rank):
if pp_size == 1:
return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
return '{}_{}_tp{}_pp{}_rank{}.engine'.format(model, dtype, tp_size,
pp_size, rank)
def find_engines(dir: Path,
model_name: str = "*",
dtype: str = "*",
tp_size: str = "*",
rank: str = "*") -> List[Path]:
template = f"{model_name}_{dtype}_tp{tp_size}_rank{rank}.engine"
return list(dir.glob(template))
def serialize_engine(engine, path):
logger.info(f'Serializing engine to {path}...')
tik = time.time()
with open(path, 'wb') as f:
f.write(bytearray(engine))
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Engine serialized. Total time: {t}')
def truncate_input_output_len(
max_input_len,
max_output_len,
max_seq_length_from_config,
is_fixed_max_position_length=False,
):
max_seq_length = max_seq_length_from_config
if max_input_len >= max_seq_length_from_config:
print("Truncate max_input_len as %d" % (max_seq_length_from_config - 1))
max_input_len = max_seq_length_from_config - 1
max_output_len = 1
elif max_input_len + max_output_len > max_seq_length_from_config:
print("Truncate max_output_len as %d" %
(max_seq_length_from_config - max_input_len))
max_output_len = max_seq_length_from_config - max_input_len
elif not is_fixed_max_position_length:
max_seq_length = max_input_len + max_output_len
return max_input_len, max_output_len, max_seq_length
def parse_arguments(args):
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_name',
'-m',
type=str,
required=True,
choices=[
"chatglm_6b", "chatglm2_6b", "chatglm2_6b_32k", "chatglm3_6b",
"chatglm3_6b_base", "chatglm3_6b_32k", "glm_10b"
],
help=
'the name of the model, use "_" rather than "-" to connect the name parts'
)
parser.add_argument(
'--world_size',
type=int,
default=1,
help='world size, only support tensor parallelism now',
)
parser.add_argument('--tp_size', type=int, default=1)
parser.add_argument('--pp_size', type=int, default=1)
parser.add_argument('--model_dir', type=Path, default=None)
parser.add_argument('--quant_ckpt_path', type=str, default="awq/")
parser.add_argument(
'--dtype',
type=str,
default='float16',
choices=['float32', 'float16', 'bfloat16'],
)
parser.add_argument(
'--logits_dtype',
type=str,
default='float32',
choices=['float16', 'float32'],
)
parser.add_argument(
'--timing_cache',
type=str,
default='model.cache',
help=
'The path of to read timing cache from, will be ignored if the file does not exist'
)
parser.add_argument(
'--log_level',
type=str,
default='info',
choices=['verbose', 'info', 'warning', 'error', 'internal_error'],
)
parser.add_argument('--max_batch_size', type=int, default=8)
parser.add_argument('--max_input_len', type=int, default=1024)
parser.add_argument('--max_output_len', type=int, default=1024)
parser.add_argument('--max_beam_width', type=int, default=1)
parser.add_argument(
'--use_gpt_attention_plugin',
nargs='?',
const='float16',
default='float16',
choices=['float32', 'float16', 'bfloat16', False],
help=
"Activates attention plugin. You can specify the plugin dtype or leave blank to use the model dtype."
)
parser.add_argument(
'--use_gemm_plugin',
nargs='?',
const='float16',
type=str,
default='float16',
choices=['float32', 'float16', 'bfloat16', False],
help=
"Activates GEMM plugin. You can specify the plugin dtype or leave blank to use the model dtype."
)
parser.add_argument(
'--use_layernorm_plugin',
nargs='?',
const='float16',
type=str,
default='float16',
choices=['float32', 'float16', 'bfloat16', False],
help=
"Activates layernorm plugin for ChatGLM-6B / GLM-10B models. You can specify the plugin dtype or leave blank to use the model dtype."
)
parser.add_argument(
'--use_rmsnorm_plugin',
nargs='?',
const='float16',
type=str,
default='float16',
choices=['float32', 'float16', 'bfloat16', False],
help=
"Activates rmsnorm plugin for ChatGLM2-6B* / ChatGLM3-6B* models. You can specify the plugin dtype or leave blank to use the model dtype."
)
parser.add_argument('--gather_all_token_logits',
action='store_true',
default=False)
parser.add_argument('--parallel_build', default=False, action='store_true')
parser.add_argument(
'--enable_context_fmha',
default=False,
action='store_true',
)
parser.add_argument(
'--enable_context_fmha_fp32_acc',
default=False,
action='store_true',
)
parser.add_argument(
'--multi_block_mode',
default=False,
action='store_true',
help=
'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \
It is beneifical when batchxnum_heads cannot fully utilize GPU.'
)
parser.add_argument('--visualize', default=False, action='store_true')
parser.add_argument(
'--enable_debug_output',
default=False,
action='store_true',
)
parser.add_argument('--gpus_per_node', type=int, default=8)
parser.add_argument('--builder_opt', type=int, default=None)
parser.add_argument(
'--output_dir',
type=Path,
default=None,
help=
'The path to save the serialized engine files, timing cache file and model configs'
)
parser.add_argument(
'--strongly_typed',
default=False,
action="store_true",
help=
'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.'
)
parser.add_argument(
'--remove_input_padding',
default=False,
action='store_true',
)
parser.add_argument(
'--paged_kv_cache',
action="store_true",
default=False,
help=
'By default we use contiguous KV cache. By setting this flag you enable paged KV cache'
)
parser.add_argument(
'--use_inflight_batching',
action="store_true",
default=False,
help="Activates inflight batching mode of gptAttentionPlugin.",
)
# Arguments related to the quantization of the model.
parser.add_argument(
'--use_smooth_quant',
default=False,
action="store_true",
help=
'Use the SmoothQuant method to quantize activations and weights for the various GEMMs.'
'See --per_channel and --per_token for finer-grained quantization options.'
)
parser.add_argument(
'--use_weight_only',
default=False,
action="store_true",
help='Quantize weights for the various GEMMs to INT4/INT8.'
'See --weight_only_precision to set the precision',
)
parser.add_argument(
'--weight_only_precision',
const='int8',
type=str,
nargs='?',
default='int8',
choices=['int8', 'int4', 'int4_awq'],
help=
'Define the precision for the weights when using weight-only quantization.'
'You must also use --use_weight_only for that argument to have an impact.',
)
parser.add_argument(
'--per_channel',
default=False,
action="store_true",
help=
'By default, we use a single static scaling factor for the GEMM\'s result. '
'per_channel instead uses a different static scaling factor for each channel. '
'The latter is usually more accurate, but a little slower.',
)
parser.add_argument(
'--per_token',
default=False,
action="store_true",
help=
'By default, we use a single static scaling factor to scale activations in the int8 range. '
'per_token chooses at run time, and for each token, a custom scaling factor. '
'The latter is usually more accurate, but a little slower.',
)
parser.add_argument(
'--per_group',
default=False,
action="store_true",
help=
'By default, we use a single static scaling factor to scale weights in the int4 range. '
'per_group chooses at run time, and for each group, a custom scaling factor. '
'The flag is built for GPTQ/AWQ quantization.',
)
parser.add_argument(
'--group_size',
type=int,
default=128,
help='Group size used in GPTQ/AWQ quantization.',
)
parser.add_argument(
'--int8_kv_cache',
default=False,
action="store_true",
help=
'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
)
parser.add_argument(
'--random_seed',
type=int,
default=None,
help=
'Seed to use when initializing the random number generator for torch.',
)
parser.add_argument(
'--tokens_per_block',
type=int,
default=64,
help='Number of tokens per block in paged KV cache',
)
parser.add_argument(
'--enable_fp8',
default=False,
action='store_true',
help='Use FP8 Linear layer for Attention QKV/Dense and MLP.',
)
parser.add_argument(
'--fp8_kv_cache',
default=False,
action="store_true",
help=
'By default, we use dtype for KV cache. fp8_kv_cache chooses fp8 quantization for KV'
)
parser.add_argument(
'--max_num_tokens',
type=int,
default=None,
help='Define the max number of tokens supported by the engine',
)
parser.add_argument(
'--use_custom_all_reduce',
action='store_true',
help=
'Activates latency-optimized algorithm for all-reduce instead of NCCL.',
)
args = parser.parse_args(args)
logger.set_level(args.log_level)
plugins_args = [
'use_gpt_attention_plugin',
'use_gemm_plugin',
'use_layernorm_plugin',
'use_rmsnorm_plugin',
]
for plugin_arg in plugins_args:
if getattr(args, plugin_arg) is None:
logger.info(
f"{plugin_arg} set, without specifying a value. Using {args.dtype} automatically."
)
setattr(args, plugin_arg, args.dtype)
assert args.world_size == args.tp_size * args.pp_size # only TP is supported now
if args.model_dir is None:
args.model_dir = Path(args.model_name)
if args.output_dir is None:
args.output_dir = Path("output_" + args.model_name)
with open(args.model_dir / "config.json", "r") as f:
js = json.loads(f.read())
if args.model_name in ["chatglm_6b", "glm_10b"]:
assert args.max_input_len < js["max_sequence_length"]
if args.model_name in ["chatglm_6b"]:
args.apply_query_key_layer_scaling = False
args.apply_residual_connection_post_layernorm = False
args.ffn_hidden_size = js["inner_hidden_size"]
args.hidden_act = 'gelu'
args.hidden_size = js["hidden_size"]
args.linear_bias = True
args.max_input_len, args.max_output_len, args.max_seq_length = truncate_input_output_len(
args.max_input_len,
args.max_output_len,
js["max_sequence_length"],
)
args.multi_query_mode = False
args.norm_epsilon = js["layernorm_epsilon"]
args.num_heads = js["num_attention_heads"]
args.num_kv_heads = js["num_attention_heads"]
args.num_layers = js["num_layers"]
args.qkv_bias = True
args.rmsnorm = False
args.rotary_embedding_scaling = 1.0
args.use_cache = js["use_cache"]
args.vocab_size = js["vocab_size"]
elif args.model_name in [
"chatglm2_6b",
"chatglm2_6b_32k",
"chatglm3_6b",
"chatglm3_6b_base",
"chatglm3_6b_32k",
]:
args.apply_query_key_layer_scaling = False
args.apply_residual_connection_post_layernorm = js[
"apply_residual_connection_post_layernorm"]
args.ffn_hidden_size = js["ffn_hidden_size"]
args.hidden_act = 'swiglu'
args.hidden_size = js["hidden_size"]
args.linear_bias = js["add_bias_linear"]
args.max_input_len, args.max_output_len, args.max_seq_length = truncate_input_output_len(
args.max_input_len,
args.max_output_len,
js["seq_length"],
)
args.multi_query_mode = js["multi_query_attention"]
args.norm_epsilon = js["layernorm_epsilon"]
args.num_heads = js["num_attention_heads"]
args.num_kv_heads = js["multi_query_group_num"]
args.num_layers = js["num_layers"]
args.qkv_bias = js["add_qkv_bias"]
args.rmsnorm = js["rmsnorm"]
if args.model_name in ["chatglm2_6b_32k", "chatglm3_6b_32k"]:
args.rotary_embedding_scaling = js["rope_ratio"]
else:
args.rotary_embedding_scaling = 1.0
args.use_cache = js["use_cache"]
args.vocab_size = js["padded_vocab_size"]
elif args.model_name in ["glm_10b"]:
args.apply_query_key_layer_scaling = False
args.apply_residual_connection_post_layernorm = False
args.ffn_hidden_size = 4 * js["hidden_size"]
args.hidden_act = 'gelu'
args.hidden_size = js["hidden_size"]
args.linear_bias = True
args.max_input_len, args.max_output_len, args.max_seq_length = truncate_input_output_len(
args.max_input_len,
args.max_output_len,
js["max_sequence_length"],
True,
)
args.multi_query_mode = False
args.norm_epsilon = 1.0e-5
args.num_heads = js["num_attention_heads"]
args.num_kv_heads = js["num_attention_heads"]
args.num_layers = js["num_layers"]
args.qkv_bias = True
args.rmsnorm = False
args.rotary_embedding_scaling = 1.0
args.use_cache = True
args.vocab_size = js["vocab_size"]
if args.use_inflight_batching:
if not args.use_gpt_attention_plugin:
args.use_gpt_attention_plugin = 'float16'
logger.info(
f"Using GPT attention plugin for inflight batching mode. Setting to default '{args.use_gpt_attention_plugin}'"
)
if not args.remove_input_padding:
args.remove_input_padding = True
logger.info(
"Using remove input padding for inflight batching mode.")
if not args.paged_kv_cache:
args.paged_kv_cache = True
logger.info("Using paged KV cache for inflight batching mode.")
assert not (
args.use_smooth_quant and args.use_weight_only
), "You cannot enable both SmoothQuant and INT8 weight-only together."
if args.use_smooth_quant:
args.quant_mode = QuantMode.use_smooth_quant(args.per_token,
args.per_channel)
elif args.use_weight_only:
args.quant_mode = QuantMode.use_weight_only(
args.weight_only_precision == 'int4')
else:
args.quant_mode = QuantMode(0)
if args.int8_kv_cache:
args.quant_mode = args.quant_mode.set_int8_kv_cache()
elif args.fp8_kv_cache:
args.quant_mode = args.quant_mode.set_fp8_kv_cache()
if args.enable_fp8:
args.quant_mode = args.quant_mode.set_fp8_qdq()
if args.max_num_tokens is not None:
assert args.enable_context_fmha
logger.info(' Build Arguments '.center(100, '='))
for k, v in vars(args).items():
logger.info(f' - {k.ljust(30, ".")}: {v}')
logger.info('=' * 100)
return args
def build_rank_engine(
builder: Builder,
builder_config: tensorrt_llm.builder.BuilderConfig,
engine_name: str,
rank: int,
args: argparse.Namespace,
) -> trt.IHostMemory:
'''
@brief: Build the engine on the given rank.
@param rank: The rank to build the engine.
@param args: The cmd line arguments.
@return: The built engine.
'''
# Initialize Module
args.mapping = Mapping(
world_size=args.world_size,
rank=rank,
tp_size=args.tp_size,
)
assert args.num_layers % args.pp_size == 0, \
f"num_layers {args.n_layer} must be a multiple of pipeline "\
f"parallelism size {args.pp_size}"
trtllm_model = ChatGLMHeadModel(
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
apply_residual_connection_post_layernorm=args.
apply_residual_connection_post_layernorm,
dtype=args.dtype,
enable_debug_output=args.enable_debug_output,
ffn_hidden_size=args.ffn_hidden_size,
hidden_act=args.hidden_act,
hidden_size=args.hidden_size,
linear_bias=args.linear_bias,
logits_dtype=args.logits_dtype,
mapping=args.mapping,
max_input_len=args.max_input_len,
max_output_len=args.max_output_len,
max_seq_length=args.max_seq_length,
model_name=args.model_name,
norm_epsilon=args.norm_epsilon,
num_heads=args.num_heads,
num_kv_heads=args.num_kv_heads,
num_layers=args.num_layers,
qkv_bias=args.qkv_bias,
quant_mode=args.quant_mode,
rmsnorm=args.rmsnorm,
rotary_embedding_scaling=args.rotary_embedding_scaling,
tokens_per_block=args.tokens_per_block,
use_cache=args.use_cache,
vocab_size=args.vocab_size,
)
if args.use_smooth_quant or args.use_weight_only:
trtllm_model = quantize_model(trtllm_model, args.quant_mode)
elif args.enable_fp8 or args.fp8_kv_cache:
logger.info(f'Loading scaling factors from '
f'{args.quantized_fp8_model_path}')
quant_scales = get_scaling_factors(args.quantized_fp8_model_path,
num_layers=args.n_layer,
quant_mode=args.quant_mode)
trtllm_model = quantize_model(trtllm_model,
quant_mode=args.quant_mode,
quant_scales=quant_scales)
trtllm_model = load_from_hf(
trtllm_model,
args.model_dir,
mapping=args.mapping,
dtype=args.dtype,
model_name=args.model_name,
)
# Module -> Network
network = builder.create_network()
network.trt_network.name = engine_name
if args.use_gpt_attention_plugin:
network.plugin_config.set_gpt_attention_plugin(
dtype=args.use_gpt_attention_plugin)
if args.use_gemm_plugin:
if not args.enable_fp8:
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
else:
logger.info(
"Gemm plugin does not support FP8. Disabled Gemm plugin.")
if args.use_rmsnorm_plugin:
network.plugin_config.set_rmsnorm_plugin(dtype=args.use_rmsnorm_plugin)
# Quantization plugins.
if args.use_smooth_quant:
network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype)
network.plugin_config.set_rmsnorm_quantization_plugin(dtype=args.dtype)
network.plugin_config.set_quantize_tensor_plugin()
network.plugin_config.set_quantize_per_token_plugin()
assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc)
if args.enable_context_fmha:
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
if args.enable_context_fmha_fp32_acc:
network.plugin_config.set_context_fmha(
ContextFMHAType.enabled_with_fp32_acc)
if args.multi_block_mode:
network.plugin_config.enable_mmha_multi_block_mode()
if args.use_weight_only:
if args.per_group:
network.plugin_config.set_weight_only_groupwise_quant_matmul_plugin(
dtype='float16')
else:
network.plugin_config.set_weight_only_quant_matmul_plugin(
dtype='float16')
if args.world_size > 1:
network.plugin_config.set_nccl_plugin(args.dtype,
args.use_custom_all_reduce)
if args.remove_input_padding:
network.plugin_config.enable_remove_input_padding()
if args.paged_kv_cache:
network.plugin_config.enable_paged_kv_cache(args.tokens_per_block)
with net_guard(network):
# Prepare
network.set_named_parameters(trtllm_model.named_parameters())
# Forward
inputs = trtllm_model.prepare_inputs(
max_batch_size=args.max_batch_size,
max_input_len=args.max_input_len,
max_new_tokens=args.max_output_len,
use_cache=True,
max_beam_width=args.max_beam_width,
)
trtllm_model(*inputs)
if args.enable_debug_output:
# mark intermediate nodes' outputs
for k, v in trtllm_model.named_network_outputs():
v = v.trt_tensor
v.name = k
network.trt_network.mark_output(v)
v.dtype = str_dtype_to_trt(args.dtype)
if args.visualize:
model_path = args.output_dir / 'test.onnx'
to_onnx(network.trt_network, model_path)
tensorrt_llm.graph_rewriting.optimize(network)
# Network -> Engine
engine = None
engine = builder.build_engine(network, builder_config)
if rank == 0:
config_path = args.output_dir / 'config.json'
builder.save_config(builder_config, config_path)
return engine
def build(rank, args):
torch.cuda.set_device(rank % args.gpus_per_node)
logger.set_level(args.log_level)
args.output_dir.mkdir(parents=True, exist_ok=True)
timing_cache_file = args.output_dir / "model.cache"
timing_cache = timing_cache_file
builder = Builder()
for cur_rank in range(args.world_size):
# skip other ranks if parallel_build is enabled
if args.parallel_build and cur_rank != rank:
continue
# NOTE: when only int8 kv cache is used together with paged kv cache no int8 tensors are exposed to TRT
int8_trt_flag = args.quant_mode.has_act_or_weight_quant() or (
not args.paged_kv_cache and args.quant_mode.has_int8_kv_cache())
builder_config = builder.create_builder_config(
precision=args.dtype,
timing_cache=timing_cache,
tensor_parallel=args.tp_size,
pipeline_parallel=args.pp_size,
int8=int8_trt_flag,
fp8=args.enable_fp8,
strongly_typed=args.strongly_typed,
opt_level=args.builder_opt,
hardware_compatibility=None,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
gather_all_token_logits=args.gather_all_token_logits,
hidden_act=args.hidden_act,
hidden_size=args.hidden_size,
max_batch_size=args.max_batch_size,
max_beam_width=args.max_beam_width,
max_input_len=args.max_input_len,
max_num_tokens=args.max_output_len + args.max_input_len,
max_output_len=args.max_output_len,
max_position_embeddings=args.max_seq_length,
multi_query_mode=args.multi_query_mode,
name=args.model_name,
num_heads=args.num_heads,
num_kv_heads=args.num_kv_heads,
num_layers=args.num_layers,
paged_kv_cache=args.paged_kv_cache,
parallel_build=args.parallel_build,
quant_mode=args.quant_mode,
remove_input_padding=args.remove_input_padding,
vocab_size=args.vocab_size,
)
engine_name = get_engine_name(
args.model_name,
args.dtype,
args.world_size,
args.pp_size,
cur_rank,
)
engine = build_rank_engine(
builder,
builder_config,
engine_name,
cur_rank,
args,
)
assert engine is not None, f'Failed to build engine for rank {cur_rank}'
local_num_kv_heads = (args.num_kv_heads + args.world_size -
1) // args.world_size
kv_dtype = str_dtype_to_trt(args.dtype)
if args.quant_mode.has_int8_kv_cache():
kv_dtype = str_dtype_to_trt('int8')
elif args.quant_mode.has_fp8_kv_cache():
kv_dtype = str_dtype_to_trt('fp8')
check_gpt_mem_usage(
engine=engine,
kv_dtype=kv_dtype,
use_gpt_attention_plugin=args.use_gpt_attention_plugin,
paged_kv_cache=args.paged_kv_cache,
max_batch_size=args.max_batch_size,
max_beam_width=args.max_beam_width,
max_input_len=args.max_input_len,
max_output_len=args.max_output_len,
local_num_kv_heads=local_num_kv_heads,
head_size=args.hidden_size // args.num_heads,
num_layers=args.num_layers)
if cur_rank == 0:
# Use in-memory timing cache for multiple builder passes.
if not args.parallel_build:
timing_cache = builder_config.trt_builder_config.get_timing_cache(
)
serialize_engine(engine, args.output_dir / engine_name)
del engine
if rank == 0:
ok = builder.save_timing_cache(builder_config, timing_cache_file)
assert ok, "Failed to save timing cache."
def run_build(args=None):
args = parse_arguments(args)
if args.random_seed is not None:
torch.manual_seed(args.random_seed)
logger.set_level(args.log_level)
tik = time.time()
if args.parallel_build and args.world_size > 1 and \
torch.cuda.device_count() >= args.world_size:
logger.warning(
f'Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.'
)
mp.spawn(build, nprocs=args.world_size, args=(args, ))
else:
args.parallel_build = False
logger.info('Serially build TensorRT engines.')
build(0, args)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Total time of building all {args.world_size} engines: {t}')
if __name__ == '__main__':
run_build()