TensorRT-LLMs/tensorrt_llm/builder.py
Kaiyu Xie e06f537e08
Update TensorRT-LLM (#1019)
* Update TensorRT-LLM

---------

Co-authored-by: erenup <ping.nie@pku.edu.cn>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-01-31 21:55:32 +08:00

457 lines
20 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import math
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Union
import tensorrt as trt
from ._common import _is_building
from ._utils import support_strongly_type, to_dict, to_json_file
from .logger import logger
from .network import Network
from .plugin import PluginConfig
from .quantization import QuantMode
class BuilderConfig(object):
def __init__(self, **kwargs):
# intentionally use **kwargs, user should never call this ctor directly,
# use Builder.create_builder_config() instead
pass
def _init(self, trt_builder_config, **kwargs):
self._trt_builder_config = trt_builder_config
for key, value in kwargs.items():
setattr(self, key, value)
return self
@property
def trt_builder_config(self) -> trt.IBuilderConfig:
return self._trt_builder_config
def to_dict(self) -> Dict:
'''return a dict with keys
{
"builder_config": {
# all key values set by the _init function
},
"plugin_config": {
# the network plugin_config (if any) attached to this BuilderConfig object
# inside the Builder.build_engine
},
"auto_parallel_config": {
# the network auto_parallel_config (if any) attached to this BuilderConfig object
# inside the Builder.build_engine
}
}
'''
config = {'builder_config': {}}
for k in self.__dict__.keys():
if k not in [
'_trt_builder_config', 'plugin_config',
'auto_parallel_config'
]:
config['builder_config'][k] = self.__getattribute__(k)
if hasattr(self, 'plugin_config'):
assert isinstance(self.plugin_config, PluginConfig), \
f"Found unexpected plugin_config object with type: {type(self.plugin_config)}"
config['plugin_config'] = to_dict(self.plugin_config)
if self.auto_parallel_config is not None:
config['auto_parallel_config'] = self.auto_parallel_config
config['builder_config']['tensor_parallel'] = math.prod(
self.auto_parallel_config['mesh'])
return config
class Builder():
_ALLOWED_PRECISIONS = ['float32', 'float16', 'bfloat16']
def __init__(self):
super().__init__()
self._trt_builder = trt.Builder(logger.trt_logger)
# TODO: Enable strongly_typed on by default in TRT 10.0
self.strongly_typed = False
@property
def trt_builder(self) -> trt.Builder:
return self._trt_builder
def create_network(self) -> Network:
explicit_batch_flag = 0
# Explicit batch flag will be deprecated in TRT 10
if "EXPLICIT_BATCH" in trt.NetworkDefinitionCreationFlag.__members__.keys(
):
explicit_batch_flag = 1 << int(
trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
if support_strongly_type() and self.strongly_typed:
return Network()._init(
self.trt_builder.create_network(
explicit_batch_flag
| (1 << int(
trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))))
else:
return Network()._init(
self.trt_builder.create_network(explicit_batch_flag))
def create_builder_config(self,
precision: str,
timing_cache: Union[str, Path,
trt.ITimingCache] = None,
tensor_parallel: int = 1,
use_refit: bool = False,
int8: bool = False,
strongly_typed: bool = False,
opt_level: Optional[int] = None,
profiling_verbosity: str = "layer_names_only",
**kwargs) -> BuilderConfig:
''' @brief Create a builder config with given precisions and timing cache
@param precision: one of allowed precisions, defined in Builder._ALLOWED_PRECISIONS
@param timing_cache: a timing cache object or a path to a timing cache file
@param tensor_parallel: number of GPUs used for tensor parallel
@param kwargs: any other arguments users would like to attach to the config object as attributes
@param refit: set to accelerate multi-gpu building, build engine for 1 gpu and refit for the others
@param int8: whether to build with int8 enabled or not. Can't be used together with refit option
@return: A BuilderConfig object, return None if failed
'''
if strongly_typed and not support_strongly_type():
logger.warning(
"TRT version does not support strongly_type. strongly_typed flag is ignored."
)
# In TRT 10.0, enable strongly_typed by default.
self.strongly_typed = self.strongly_typed or (strongly_typed and
support_strongly_type())
quant_mode = kwargs.get("quant_mode", QuantMode(0))
if not strongly_typed and precision not in self._ALLOWED_PRECISIONS:
logger.error(
f"precision should be one of {self._ALLOWED_PRECISIONS}")
if use_refit and int8:
# TRT folds weights into Myelin graph because network contains int8 tensor or Q/DQ nodes
# These folded weights can not be refitted
logger.error(f"can't use refit and int8 mode at the same time")
config = self.trt_builder.create_builder_config()
if not self.strongly_typed:
fp8 = quant_mode.has_fp8_qdq() or quant_mode.has_fp8_kv_cache()
if precision == 'float16' or precision == trt.DataType.HALF:
config.set_flag(trt.BuilderFlag.FP16)
config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
elif precision == 'bfloat16' or precision == trt.DataType.BF16:
config.set_flag(trt.BuilderFlag.BF16)
config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
if int8:
config.set_flag(trt.BuilderFlag.INT8)
if fp8:
config.set_flag(trt.BuilderFlag.FP8)
config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
config.set_preview_feature(trt.PreviewFeature.PROFILE_SHARING_0806,
True)
if use_refit:
config.set_flag(trt.BuilderFlag.REFIT)
if opt_level is not None:
config.builder_optimization_level = opt_level
# Set TRT Engine profiling verbosity
if profiling_verbosity == "detailed":
config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
elif profiling_verbosity == "none":
config.profiling_verbosity = trt.ProfilingVerbosity.NONE
else:
config.profiling_verbosity = trt.ProfilingVerbosity.LAYER_NAMES_ONLY
# set timing cache
cache = None
if timing_cache is not None:
# use given cache
if isinstance(timing_cache, trt.ITimingCache):
cache = timing_cache
# read cache from file
elif isinstance(timing_cache,
(str, Path)) and os.path.exists(timing_cache):
with open(timing_cache, "rb") as f:
cache = config.create_timing_cache(f.read())
else:
logger.warning(
"Invalid timing cache, using freshly created one")
if cache is None:
cache = config.create_timing_cache(b"")
# When user does not given any existing cache, internally always created one
# so the cache should never None here
assert cache is not None and isinstance(cache, trt.ITimingCache)
config.set_timing_cache(cache, ignore_mismatch=False)
return BuilderConfig()._init(config,
precision=precision,
tensor_parallel=tensor_parallel,
use_refit=use_refit,
int8=int8,
strongly_typed=self.strongly_typed,
**kwargs)
def _add_optimization_profile(self, network: Network,
builder_config: BuilderConfig):
assert isinstance(builder_config, BuilderConfig)
assert isinstance(network, Network)
input_tensors = network._inputs
num_profiles = len(list(input_tensors.values())[0].profiles)
for i in range(num_profiles):
logger.debug(f'Adding optimization profile {i+1}/{num_profiles}')
profile = self.trt_builder.create_optimization_profile()
for input_name in input_tensors.keys():
shape_profile = input_tensors[input_name].profiles[i]
min_shape = [*shape_profile.min]
opt_shape = [*shape_profile.opt]
max_shape = [*shape_profile.max]
if network._auto_parallel_config is not None:
io_shards = network._auto_parallel_config["io_shards"]
if input_name in io_shards:
shards = io_shards[input_name]
for dim, shard_num in shards.items():
min_shape[dim] = int(
math.floor(min_shape[dim] / shard_num))
opt_shape[dim] = int(
round(opt_shape[dim] / shard_num))
max_shape[dim] = int(
math.ceil(max_shape[dim] / shard_num))
profile.set_shape(input_name, min_shape, opt_shape, max_shape)
logger.debug(
f'{input_name}, min: {min_shape}, opt: {opt_shape}, max: {max_shape}, dimension names: {shape_profile.dimension_names}'
)
builder_config.trt_builder_config.add_optimization_profile(profile)
assert self._validate_named_dimensions(
network, builder_config
), "Validation of the tensor dimension ranges failed, please check the dimension ranges, find the offensive tensor and dimension name in above the error log"
def _validate_named_dimensions(self, network: Network,
builder_config) -> bool:
'''
For each profile, validate that the named dimensions of different input tensors in this profile all have same range.
TRT will validate the same condition, validate it earlier to make sure the modeling in TensorRT-LLM are correct and
makes the error msg more user friendly.
'''
valid = True
for profile_idx in range(
builder_config.trt_builder_config.num_optimization_profiles):
dimension_to_range = {}
for input_name, input_tensor in network._inputs.items():
# it's legal that a Tensor does not have dim_range?
if len(input_tensor.profiles) != 0:
profile = input_tensor.profiles[profile_idx]
for dim_idx, dim_name in enumerate(profile.dimension_names):
if dim_name not in dimension_to_range:
dimension_to_range[dim_name] = []
min, opt, max = profile.min[dim_idx], profile.opt[
dim_idx], profile.max[dim_idx]
dimension_to_range[dim_name].append(
(input_name, (min, opt, max)))
for dim, ranges in dimension_to_range.items():
unique_ranges = set([r[1] for r in ranges])
logger.debug(
f"Validating dimension:{dim}, ranges for this dim are:{unique_ranges}"
)
if len(unique_ranges) != 1:
logger.error(
f"Found illegal dimension setting for profile {profile_idx}, dimension name is: {dim}"
)
logger.error(
f"Offensive tensors which have this dimension are:\n" +
"\n".join([f"{r[1]} {dim} {r[0]}" for r in ranges]))
valid = False
return valid
@_is_building
def refit_engine(self, network: Network, engine_buffer) -> trt.IHostMemory:
'''
@brief: Refit one TensorRT engine using weights from the network,
user should guarantee that the engine is built with REFIT flag, and the network has the same structure with the engine.
@param engine_buffer: A serialized TensorRT engine.
@param network: Network object.
@return: A serialized TRT engine if refit successfully, None otherwise
'''
assert isinstance(network, Network)
logger.info(f'Refit TRT engine')
runtime = trt.Runtime(logger.trt_logger)
engine = runtime.deserialize_cuda_engine(engine_buffer)
tik = time.time()
# Refit engine
refitter = trt.Refitter(engine, logger.trt_logger)
if network.named_parameters is not None:
for name, param in network.named_parameters:
if param._get_weights(
) is None or not refitter.set_named_weights(
name, param._get_weights()):
logger.error(f'Failed to refit weight: {name}')
return None
else:
logger.error(
f'Please set named parameters before building multiple engines.'
)
return None
if not refitter.refit_cuda_engine():
logger.error(f'Failed to refit engine.')
return None
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Total time of refitting {engine.name}: {t}')
serialized_engine = engine.serialize()
return serialized_engine
@_is_building
def build_engine(self, network: Network,
builder_config: BuilderConfig) -> trt.IHostMemory:
'''
@brief: Build one TensorRT engine from the network.
@param network: Network object.
@param builder_config: BuilderConfig object.
@return: A serialized TRT engine.
'''
assert isinstance(network, Network)
builder_config.plugin_config = network.plugin_config
builder_config.auto_parallel_config = network.auto_parallel_config
if builder_config.trt_builder_config.num_optimization_profiles == 0:
self._add_optimization_profile(network, builder_config)
engine = None
# Rename weights
if network.named_parameters is not None:
for name, param in network.named_parameters:
if param._get_weights(
) is None or not network.trt_network.set_weights_name(
param._get_weights(), name):
raise RuntimeError(f'Failed to set weight: {name}')
network._fill_weights()
# Build engine
logger.info(f'Build TensorRT engine {network.trt_network.name}')
tik = time.time()
engine = self.trt_builder.build_serialized_network(
network.trt_network, builder_config.trt_builder_config)
if engine is None:
logger.error('Engine building failed, please check the error log.')
return None
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Total time of building {network.trt_network.name}: {t}')
return engine
@staticmethod
def save_timing_cache(builder_config: BuilderConfig, out_path: str) -> bool:
'''Serialize timing cache of given builder config to file specified by out_path
return True if the cache is successfully serialized, False otherwise
'''
cache = builder_config.trt_builder_config.get_timing_cache()
if cache is None:
logger.warning(
'No timing cache found in the given builder config, skip saving.'
)
return False
with cache.serialize() as buffer:
with open(out_path, "wb") as f:
f.write(buffer)
f.flush()
os.fsync(f)
logger.info(f'Timing cache serialized to {out_path}')
return True
@staticmethod
def save_config(builder_config: BuilderConfig, config_path: str):
config = builder_config.to_dict()
to_json_file(config, config_path)
logger.info(f'Config saved to {config_path}.')
@dataclass
class BuildConfig:
max_input_len: int = 256
max_output_len: int = 256
max_batch_size: int = 8
max_beam_width: int = 1
max_num_tokens: Optional[int] = None
max_prompt_embedding_table_size: int = 0
gather_context_logits: int = False
gather_generation_logits: int = False
strongly_typed: bool = False
builder_opt: Optional[int] = None
profiling_verbosity: str = "layer_names_only"
plugin_config: PluginConfig = PluginConfig()
@classmethod
def from_dict(cls, config, plugin_config=None):
max_input_len = config.pop('max_input_len')
max_output_len = config.pop('max_output_len')
max_batch_size = config.pop('max_batch_size')
max_beam_width = config.pop('max_beam_width')
max_num_tokens = config.pop('max_num_tokens')
max_prompt_embedding_table_size = config.pop(
'max_prompt_embedding_table_size', 0)
gather_context_logits = config.pop('gather_context_logits', False)
gather_generation_logits = config.pop('gather_generation_logits', False)
strongly_typed = config.pop('strongly_typed', False)
builder_opt = config.pop('builder_opt', None)
profiling_verbosity = config.pop('profiling_verbosity',
"layer_names_only")
if plugin_config is None:
plugin_config = PluginConfig()
if "plugin_config" in config.keys():
plugin_config.update_from_dict(config["plugin_config"])
return cls(
max_input_len=max_input_len,
max_output_len=max_output_len,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_num_tokens=max_num_tokens,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
gather_context_logits=gather_context_logits,
gather_generation_logits=gather_generation_logits,
strongly_typed=strongly_typed,
builder_opt=builder_opt,
profiling_verbosity=profiling_verbosity,
plugin_config=plugin_config)
@classmethod
def from_json_file(cls, config_file, plugin_config=None):
with open(config_file) as f:
config = json.load(f)
return BuildConfig.from_dict(config, plugin_config=plugin_config)
def to_dict(self):
output = copy.deepcopy(self.__dict__)
plugin_config = output.pop('plugin_config')
plugin_config_dict = copy.deepcopy(plugin_config.__dict__)
output['plugin_config'] = plugin_config_dict
return output