TensorRT-LLMs/tensorrt_llm/auto_parallel/utils.py
Kaiyu Xie 4bb65f216f
Update TensorRT-LLM (#1274)
* Update TensorRT-LLM

---------

Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-03-12 18:15:52 +08:00

320 lines
11 KiB
Python

import contextlib
import threading
from enum import Enum, EnumMeta
from types import NoneType
from typing import ByteString, Iterable, MutableMapping
import tensorrt as trt
import torch
from tensorrt_llm._utils import get_extra_attr, np_dtype_to_trt, set_extra_attr
from tensorrt_llm.logger import logger
from tensorrt_llm.network import PluginInfo, get_plugin_info
class BaseEnumMeta(EnumMeta):
def __contains__(cls, item):
try:
cls(item)
except ValueError:
return False
return True
class BaseEnum(Enum, metaclass=BaseEnumMeta):
pass
LAYER_TYPE_2_CLASS = {
trt.LayerType.ACTIVATION: trt.IActivationLayer,
trt.LayerType.CONCATENATION: trt.IConcatenationLayer,
trt.LayerType.CONSTANT: trt.IConstantLayer,
trt.LayerType.ELEMENTWISE: trt.IElementWiseLayer,
trt.LayerType.FILL: trt.IFillLayer,
trt.LayerType.GATHER: trt.IGatherLayer,
trt.LayerType.MATRIX_MULTIPLY: trt.IMatrixMultiplyLayer,
trt.LayerType.REDUCE: trt.IReduceLayer,
trt.LayerType.SELECT: trt.ISelectLayer,
trt.LayerType.SHUFFLE: trt.IShuffleLayer,
trt.LayerType.SLICE: trt.ISliceLayer,
trt.LayerType.SOFTMAX: trt.ISoftMaxLayer,
trt.LayerType.UNARY: trt.IUnaryLayer,
trt.LayerType.SHAPE: trt.IShapeLayer,
trt.LayerType.ASSERTION: trt.IAssertionLayer,
trt.LayerType.CAST: trt.ICastLayer,
trt.LayerType.NORMALIZATION: trt.INormalizationLayer,
trt.LayerType.IDENTITY: trt.IIdentityLayer,
trt.LayerType.PLUGIN_V2: trt.IPluginV2Layer,
}
def to_subclass_layer(trt_layer):
trt_layer.__class__ = LAYER_TYPE_2_CLASS[trt_layer.type]
def to_base_class_layer(trt_layer):
trt_layer.__class__ = trt.ILayer
def to_trt_weights(ndarray):
weight = trt.Weights(
np_dtype_to_trt(ndarray.dtype),
ndarray.ctypes.data,
ndarray.size,
)
# Prevent numpy array from going out of weight's lifetime scope
set_extra_attr(weight, "numpy", ndarray)
return weight
@contextlib.contextmanager
def silent_trt_logger():
min_severity = logger.trt_logger.min_severity
logger.trt_logger.min_severity = trt.Logger.ERROR
yield
logger.trt_logger.min_severity = min_severity
def compare_tensor(trt_tensor, new_trt_tensor):
assert trt_tensor.name == new_trt_tensor.name
assert trt_tensor.dtype == new_trt_tensor.dtype
assert tuple(trt_tensor.shape) == tuple(new_trt_tensor.shape)
assert trt_tensor.broadcast_across_batch == new_trt_tensor.broadcast_across_batch
assert trt_tensor.location == new_trt_tensor.location
assert trt_tensor.is_network_input == new_trt_tensor.is_network_input
assert trt_tensor.is_network_output == new_trt_tensor.is_network_output
assert trt_tensor.dynamic_range == new_trt_tensor.dynamic_range
assert trt_tensor.is_shape_tensor == new_trt_tensor.is_shape_tensor
assert trt_tensor.is_execution_tensor == new_trt_tensor.is_execution_tensor
assert trt_tensor.allowed_formats == new_trt_tensor.allowed_formats
def compare_network(trt_network, new_trt_network):
assert trt_network.num_inputs == new_trt_network.num_inputs
for i in range(trt_network.num_inputs):
input = trt_network.get_input(i)
new_input = new_trt_network.get_input(i)
compare_tensor(input, new_input)
assert trt_network.num_outputs == new_trt_network.num_outputs
for i in range(trt_network.num_outputs):
output = trt_network.get_output(i)
new_output = new_trt_network.get_output(i)
compare_tensor(output, new_output)
assert trt_network.num_layers == new_trt_network.num_layers
for index, new_index in zip(get_sorted_layer_ids(trt_network),
get_sorted_layer_ids(new_trt_network)):
layer = trt_network.get_layer(index)
new_layer = new_trt_network.get_layer(new_index)
assert layer.name == new_layer.name
assert layer.type == new_layer.type
assert layer.precision_is_set == new_layer.precision_is_set
assert layer.precision == new_layer.precision
assert layer.num_inputs == new_layer.num_inputs
for j in range(layer.num_inputs):
input = layer.get_input(j)
new_input = new_layer.get_input(j)
if input is None:
assert new_input is None
else:
assert new_input is not None
compare_tensor(input, new_input)
assert layer.num_outputs == new_layer.num_outputs
for j in range(layer.num_outputs):
output = layer.get_output(j)
new_output = new_layer.get_output(j)
compare_tensor(output, new_output)
assert layer.output_type_is_set(j) == new_layer.output_type_is_set(
j)
if layer.output_type_is_set(j):
assert layer.get_output_type(j) == new_layer.get_output_type(j)
def get_sorted_layer_ids(trt_network):
inputs = set()
for i in range(trt_network.num_inputs):
inputs.add(trt_network.get_input(i).name)
layer_ids = [*range(trt_network.num_layers)]
sorted_layer_ids = []
walked_tensors = set(inputs)
while len(layer_ids) > 0:
layer_id = layer_ids.pop(0)
layer = trt_network.get_layer(layer_id)
no_dependencies = True
for j in range(layer.num_inputs):
input = layer.get_input(j)
if input is None:
continue
if input.name in walked_tensors:
continue
else:
no_dependencies = False
break
if no_dependencies:
sorted_layer_ids.append(layer_id)
for j in range(layer.num_outputs):
output = layer.get_output(j)
if output is None:
continue
walked_tensors.add(output.name)
else:
layer_ids.append(layer_id)
assert len(sorted_layer_ids) == trt_network.num_layers
return sorted_layer_ids
def to_tuple(values):
if isinstance(values, (int, float, str, bool, NoneType, ByteString)):
return values
elif isinstance(values, (trt.Dims, trt.Permutation)):
if values.__len__() < 0:
return None
else:
return tuple(values)
elif isinstance(values, Iterable):
return tuple(to_tuple(v) for v in values)
elif isinstance(values, MutableMapping):
return tuple((k, to_tuple(v)) for k, v in values.items())
else:
return values
_base_layer_attr_names = set(dir(trt.ILayer))
def get_cache_key(layer, shapes, values, dtypes=None, updated_attrs=None):
updated_attrs = updated_attrs or {}
layer_type = layer.type
to_subclass_layer(layer)
attr_names = set(dir(layer)) - _base_layer_attr_names
if layer_type == trt.LayerType.CONSTANT:
attr_names.remove("weights")
elif layer_type == trt.LayerType.SHUFFLE:
if layer.num_inputs >= 2:
attr_names.remove("reshape_dims")
elif layer_type == trt.LayerType.SLICE:
if layer.num_inputs >= 2 and layer.get_input(1) is not None:
attr_names.remove("start")
if layer.num_inputs >= 3 and layer.get_input(2) is not None:
attr_names.remove("shape")
if layer.num_inputs >= 4 and layer.get_input(3) is not None:
attr_names.remove("stride")
elif layer_type == trt.LayerType.FILL:
attr_names.remove("is_alpha_beta_int64")
if layer.num_inputs >= 1 and layer.get_input(0) is not None:
attr_names.remove("shape")
if layer.num_inputs >= 2 and layer.get_input(1) is not None:
attr_names.remove("alpha")
if layer.num_inputs >= 3 and layer.get_input(2) is not None:
attr_names.remove("beta")
if layer_type != trt.LayerType.PLUGIN_V2:
attr_key = tuple(
(name, to_tuple(updated_attrs.get(name) or getattr(layer, name)))
for name in sorted(attr_names))
else:
network = get_trt_network(layer)
plugin_info = get_plugin_info(network, layer.name)
attr_key = tuple(
(name, tuple(updated_attrs.get(name) or data))
for name, data in sorted(plugin_info.pfc_as_list.items()))
to_base_class_layer(layer)
shape_key = ()
value_key = ()
dtype_key = ()
for i in range(layer.num_inputs):
input = layer.get_input(i)
if input is not None:
shape_key += (tuple(shapes[input.name]), )
if input.name in values:
value = values[input.name]
# All torch tensors are derived from input shapes and pfc,
# thus we ignore them in cache key
if isinstance(value, torch.Tensor):
value = None
else:
value = tuple(value)
value_key += (value, )
else:
value_key += (None, )
if dtypes is not None:
dtype_key += (dtypes[input.name], )
else:
shape_key += (None, )
value_key += (None, )
dtype_key += (None, )
if dtypes is not None:
for i in range(layer.num_outputs):
output = layer.get_output(i)
dtype_key += (dtypes[output.name], )
cache_key = (layer.type, attr_key, shape_key, value_key)
if dtypes is not None:
cache_key += (dtype_key, )
return cache_key
def get_trt_network(layer: trt.ILayer):
network = get_extra_attr(layer, "network")
assert network is not None
return network
def set_trt_network(layer: trt.ILayer, network: trt.INetworkDefinition):
set_extra_attr(layer, "network", network)
def get_updated_plugin(plugin_info: PluginInfo, updated_attrs):
fields = []
for field in plugin_info.pfc:
name = field.name
if name in updated_attrs:
field = trt.PluginField(name, updated_attrs[name], field.type)
else:
field = trt.PluginField(name, plugin_info.pfc_as_ndarray[name],
field.type)
fields.append(field)
pfc = trt.PluginFieldCollection(fields)
plugin = plugin_info.plugin_creator.create_plugin(plugin_info.plugin_name,
pfc)
new_plugin_info = PluginInfo(plugin_info.plugin_creator,
plugin_info.plugin_name, pfc)
return plugin, new_plugin_info
_builder_flags = threading.local()
_strongly_typed = threading.local()
def get_builder_flags():
return getattr(_builder_flags, 'value', 0)
def get_strongly_typed():
return getattr(_strongly_typed, 'value', False)
@contextlib.contextmanager
def current_flags(builder_flags, strongly_typed):
previous_builder_flags = get_builder_flags()
_builder_flags.value = builder_flags
previous_strongly_typed = get_strongly_typed()
_strongly_typed.value = strongly_typed
yield
_builder_flags.value = previous_builder_flags
_strongly_typed.value = previous_strongly_typed
def get_engine_information(engine_file) -> str:
with open(engine_file, "rb") as f:
engine_buffer = f.read()
runtime = trt.Runtime(logger.trt_logger)
engine = runtime.deserialize_cuda_engine(engine_buffer)
inspector = engine.create_engine_inspector()
return inspector.get_engine_information(trt.LayerInformationFormat.JSON)
def print_engine_info(engine_file) -> dict:
with open(engine_file, "rb") as f:
engine_buffer = f.read()
from tensorrt_llm.runtime.session import Session
Session.from_serialized_engine(engine_buffer)._print_engine_info()