TensorRT-LLMs/tensorrt_llm/builder.py
Kaiyu Xie d37b507f41
Update TensorRT-LLM main branch (#754)
* Update TensorRT-LLM

---------

Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2023-12-27 17:41:24 +08:00

654 lines
26 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import math
import os
import time
from dataclasses import dataclass
from functools import wraps
from pathlib import Path
from typing import Dict, Optional, Union
import tensorrt as trt
from packaging import version
from ._utils import to_dict, to_json_file, trt_version
from .graph_rewriting import optimize
from .logger import logger
from .models import MODEL_MAP, PretrainedConfig, PretrainedModel
from .network import Network, net_guard
from .plugin import PluginConfig
from .plugin.plugin import ContextFMHAType
from .quantization import QuantMode
from .version import __version__
class _BuildingFlag:
def __enter__(self):
os.environ['IS_BUILDING'] = '1'
def __exit__(self, type, value, tb):
del os.environ['IS_BUILDING']
def _is_building(f):
'''Use this to decorate functions which are called during engine building/refiting process,
otherwise, the plugin registration will fail.
'''
@wraps(f)
def decorated(*args, **kwargs):
with _BuildingFlag():
return f(*args, **kwargs)
return decorated
class BuilderConfig(object):
def __init__(self, **kwargs):
# intentionally use **kwargs, user should never call this ctor directly,
# use Builder.create_builder_config() instead
pass
def _init(self, trt_builder_config, **kwargs):
self._trt_builder_config = trt_builder_config
for key, value in kwargs.items():
setattr(self, key, value)
return self
@property
def trt_builder_config(self) -> trt.IBuilderConfig:
return self._trt_builder_config
def to_dict(self) -> Dict:
'''return a dict with keys
{
"builder_config": {
# all key values set by the _init function
},
"plugin_config": {
# the network plugin_config (if any) attached to this BuilderConfig object
# inside the Builder.build_engine
}
}
'''
config = {'builder_config': {}}
for k in self.__dict__.keys():
if k != '_trt_builder_config' and k != 'plugin_config':
config['builder_config'][k] = self.__getattribute__(k)
if hasattr(self, 'plugin_config'):
assert isinstance(self.plugin_config, PluginConfig), \
f"Found unexpected plugin_config object with type: {type(self.plugin_config)}"
config['plugin_config'] = to_dict(self.plugin_config)
return config
class Builder():
_ALLOWED_PRECISIONS = ['float32', 'float16', 'bfloat16']
def __init__(self):
super().__init__()
self._trt_builder = trt.Builder(logger.trt_logger)
self.strongly_typed = False
@property
def trt_builder(self) -> trt.Builder:
return self._trt_builder
def create_network(self) -> Network:
explicit_batch_flag = 0
if "EXPLICIT_BATCH" in trt.NetworkDefinitionCreationFlag.__members__.keys(
):
# Explicit batch flag will be deprecated in TRT 10
explicit_batch_flag = 1 << int(
trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
if version.parse(trt_version()) >= version.parse(
"9.1.0") and self.strongly_typed:
return Network()._init(
self.trt_builder.create_network(
explicit_batch_flag
| (1 << int(
trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))))
else:
return Network()._init(
self.trt_builder.create_network(explicit_batch_flag))
def create_builder_config(self,
precision: str,
timing_cache: Union[str, Path,
trt.ITimingCache] = None,
tensor_parallel: int = 1,
use_refit: bool = False,
int8: bool = False,
strongly_typed: bool = False,
opt_level: Optional[int] = None,
**kwargs) -> BuilderConfig:
''' @brief Create a builder config with given precisions and timing cache
@param precision: one of allowed precisions, defined in Builder._ALLOWED_PRECISIONS
@param timing_cache: a timing cache object or a path to a timing cache file
@param tensor_parallel: number of GPUs used for tensor parallel
@param kwargs: any other arguments users would like to attach to the config object as attributes
@param refit: set to accelerate multi-gpu building, build engine for 1 gpu and refit for the others
@param int8: whether to build with int8 enabled or not. Can't be used together with refit option
@return: A BuilderConfig object, return None if failed
'''
self.strongly_typed = strongly_typed
quant_mode = kwargs.get("quant_mode", QuantMode(0))
if not strongly_typed and precision not in self._ALLOWED_PRECISIONS:
logger.error(
f"precision should be one of {self._ALLOWED_PRECISIONS}")
if use_refit and int8:
# TRT folds weights into Myelin graph because network contains int8 tensor or Q/DQ nodes
# These folded weights can not be refitted
logger.error(f"can't use refit and int8 mode at the same time")
config = self.trt_builder.create_builder_config()
if not strongly_typed:
fp8 = quant_mode.has_fp8_qdq() or quant_mode.has_fp8_kv_cache()
if precision == 'float16' or precision == trt.DataType.HALF:
config.set_flag(trt.BuilderFlag.FP16)
config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
elif precision == 'bfloat16' or precision == trt.DataType.BF16:
config.set_flag(trt.BuilderFlag.BF16)
config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
if int8:
config.set_flag(trt.BuilderFlag.INT8)
if fp8:
config.set_flag(trt.BuilderFlag.FP8)
config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
config.set_preview_feature(trt.PreviewFeature.PROFILE_SHARING_0806,
True)
if use_refit:
config.set_flag(trt.BuilderFlag.REFIT)
if opt_level is not None:
config.builder_optimization_level = opt_level
# set timing cache
cache = None
if timing_cache is not None:
# use given cache
if isinstance(timing_cache, trt.ITimingCache):
cache = timing_cache
# read cache from file
elif isinstance(timing_cache,
(str, Path)) and os.path.exists(timing_cache):
with open(timing_cache, "rb") as f:
cache = config.create_timing_cache(f.read())
else:
logger.warning(
"Invalid timing cache, using freshly created one")
if cache is None:
cache = config.create_timing_cache(b"")
# When user does not given any existing cache, internally always created one
# so the cache should never None here
assert cache is not None and isinstance(cache, trt.ITimingCache)
config.set_timing_cache(cache, ignore_mismatch=False)
return BuilderConfig()._init(config,
precision=precision,
tensor_parallel=tensor_parallel,
use_refit=use_refit,
int8=int8,
**kwargs)
def _add_optimization_profile(self, network: Network,
builder_config: BuilderConfig):
assert isinstance(builder_config, BuilderConfig)
assert isinstance(network, Network)
input_tensors = network._inputs
num_profiles = len(list(input_tensors.items())[0][1].profiles)
for i in range(num_profiles):
logger.debug(f'Adding optimization profile {i+1}/{num_profiles}')
profile = self.trt_builder.create_optimization_profile()
for input_name in input_tensors.keys():
shape_profile = input_tensors[input_name].profiles[i]
min_shape = [*shape_profile.min]
opt_shape = [*shape_profile.opt]
max_shape = [*shape_profile.max]
if network._autopp_config is not None:
io_shards = network._autopp_config["io_shards"]
if input_name in io_shards:
shards = io_shards[input_name]
for dim, shard_num in shards.items():
min_shape[dim] = int(
math.floor(min_shape[dim] / shard_num))
opt_shape[dim] = int(
round(opt_shape[dim] / shard_num))
max_shape[dim] = int(
math.ceil(max_shape[dim] / shard_num))
profile.set_shape(input_name, min_shape, opt_shape, max_shape)
logger.debug(
f'{input_name}, min: {min_shape}, opt: {opt_shape}, max: {max_shape}, dimension names: {shape_profile.dimension_names}'
)
builder_config.trt_builder_config.add_optimization_profile(profile)
assert self._validate_named_dimensions(
network, builder_config
), "Validation of the tensor dimension ranges failed, please check the dimension ranges, find the offensive tensor and dimension name in above the error log"
def _validate_named_dimensions(self, network: Network,
builder_config) -> bool:
'''
For each profile, validate that the named dimensions of different input tensors in this profile all have same range.
TRT will validate the same condition, validate it earlier to make sure the modeling in TensorRT-LLM are correct and
makes the error msg more user friendly.
'''
valid = True
for profile_idx in range(
builder_config.trt_builder_config.num_optimization_profiles):
dimension_to_range = {}
for input_name, input_tensor in network._inputs.items():
# it's legal that a Tensor does not have dim_range?
if len(input_tensor.profiles) != 0:
profile = input_tensor.profiles[profile_idx]
for dim_idx, dim_name in enumerate(profile.dimension_names):
if dim_name not in dimension_to_range:
dimension_to_range[dim_name] = []
min, opt, max = profile.min[dim_idx], profile.opt[
dim_idx], profile.max[dim_idx]
dimension_to_range[dim_name].append(
(input_name, (min, opt, max)))
for dim, ranges in dimension_to_range.items():
unique_ranges = set([r[1] for r in ranges])
logger.debug(
f"Validating dimension:{dim}, ranges for this dim are:{unique_ranges}"
)
if len(unique_ranges) != 1:
logger.error(
f"Found illegal dimension setting for profile {profile_idx}, dimension name is: {dim}"
)
logger.error(
f"Offensive tensors which have this dimension are:\n" +
"\n".join([f"{r[1]} {dim} {r[0]}" for r in ranges]))
valid = False
return valid
@_is_building
def refit_engine(self, network: Network, engine_buffer) -> trt.IHostMemory:
'''
@brief: Refit one TensorRT engine using weights from the network,
user should guarantee that the engine is built with REFIT flag, and the network has the same structure with the engine.
@param engine_buffer: A serialized TensorRT engine.
@param network: Network object.
@return: A serialized TRT engine if refit successfully, None otherwise
'''
assert isinstance(network, Network)
logger.info(f'Refit TRT engine')
runtime = trt.Runtime(logger.trt_logger)
engine = runtime.deserialize_cuda_engine(engine_buffer)
tik = time.time()
# Refit engine
refitter = trt.Refitter(engine, logger.trt_logger)
if network.named_parameters is not None:
for name, param in network.named_parameters:
if param._get_weights(
) is None or not refitter.set_named_weights(
name, param._get_weights()):
logger.error(f'Failed to refit weight: {name}')
return None
else:
logger.error(
f'Please set named parameters before building multiple engines.'
)
return None
if not refitter.refit_cuda_engine():
logger.error(f'Failed to refit engine.')
return None
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Total time of refitting {engine.name}: {t}')
serialized_engine = engine.serialize()
return serialized_engine
@_is_building
def build_engine(self, network: Network,
builder_config: BuilderConfig) -> trt.IHostMemory:
'''
@brief: Build one TensorRT engine from the network.
@param network: Network object.
@param builder_config: BuilderConfig object.
@return: A serialized TRT engine.
'''
assert isinstance(network, Network)
builder_config.plugin_config = network.plugin_config
builder_config.autopp_config = network.autopp_config
if builder_config.trt_builder_config.num_optimization_profiles == 0:
self._add_optimization_profile(network, builder_config)
engine = None
logger.info(f'Build TensorRT engine {network.trt_network.name}')
tik = time.time()
# Rename weights
if network.named_parameters is not None:
for name, param in network.named_parameters:
if param._get_weights(
) is None or not network.trt_network.set_weights_name(
param._get_weights(), name):
raise RuntimeError(f'Failed to set weight: {name}')
# Build engine
network._fill_weights()
engine = self.trt_builder.build_serialized_network(
network.trt_network, builder_config.trt_builder_config)
if engine is None:
logger.error('Engine building failed, please check the error log.')
return None
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Total time of building {network.trt_network.name}: {t}')
return engine
@staticmethod
def save_timing_cache(builder_config: BuilderConfig, out_path: str) -> bool:
'''Serialize timing cache of given builder config to file specified by out_path
return True if the cache is successfully serialized, False otherwise
'''
cache = builder_config.trt_builder_config.get_timing_cache()
if cache is None:
logger.warning(
'No timing cache found in the given builder config, skip saving.'
)
return False
with cache.serialize() as buffer:
with open(out_path, "wb") as f:
f.write(buffer)
f.flush()
os.fsync(f)
logger.info(f'Timing cache serialized to {out_path}')
return True
@staticmethod
def save_config(builder_config: BuilderConfig, config_path: str):
config = builder_config.to_dict()
to_json_file(config, config_path)
logger.info(f'Config saved to {config_path}.')
@dataclass
class BuildConfig:
max_input_len: int = 256
max_output_len: int = 256
max_batch_size: int = 8
max_beam_width: int = 1
max_num_tokens: Optional[int] = None
max_prompt_embedding_table_size: int = 0
gather_all_token_logits: int = False
plugin_config: PluginConfig = PluginConfig()
@classmethod
def from_dict(cls, config):
max_input_len = config.pop('max_input_len')
max_output_len = config.pop('max_output_len')
max_batch_size = config.pop('max_batch_size')
max_beam_width = config.pop('max_beam_width')
max_num_tokens = config.pop('max_num_tokens')
max_prompt_embedding_table_size = config.pop(
'max_prompt_embedding_table_size', 0)
gather_all_token_logits = config.pop('gather_all_token_logits', False)
plugin_config = PluginConfig()
if 'plugin_config' not in config:
return cls(
max_input_len=max_input_len,
max_output_len=max_output_len,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_num_tokens=max_num_tokens,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
gather_all_token_logits=gather_all_token_logits,
plugin_config=plugin_config)
config = config['plugin_config']
gpt_attention_plugin = config.pop('gpt_attention_plugin', False)
if gpt_attention_plugin:
plugin_config.set_gpt_attention_plugin(dtype=gpt_attention_plugin)
gemm_plugin = config.pop('gemm_plugin', False)
if gemm_plugin:
plugin_config.set_gemm_plugin(dtype=gemm_plugin)
lookup_plugin = config.pop('lookup_plugin', False)
if lookup_plugin:
plugin_config.set_lookup_plugin(dtype=lookup_plugin)
enable_context_fmha = config.pop('enable_context_fmha', False)
enable_context_fmha_fp32_acc = config.pop(
'enable_context_fmha_fp32_acc', False)
assert not (enable_context_fmha and enable_context_fmha_fp32_acc)
if enable_context_fmha:
plugin_config.set_context_fmha(ContextFMHAType.enabled)
if enable_context_fmha_fp32_acc:
plugin_config.set_context_fmha(
ContextFMHAType.enabled_with_fp32_acc)
remove_input_padding = config.pop('remove_input_padding', False)
if remove_input_padding:
plugin_config.enable_remove_input_padding()
paged_kv_cache = config.pop('paged_kv_cache', False)
tokens_per_block = config.pop('tokens_per_block', 64)
if paged_kv_cache:
plugin_config.enable_paged_kv_cache(tokens_per_block)
use_custom_all_reduce = config.pop('use_custom_all_reduce', False)
plugin_config.use_custom_all_reduce = use_custom_all_reduce
return cls(
max_input_len=max_input_len,
max_output_len=max_output_len,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_num_tokens=max_num_tokens,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
gather_all_token_logits=gather_all_token_logits,
plugin_config=plugin_config)
@classmethod
def from_json_file(cls, config_file):
with open(config_file) as f:
config = json.load(f)
return BuildConfig.from_dict(config)
def to_dict(self):
output = copy.deepcopy(self.__dict__)
plugin_config = output.pop('plugin_config')
plugin_config_dict = copy.deepcopy(plugin_config.__dict__)
output['plugin_config'] = plugin_config_dict
return output
def serialize_engine(engine, path):
logger.info(f'Serializing engine to {path}...')
tik = time.time()
with open(path, 'wb') as f:
f.write(bytearray(engine))
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Engine serialized. Total time: {t}')
class EngineConfig:
def __init__(self, pretrained_config: PretrainedConfig,
build_config: BuildConfig, version: str):
self.pretrained_config = pretrained_config
self.build_config = build_config
self.version = version
@classmethod
def from_json_file(cls, config_file):
with open(config_file) as f:
config = json.load(f)
return cls(PretrainedConfig.from_dict(config['pretrained_config']),
BuildConfig.from_dict(config['build_config']),
config['version'])
def to_dict(self):
return {
'version': self.version,
'pretrained_config': self.pretrained_config.to_dict(),
'build_config': self.build_config.to_dict(),
}
class Engine:
def __init__(self, config: EngineConfig, engine: trt.IHostMemory):
self.config = config
self.engine = engine
def save(self, engine_dir: str):
if self.config.pretrained_config.mapping.rank == 0:
with open(os.path.join(engine_dir, 'config.json'),
"w",
encoding="utf-8") as f:
json.dump(self.config.to_dict(), f, indent=4)
serialize_engine(
self.engine,
os.path.join(
engine_dir,
f'rank{self.config.pretrained_config.mapping.rank}.engine'))
@classmethod
def from_dir(cls, engine_dir: str, rank: int = 0):
with open(os.path.join(engine_dir, f'rank{rank}.engine'), 'rb') as f:
engine_buffer = f.read()
config = EngineConfig.from_json_file(
os.path.join(engine_dir, 'config.json'))
config.pretrained_config.set_rank(rank)
return cls(config, engine_buffer)
def get_engine_version(engine_dir: str) -> Union[None, str]:
engine_dir = Path(engine_dir)
config_path = engine_dir / "config.json"
with open(config_path, 'r') as f:
config = json.load(f)
if 'version' not in config:
return None
return config['version']
def build_shard_model(model: PretrainedModel,
build_config: BuildConfig) -> Engine:
builder = Builder()
network = builder.create_network()
network._plugin_config = build_config.plugin_config
use_weight_only = model.config.quant_mode.is_weight_only()
per_group = model.config.quant_mode.has_per_group_scaling()
use_smooth_quant = model.config.quant_mode.has_act_and_weight_quant()
if use_weight_only:
if per_group:
network.plugin_config.set_weight_only_groupwise_quant_matmul_plugin(
dtype='float16')
else:
network.plugin_config.set_weight_only_quant_matmul_plugin(
dtype='float16')
if use_smooth_quant:
network.plugin_config.set_smooth_quant_gemm_plugin(dtype='float16')
network.plugin_config.set_rmsnorm_quantization_plugin(dtype='float16')
network.plugin_config.set_layernorm_quantization_plugin(dtype='float16')
network.plugin_config.set_quantize_tensor_plugin()
network.plugin_config.set_quantize_per_token_plugin()
nccl_plugin = model.config.dtype if model.config.mapping.world_size > 1 else False
if nccl_plugin:
network.plugin_config.set_nccl_plugin(
nccl_plugin, network.plugin_config.use_custom_all_reduce)
with net_guard(network):
# Prepare
network.set_named_parameters(model.named_parameters())
# Forward
inputs = model.prepare_inputs(
build_config.max_batch_size, build_config.max_input_len,
build_config.max_output_len, True, build_config.max_beam_width,
build_config.max_num_tokens,
build_config.max_prompt_embedding_table_size)
model(**inputs)
optimize(network)
builder_config = builder.create_builder_config(
precision=model.config.dtype,
int8=model.config.quant_mode.has_act_or_weight_quant()
or model.config.quant_mode.has_int8_kv_cache())
# Network -> Engine
engine = builder.build_engine(network, builder_config)
engine_config = EngineConfig(model.config, build_config, __version__)
return Engine(engine_config, engine)
def build(build_config: Union[str, BuildConfig],
rank: int = 0,
ckpt_dir: str = None,
model_config: Union[str, PretrainedConfig] = None,
weights=None,
model_cls=None) -> Engine:
if ckpt_dir is not None:
model_config = PretrainedConfig.from_json_file(
os.path.join(ckpt_dir, 'config.json'))
else:
assert model_config is not None
if isinstance(model_config, PretrainedConfig):
model_config = model_config
else:
model_config = PretrainedConfig.from_json_file(model_config)
if isinstance(build_config, str):
build_config = BuildConfig.from_json_file(build_config)
assert rank < model_config.mapping.world_size
architecture = model_config.architecture
if model_cls is None:
if architecture not in MODEL_MAP:
raise RuntimeError(
f'Unsupported model architecture: {architecture}')
model_cls = MODEL_MAP[architecture]
if ckpt_dir is not None:
model = model_cls.from_checkpoint(ckpt_dir, rank=rank)
else:
rank_config = copy.deepcopy(model_config)
rank_config.set_rank(rank)
model = model_cls.from_config(rank_config)
if weights is not None:
model.load(weights)
return build_shard_model(model, build_config)