mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
343 lines
14 KiB
Python
343 lines
14 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
import time
|
|
from functools import wraps
|
|
from pathlib import Path
|
|
from typing import Optional, Union
|
|
|
|
import tensorrt as trt
|
|
from packaging import version
|
|
|
|
from ._utils import to_dict, to_json_file, trt_version
|
|
from .logger import logger
|
|
from .network import Network
|
|
|
|
|
|
class _BuildingFlag:
|
|
|
|
def __enter__(self):
|
|
os.environ['IS_BUILDING'] = '1'
|
|
|
|
def __exit__(self, type, value, tb):
|
|
del os.environ['IS_BUILDING']
|
|
|
|
|
|
def _is_building(f):
|
|
'''Use this to decorate functions which are called during engine building/refiting process,
|
|
otherwise, the plugin registration will fail.
|
|
'''
|
|
|
|
@wraps(f)
|
|
def decorated(*args, **kwargs):
|
|
with _BuildingFlag():
|
|
return f(*args, **kwargs)
|
|
|
|
return decorated
|
|
|
|
|
|
class BuilderConfig(object):
|
|
|
|
def __init__(self, **kwargs):
|
|
# intentionally use **kwargs, user should never call this ctor directly,
|
|
# use Builder.create_builder_config() instead
|
|
pass
|
|
|
|
def _init(self, trt_builder_config, **kwargs):
|
|
self._trt_builder_config = trt_builder_config
|
|
for key, value in kwargs.items():
|
|
setattr(self, key, value)
|
|
return self
|
|
|
|
@property
|
|
def trt_builder_config(self) -> trt.IBuilderConfig:
|
|
return self._trt_builder_config
|
|
|
|
|
|
class Builder():
|
|
|
|
_ALLOWED_PRECISIONS = ['float32', 'float16', 'bfloat16']
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._trt_builder = trt.Builder(logger.trt_logger)
|
|
self.strongly_typed = False
|
|
|
|
@property
|
|
def trt_builder(self) -> trt.Builder:
|
|
return self._trt_builder
|
|
|
|
def create_network(self) -> Network:
|
|
if version.parse(trt_version()) >= version.parse(
|
|
"9.1.0") and self.strongly_typed:
|
|
return Network()._init(
|
|
self.trt_builder.create_network(
|
|
(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
|
| (1 << int(
|
|
trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))))
|
|
else:
|
|
return Network()._init(
|
|
self.trt_builder.create_network(
|
|
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)))
|
|
|
|
def create_builder_config(self,
|
|
precision: str,
|
|
timing_cache: Union[str, Path,
|
|
trt.ITimingCache] = None,
|
|
tensor_parallel: int = 1,
|
|
use_refit: bool = False,
|
|
int8: bool = False,
|
|
fp8: bool = False,
|
|
strongly_typed: bool = False,
|
|
opt_level: Optional[int] = None,
|
|
**kwargs) -> BuilderConfig:
|
|
''' @brief Create a builder config with given precisions and timing cache
|
|
@param precision: one of allowed precisions, defined in Builder._ALLOWED_PRECISIONS
|
|
@param timing_cache: a timing cache object or a path to a timing cache file
|
|
@param tensor_parallel: number of GPUs used for tensor parallel
|
|
@param kwargs: any other arguments users would like to attach to the config object as attributes
|
|
@param refit: set to accelerate multi-gpu building, build engine for 1 gpu and refit for the others
|
|
@param int8: whether to build with int8 enabled or not. Can't be used together with refit option
|
|
@return: A BuilderConfig object, return None if failed
|
|
'''
|
|
self.strongly_typed = strongly_typed
|
|
|
|
if not strongly_typed and precision not in self._ALLOWED_PRECISIONS:
|
|
logger.error(
|
|
f"precision should be one of {self._ALLOWED_PRECISIONS}")
|
|
|
|
if use_refit and int8:
|
|
# TRT folds weights into Myelin graph because network contains int8 tensor or Q/DQ nodes
|
|
# These folded weights can not be refitted
|
|
logger.error(f"can't use refit and int8 mode at the same time")
|
|
|
|
config = self.trt_builder.create_builder_config()
|
|
if not strongly_typed:
|
|
if precision == 'float16':
|
|
config.set_flag(trt.BuilderFlag.FP16)
|
|
config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
|
|
elif precision == 'bfloat16':
|
|
config.set_flag(trt.BuilderFlag.BF16)
|
|
config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
|
|
if int8:
|
|
config.set_flag(trt.BuilderFlag.INT8)
|
|
|
|
if fp8:
|
|
config.set_flag(trt.BuilderFlag.FP8)
|
|
config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
|
|
|
|
config.set_preview_feature(trt.PreviewFeature.PROFILE_SHARING_0806,
|
|
True)
|
|
|
|
if use_refit:
|
|
config.set_flag(trt.BuilderFlag.REFIT)
|
|
|
|
if opt_level is not None:
|
|
config.builder_optimization_level = opt_level
|
|
|
|
# set timing cache
|
|
cache = None
|
|
if timing_cache is not None:
|
|
# use given cache
|
|
if isinstance(timing_cache, trt.ITimingCache):
|
|
cache = timing_cache
|
|
# read cache from file
|
|
elif isinstance(timing_cache,
|
|
(str, Path)) and os.path.exists(timing_cache):
|
|
with open(timing_cache, "rb") as f:
|
|
cache = config.create_timing_cache(f.read())
|
|
else:
|
|
logger.warning(
|
|
"Invalid timing cache, using freshly created one")
|
|
if cache is None:
|
|
cache = config.create_timing_cache(b"")
|
|
# When user does not given any existing cache, internally always created one
|
|
# so the cache should never None here
|
|
assert cache is not None and isinstance(cache, trt.ITimingCache)
|
|
config.set_timing_cache(cache, ignore_mismatch=False)
|
|
|
|
return BuilderConfig()._init(config,
|
|
precision=precision,
|
|
tensor_parallel=tensor_parallel,
|
|
use_refit=use_refit,
|
|
int8=int8,
|
|
fp8=fp8,
|
|
**kwargs)
|
|
|
|
def _add_optimization_profile(self, network: Network,
|
|
builder_config: BuilderConfig):
|
|
assert isinstance(builder_config, BuilderConfig)
|
|
assert isinstance(network, Network)
|
|
input_tensors = network._inputs
|
|
num_profiles = len(list(input_tensors.items())[0][1].profiles)
|
|
for i in range(num_profiles):
|
|
logger.debug(f'Adding optimization profile {i+1}/{num_profiles}')
|
|
profile = self.trt_builder.create_optimization_profile()
|
|
for input_name in input_tensors.keys():
|
|
shape_profile = input_tensors[input_name].profiles[i]
|
|
profile.set_shape(input_name, shape_profile.min,
|
|
shape_profile.opt, shape_profile.max)
|
|
logger.debug(
|
|
f'{input_name}, min: {shape_profile.min}, opt: {shape_profile.opt}, max: {shape_profile.max}, dimension names: {shape_profile.dimension_names}'
|
|
)
|
|
builder_config.trt_builder_config.add_optimization_profile(profile)
|
|
assert self._validate_named_dimensions(
|
|
network, builder_config
|
|
), "Validation of the tensor dimension ranges failed, please check the dimension ranges, find the offensive tensor and dimension name in above the error log"
|
|
|
|
def _validate_named_dimensions(self, network: Network,
|
|
builder_config) -> bool:
|
|
'''
|
|
For each profile, validate that the named dimensions of different input tensors in this profile all have same range.
|
|
TRT will validate the same condition, validate it earlier to make sure the modeling in TensorRT-LLM are correct and
|
|
makes the error msg more user friendly.
|
|
'''
|
|
valid = True
|
|
for profile_idx in range(
|
|
builder_config.trt_builder_config.num_optimization_profiles):
|
|
dimension_to_range = {}
|
|
for input_name, input_tensor in network._inputs.items():
|
|
# it's legal that a Tensor does not have dim_range?
|
|
if len(input_tensor.profiles) != 0:
|
|
profile = input_tensor.profiles[profile_idx]
|
|
for dim_idx, dim_name in enumerate(profile.dimension_names):
|
|
if dim_name not in dimension_to_range:
|
|
dimension_to_range[dim_name] = []
|
|
min, opt, max = profile.min[dim_idx], profile.opt[
|
|
dim_idx], profile.max[dim_idx]
|
|
dimension_to_range[dim_name].append(
|
|
(input_name, (min, opt, max)))
|
|
for dim, ranges in dimension_to_range.items():
|
|
unique_ranges = set([r[1] for r in ranges])
|
|
logger.debug(
|
|
f"Validating dimension:{dim}, ranges for this dim are:{unique_ranges}"
|
|
)
|
|
if len(unique_ranges) != 1:
|
|
logger.error(
|
|
f"Found illegal dimension setting for profile {profile_idx}, dimension name is: {dim}"
|
|
)
|
|
logger.error(
|
|
f"Offensive tensors which have this dimension are:\n" +
|
|
"\n".join([f"{r[1]} {dim} {r[0]}" for r in ranges]))
|
|
valid = False
|
|
return valid
|
|
|
|
@_is_building
|
|
def refit_engine(self, network: Network, engine_buffer) -> trt.IHostMemory:
|
|
'''
|
|
@brief: Refit one TensorRT engine using weights from the network,
|
|
user should guarantee that the engine is built with REFIT flag, and the network has the same structure with the engine.
|
|
@param engine_buffer: A serialized TensorRT engine.
|
|
@param network: Network object.
|
|
@return: A serialized TRT engine if refit successfully, None otherwise
|
|
'''
|
|
assert isinstance(network, Network)
|
|
logger.info(f'Refit TRT engine')
|
|
runtime = trt.Runtime(logger.trt_logger)
|
|
engine = runtime.deserialize_cuda_engine(engine_buffer)
|
|
|
|
tik = time.time()
|
|
|
|
# Refit engine
|
|
refitter = trt.Refitter(engine, logger.trt_logger)
|
|
if network.named_parameters is not None:
|
|
for name, param in network.named_parameters:
|
|
if param._get_weights(
|
|
) is None or not refitter.set_named_weights(
|
|
name, param._get_weights()):
|
|
logger.error(f'Failed to refit weight: {name}')
|
|
return None
|
|
else:
|
|
logger.error(
|
|
f'Please set named parameters before building multiple engines.'
|
|
)
|
|
return None
|
|
|
|
if not refitter.refit_cuda_engine():
|
|
logger.error(f'Failed to refit engine.')
|
|
return None
|
|
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
logger.info(f'Total time of refitting {engine.name}: {t}')
|
|
serialized_engine = engine.serialize()
|
|
return serialized_engine
|
|
|
|
@_is_building
|
|
def build_engine(self, network: Network,
|
|
builder_config: BuilderConfig) -> trt.IHostMemory:
|
|
'''
|
|
@brief: Build one TensorRT engine from the network.
|
|
@param network: Network object.
|
|
@param builder_config: BuilderConfig object.
|
|
@return: A serialized TRT engine.
|
|
'''
|
|
assert isinstance(network, Network)
|
|
builder_config.plugin_config = network.plugin_config
|
|
self._add_optimization_profile(network, builder_config)
|
|
engine = None
|
|
logger.info(f'Build TensorRT engine {network.trt_network.name}')
|
|
tik = time.time()
|
|
|
|
# Rename weights
|
|
if network.named_parameters is not None:
|
|
for name, param in network.named_parameters:
|
|
if param._get_weights(
|
|
) is None or not network.trt_network.set_weights_name(
|
|
param._get_weights(), name):
|
|
raise RuntimeError(f'Failed to set weight: {name}')
|
|
|
|
# Build engine
|
|
engine = self.trt_builder.build_serialized_network(
|
|
network.trt_network, builder_config.trt_builder_config)
|
|
if engine is None:
|
|
logger.error('Engine building failed, please check the error log.')
|
|
return None
|
|
|
|
tok = time.time()
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
|
logger.info(f'Total time of building {network.trt_network.name}: {t}')
|
|
|
|
return engine
|
|
|
|
@staticmethod
|
|
def save_timing_cache(builder_config: BuilderConfig, out_path: str) -> bool:
|
|
'''Serialize timing cache of given builder config to file specified by out_path
|
|
return True if the cache is successfully serialized, False otherwise
|
|
'''
|
|
cache = builder_config.trt_builder_config.get_timing_cache()
|
|
if cache is None:
|
|
logger.warning(
|
|
'No timing cache found in the given builder config, skip saving.'
|
|
)
|
|
return False
|
|
with cache.serialize() as buffer:
|
|
with open(out_path, "wb") as f:
|
|
f.write(buffer)
|
|
f.flush()
|
|
os.fsync(f)
|
|
logger.info(f'Timing cache serialized to {out_path}')
|
|
return True
|
|
|
|
@staticmethod
|
|
def save_config(builder_config: BuilderConfig, config_path: str):
|
|
config = {'builder_config': {}}
|
|
for k in builder_config.__dict__.keys():
|
|
if k != '_trt_builder_config' and k != 'plugin_config':
|
|
config['builder_config'][k] = builder_config.__getattribute__(k)
|
|
config['plugin_config'] = to_dict(builder_config.plugin_config)
|
|
to_json_file(config, config_path)
|
|
logger.info(f'Config saved to {config_path}.')
|