mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: IbrahimAmin <ibrahimamin532@gmail.com> Co-authored-by: Fabian Joswig <fjosw@users.noreply.github.com> Co-authored-by: Pzzzzz <hello-cd.plus@hotmail.com> Co-authored-by: CoderHam <hemant@cohere.com> Co-authored-by: Konstantin Lopuhin <kostia.lopuhin@gmail.com>
309 lines
11 KiB
Python
309 lines
11 KiB
Python
import contextlib
|
|
import threading
|
|
|
|
try:
|
|
from types import NoneType
|
|
except ImportError:
|
|
NoneType = type(None)
|
|
from typing import ByteString, Iterable, MutableMapping
|
|
|
|
import tensorrt as trt
|
|
import torch
|
|
|
|
from tensorrt_llm._utils import get_extra_attr, np_dtype_to_trt, set_extra_attr
|
|
from tensorrt_llm.logger import logger
|
|
from tensorrt_llm.network import PluginInfo, get_plugin_info
|
|
|
|
LAYER_TYPE_2_CLASS = {
|
|
trt.LayerType.ACTIVATION: trt.IActivationLayer,
|
|
trt.LayerType.CONCATENATION: trt.IConcatenationLayer,
|
|
trt.LayerType.CONSTANT: trt.IConstantLayer,
|
|
trt.LayerType.ELEMENTWISE: trt.IElementWiseLayer,
|
|
trt.LayerType.FILL: trt.IFillLayer,
|
|
trt.LayerType.GATHER: trt.IGatherLayer,
|
|
trt.LayerType.MATRIX_MULTIPLY: trt.IMatrixMultiplyLayer,
|
|
trt.LayerType.REDUCE: trt.IReduceLayer,
|
|
trt.LayerType.SELECT: trt.ISelectLayer,
|
|
trt.LayerType.SHUFFLE: trt.IShuffleLayer,
|
|
trt.LayerType.SLICE: trt.ISliceLayer,
|
|
trt.LayerType.SOFTMAX: trt.ISoftMaxLayer,
|
|
trt.LayerType.UNARY: trt.IUnaryLayer,
|
|
trt.LayerType.SHAPE: trt.IShapeLayer,
|
|
trt.LayerType.ASSERTION: trt.IAssertionLayer,
|
|
trt.LayerType.CAST: trt.ICastLayer,
|
|
trt.LayerType.NORMALIZATION: trt.INormalizationLayer,
|
|
trt.LayerType.IDENTITY: trt.IIdentityLayer,
|
|
trt.LayerType.PLUGIN_V2: trt.IPluginV2Layer,
|
|
}
|
|
|
|
|
|
def to_subclass_layer(trt_layer):
|
|
trt_layer.__class__ = LAYER_TYPE_2_CLASS[trt_layer.type]
|
|
|
|
|
|
def to_base_class_layer(trt_layer):
|
|
trt_layer.__class__ = trt.ILayer
|
|
|
|
|
|
def to_trt_weights(ndarray):
|
|
weight = trt.Weights(
|
|
np_dtype_to_trt(ndarray.dtype),
|
|
ndarray.ctypes.data,
|
|
ndarray.size,
|
|
)
|
|
# Prevent numpy array from going out of weight's lifetime scope
|
|
set_extra_attr(weight, "numpy", ndarray)
|
|
return weight
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def silent_trt_logger():
|
|
min_severity = logger.trt_logger.min_severity
|
|
logger.trt_logger.min_severity = trt.Logger.ERROR
|
|
yield
|
|
logger.trt_logger.min_severity = min_severity
|
|
|
|
|
|
def compare_tensor(trt_tensor, new_trt_tensor):
|
|
assert trt_tensor.name == new_trt_tensor.name
|
|
assert trt_tensor.dtype == new_trt_tensor.dtype
|
|
assert tuple(trt_tensor.shape) == tuple(new_trt_tensor.shape)
|
|
assert trt_tensor.broadcast_across_batch == new_trt_tensor.broadcast_across_batch
|
|
assert trt_tensor.location == new_trt_tensor.location
|
|
assert trt_tensor.is_network_input == new_trt_tensor.is_network_input
|
|
assert trt_tensor.is_network_output == new_trt_tensor.is_network_output
|
|
assert trt_tensor.dynamic_range == new_trt_tensor.dynamic_range
|
|
assert trt_tensor.is_shape_tensor == new_trt_tensor.is_shape_tensor
|
|
assert trt_tensor.is_execution_tensor == new_trt_tensor.is_execution_tensor
|
|
assert trt_tensor.allowed_formats == new_trt_tensor.allowed_formats
|
|
|
|
|
|
def compare_network(trt_network, new_trt_network):
|
|
assert trt_network.num_inputs == new_trt_network.num_inputs
|
|
for i in range(trt_network.num_inputs):
|
|
input = trt_network.get_input(i)
|
|
new_input = new_trt_network.get_input(i)
|
|
compare_tensor(input, new_input)
|
|
assert trt_network.num_outputs == new_trt_network.num_outputs
|
|
for i in range(trt_network.num_outputs):
|
|
output = trt_network.get_output(i)
|
|
new_output = new_trt_network.get_output(i)
|
|
compare_tensor(output, new_output)
|
|
assert trt_network.num_layers == new_trt_network.num_layers
|
|
for index, new_index in zip(get_sorted_layer_ids(trt_network),
|
|
get_sorted_layer_ids(new_trt_network)):
|
|
layer = trt_network.get_layer(index)
|
|
new_layer = new_trt_network.get_layer(new_index)
|
|
assert layer.name == new_layer.name
|
|
assert layer.type == new_layer.type
|
|
assert layer.precision_is_set == new_layer.precision_is_set
|
|
assert layer.precision == new_layer.precision
|
|
assert layer.num_inputs == new_layer.num_inputs
|
|
for j in range(layer.num_inputs):
|
|
input = layer.get_input(j)
|
|
new_input = new_layer.get_input(j)
|
|
if input is None:
|
|
assert new_input is None
|
|
else:
|
|
assert new_input is not None
|
|
compare_tensor(input, new_input)
|
|
assert layer.num_outputs == new_layer.num_outputs
|
|
for j in range(layer.num_outputs):
|
|
output = layer.get_output(j)
|
|
new_output = new_layer.get_output(j)
|
|
compare_tensor(output, new_output)
|
|
assert layer.output_type_is_set(j) == new_layer.output_type_is_set(
|
|
j)
|
|
if layer.output_type_is_set(j):
|
|
assert layer.get_output_type(j) == new_layer.get_output_type(j)
|
|
|
|
|
|
def get_sorted_layer_ids(trt_network):
|
|
inputs = set()
|
|
for i in range(trt_network.num_inputs):
|
|
inputs.add(trt_network.get_input(i).name)
|
|
layer_ids = [*range(trt_network.num_layers)]
|
|
sorted_layer_ids = []
|
|
walked_tensors = set(inputs)
|
|
while len(layer_ids) > 0:
|
|
layer_id = layer_ids.pop(0)
|
|
layer = trt_network.get_layer(layer_id)
|
|
no_dependencies = True
|
|
for j in range(layer.num_inputs):
|
|
input = layer.get_input(j)
|
|
if input is None:
|
|
continue
|
|
if input.name in walked_tensors:
|
|
continue
|
|
else:
|
|
no_dependencies = False
|
|
break
|
|
if no_dependencies:
|
|
sorted_layer_ids.append(layer_id)
|
|
for j in range(layer.num_outputs):
|
|
output = layer.get_output(j)
|
|
if output is None:
|
|
continue
|
|
walked_tensors.add(output.name)
|
|
else:
|
|
layer_ids.append(layer_id)
|
|
assert len(sorted_layer_ids) == trt_network.num_layers
|
|
return sorted_layer_ids
|
|
|
|
|
|
def to_tuple(values):
|
|
if isinstance(values, (int, float, str, bool, NoneType, ByteString)):
|
|
return values
|
|
elif isinstance(values, (trt.Dims, trt.Permutation)):
|
|
if values.__len__() < 0:
|
|
return None
|
|
else:
|
|
return tuple(values)
|
|
elif isinstance(values, Iterable):
|
|
return tuple(to_tuple(v) for v in values)
|
|
elif isinstance(values, MutableMapping):
|
|
return tuple((k, to_tuple(v)) for k, v in values.items())
|
|
else:
|
|
return values
|
|
|
|
|
|
_base_layer_attr_names = set(dir(trt.ILayer))
|
|
|
|
|
|
def get_cache_key(layer, shapes, values, dtypes=None, updated_attrs=None):
|
|
updated_attrs = updated_attrs or {}
|
|
layer_type = layer.type
|
|
to_subclass_layer(layer)
|
|
attr_names = set(dir(layer)) - _base_layer_attr_names
|
|
if layer_type == trt.LayerType.CONSTANT:
|
|
attr_names.remove("weights")
|
|
elif layer_type == trt.LayerType.SHUFFLE:
|
|
if layer.num_inputs >= 2:
|
|
attr_names.remove("reshape_dims")
|
|
elif layer_type == trt.LayerType.SLICE:
|
|
if layer.num_inputs >= 2 and layer.get_input(1) is not None:
|
|
attr_names.remove("start")
|
|
if layer.num_inputs >= 3 and layer.get_input(2) is not None:
|
|
attr_names.remove("shape")
|
|
if layer.num_inputs >= 4 and layer.get_input(3) is not None:
|
|
attr_names.remove("stride")
|
|
elif layer_type == trt.LayerType.FILL:
|
|
attr_names.remove("is_alpha_beta_int64")
|
|
if layer.num_inputs >= 1 and layer.get_input(0) is not None:
|
|
attr_names.remove("shape")
|
|
if layer.num_inputs >= 2 and layer.get_input(1) is not None:
|
|
attr_names.remove("alpha")
|
|
if layer.num_inputs >= 3 and layer.get_input(2) is not None:
|
|
attr_names.remove("beta")
|
|
if layer_type != trt.LayerType.PLUGIN_V2:
|
|
attr_key = tuple(
|
|
(name, to_tuple(updated_attrs.get(name) or getattr(layer, name)))
|
|
for name in sorted(attr_names))
|
|
else:
|
|
network = get_trt_network(layer)
|
|
plugin_info = get_plugin_info(network, layer.name)
|
|
assert plugin_info is not None, f"layer {layer.name} does not register plugin info"
|
|
attr_key = tuple(
|
|
(name, tuple(updated_attrs.get(name) or data))
|
|
for name, data in sorted(plugin_info.pfc_as_list.items()))
|
|
to_base_class_layer(layer)
|
|
shape_key = ()
|
|
value_key = ()
|
|
dtype_key = ()
|
|
for i in range(layer.num_inputs):
|
|
input = layer.get_input(i)
|
|
if input is not None:
|
|
shape_key += (tuple(shapes[input.name]), )
|
|
if input.name in values:
|
|
value = values[input.name]
|
|
# All torch tensors are derived from input shapes and pfc,
|
|
# thus we ignore them in cache key
|
|
if isinstance(value, torch.Tensor):
|
|
value = None
|
|
else:
|
|
value = tuple(value)
|
|
value_key += (value, )
|
|
else:
|
|
value_key += (None, )
|
|
if dtypes is not None:
|
|
dtype_key += (dtypes[input.name], )
|
|
else:
|
|
shape_key += (None, )
|
|
value_key += (None, )
|
|
dtype_key += (None, )
|
|
if dtypes is not None:
|
|
for i in range(layer.num_outputs):
|
|
output = layer.get_output(i)
|
|
dtype_key += (dtypes[output.name], )
|
|
cache_key = (layer.type, attr_key, shape_key, value_key)
|
|
if dtypes is not None:
|
|
cache_key += (dtype_key, )
|
|
return cache_key
|
|
|
|
|
|
def get_trt_network(layer: trt.ILayer):
|
|
network = get_extra_attr(layer, "network")
|
|
assert network is not None
|
|
return network
|
|
|
|
|
|
def set_trt_network(layer: trt.ILayer, network: trt.INetworkDefinition):
|
|
set_extra_attr(layer, "network", network)
|
|
|
|
|
|
def get_updated_plugin(plugin_info: PluginInfo, updated_attrs):
|
|
fields = []
|
|
for field in plugin_info.pfc:
|
|
name = field.name
|
|
if name in updated_attrs:
|
|
field = trt.PluginField(name, updated_attrs[name], field.type)
|
|
else:
|
|
field = trt.PluginField(name, plugin_info.pfc_as_ndarray[name],
|
|
field.type)
|
|
fields.append(field)
|
|
pfc = trt.PluginFieldCollection(fields)
|
|
plugin = plugin_info.plugin_creator.create_plugin(plugin_info.plugin_name,
|
|
pfc)
|
|
new_plugin_info = PluginInfo(plugin_info.plugin_creator,
|
|
plugin_info.plugin_name, pfc)
|
|
return plugin, new_plugin_info
|
|
|
|
|
|
_builder_flags = threading.local()
|
|
_strongly_typed = threading.local()
|
|
|
|
|
|
def get_builder_flags():
|
|
return getattr(_builder_flags, 'value', 0)
|
|
|
|
|
|
def get_strongly_typed():
|
|
return getattr(_strongly_typed, 'value', False)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def current_flags(builder_flags, strongly_typed):
|
|
previous_builder_flags = get_builder_flags()
|
|
_builder_flags.value = builder_flags
|
|
previous_strongly_typed = get_strongly_typed()
|
|
_strongly_typed.value = strongly_typed
|
|
yield
|
|
_builder_flags.value = previous_builder_flags
|
|
_strongly_typed.value = previous_strongly_typed
|
|
|
|
|
|
def get_engine_information(engine_file) -> str:
|
|
with open(engine_file, "rb") as f:
|
|
engine_buffer = f.read()
|
|
runtime = trt.Runtime(logger.trt_logger)
|
|
engine = runtime.deserialize_cuda_engine(engine_buffer)
|
|
inspector = engine.create_engine_inspector()
|
|
return inspector.get_engine_information(trt.LayerInformationFormat.JSON)
|
|
|
|
|
|
def print_engine_info(engine_file) -> dict:
|
|
with open(engine_file, "rb") as f:
|
|
engine_buffer = f.read()
|
|
from tensorrt_llm.runtime.session import Session
|
|
Session.from_serialized_engine(engine_buffer)._print_engine_info()
|