from dataclasses import dataclass from enum import Enum from typing import Dict, List, Set import numpy as np import tensorrt as trt import torch from tensorrt_llm._common import _is_building from tensorrt_llm._utils import (trt_dtype_to_np, trt_dtype_to_str, trt_dtype_to_torch) from tensorrt_llm.logger import logger from .pipeline_graph import PipelineGraph from .utils import (get_builder_flags, get_cache_key, get_sorted_layer_ids, set_trt_network, to_base_class_layer, to_subclass_layer, to_trt_weights) class ShapeType(Enum): MIN = 0 OPT = 1 MAX = 2 _trt_to_type_dict = { trt.int64: int, trt.bool: bool, } def get_shape_layers(trt_network): shape_layers = set() for i in range(trt_network.num_layers): layer = trt_network.get_layer(i) if (layer.num_inputs > 0 and np.all([ layer.get_input(j).is_shape_tensor for j in range(layer.num_inputs) if layer.get_input(j) is not None ])) or (layer.num_outputs > 0 and np.all([ layer.get_output(j).is_shape_tensor for j in range(layer.num_outputs) ])): shape_layers.add(layer.name) return shape_layers def get_layers_in_shape_network(trt_network, shape_layers, sorted_layer_ids): layers = set() shape_tensors = set() for layer_id in reversed(sorted_layer_ids): layer = trt_network.get_layer(layer_id) in_shape_network = False if layer.name in shape_layers: in_shape_network = True else: for j in range(layer.num_outputs): output = layer.get_output(j) if output.name in shape_tensors: in_shape_network = True break if in_shape_network: layers.add(layer.name) for j in range(layer.num_inputs): input = layer.get_input(j) if input is not None: shape_tensors.add(input.name) return layers def get_shape_network(trt_network, shapes, values, sorted_layer_ids, profile=None, shape_type: ShapeType = ShapeType.OPT): shape_layers = get_shape_layers(trt_network) layers_in_shape_network = get_layers_in_shape_network( trt_network, shape_layers, sorted_layer_ids) shape_graph = PipelineGraph.create_graph() shape_network = shape_graph.as_trt() shape_builder = shape_network.builder shape_profile = shape_builder.create_optimization_profile() for i in range(trt_network.num_inputs): input = trt_network.get_input(i) shapes[input.name] = input.shape new_input = shape_graph.add_input(input) if profile is not None: if -1 in input.shape: shape = profile.get_shape(input.name) shape = shape[shape_type.value] shapes[input.name] = shape new_input.raw_shape = shape if input.is_shape_tensor: shape_values = profile.get_shape_input(input.name) value = shape_values[shape_type.value] values[input.name] = value shape_profile.set_shape_input(input.name, value, value, value) output_mapping = {} for layer_id in sorted_layer_ids: layer = trt_network.get_layer(layer_id) if layer.name in shape_layers: new_layer = shape_graph.add_layer(layer) for i in range(layer.num_outputs): output = layer.get_output(i) if output.dtype == trt.DataType.BOOL: proxy_layer = shape_network.add_cast( new_layer.as_trt().get_output(i), trt.DataType.INT32, ) proxy_output = proxy_layer.get_output(0) shape_graph.register_layer(proxy_layer) shape_graph.add_output_shape(proxy_output) output_mapping[proxy_output.name] = (output.name, output.dtype) else: shape_graph.add_output_shape(output) elif layer.name in layers_in_shape_network: if layer.type == trt.LayerType.CONSTANT: shape_graph.add_input(layer.get_output(0)) else: shape_graph.add_layer(layer) return shape_network, shape_profile, shape_layers, output_mapping def get_per_layer_graph( layer, shapes, values, updated_attrs=None, is_shape_io: bool = None, ): graph = PipelineGraph.create_graph() network = graph.as_trt() is_shape_layer = layer.num_inputs != 0 for i in range(layer.num_inputs): input = layer.get_input(i) if input is not None: shape = shapes[input.name] if (values.get(input.name) is not None and not isinstance(values[input.name], torch.Tensor)): value = values[input.name] weights = np.asarray(value, dtype=trt_dtype_to_np(input.dtype)) weights = to_trt_weights(weights) input_layer = network.add_constant(shape, weights) new_input = input_layer.get_output(0) new_input.name = input.name graph.register_layer(input_layer) elif graph.get_input(input.name) is None: new_input = graph.add_input(input) new_input.raw_shape = shapes[input.name] is_shape_layer = False new_layer = graph.add_layer( layer, updated_attrs=updated_attrs, ) output_mapping = {} if layer.type == trt.LayerType.SHAPE: is_shape_layer = True if layer.num_inputs == 0: is_shape_layer = False if is_shape_io is not None: is_shape_layer = is_shape_io for i in range(layer.num_outputs): output = layer.get_output(i) value = values.get(output.name) if value is not None and isinstance(value, torch.Tensor): is_output_shape = False elif is_shape_layer: is_output_shape = True else: is_output_shape = False if is_output_shape: if output.dtype == trt.DataType.BOOL: proxy_layer = network.add_cast( new_layer.as_trt().get_output(i), trt.DataType.INT32, ) proxy_output = proxy_layer.get_output(0) graph.register_layer(proxy_layer) output_mapping[proxy_output.name] = (output.name, output.dtype) output = proxy_output graph.add_output_shape(output) else: graph.add_output(output) return graph, output_mapping @_is_building def infer_shapes(network, shapes, values, profile=None): if network.num_outputs == 0: return builder = network.builder config = builder.create_builder_config() config.builder_optimization_level = 0 config.flags = get_builder_flags() profile = profile or builder.create_optimization_profile() config.add_optimization_profile(profile) plan = builder.build_serialized_network(network, config) if plan is None: raise RuntimeError( 'Engine building failed when inferring shapes, please check the error log.' ) runtime = trt.Runtime(logger.trt_logger) engine = runtime.deserialize_cuda_engine(plan) context = engine.create_execution_context() for i in range(network.num_inputs): input = network.get_input(i) if input.is_shape_tensor: value = values[input.name] context.set_shape_input(engine[input.name], value) for i in range(network.num_outputs): output = network.get_output(i) shape = context.get_tensor_shape(output.name) shapes[output.name] = shape if output.is_shape_tensor: if shape == [0]: values[output.name] = [] else: if shape == []: shape = [1] value = torch.empty( list(shape), dtype=trt_dtype_to_torch(output.dtype), device="cpu", ) values[output.name] = value context.set_tensor_address(output.name, value.data_ptr()) context.infer_shapes() assert context.all_binding_shapes_specified for i in range(network.num_outputs): output = network.get_output(i) if isinstance(values.get(output.name), torch.Tensor): values[output.name] = values[output.name].tolist() @dataclass class ShapeInfo: shapes: Dict[str, trt.Dims] values: Dict[str, List[int]] shape_layers: Set[str] max_shapes: Dict[str, trt.Dims] = None def set_constant_value(layer, values): to_subclass_layer(layer) output_name = layer.get_output(0).name weights = layer.weights if isinstance(weights, trt.Weights): weights = weights.numpy() values[output_name] = list(weights) to_base_class_layer(layer) def infer_per_layer_shapes( layer: trt.ILayer, shapes, values, cache=None, is_shape_io=False, ): if layer.type == trt.LayerType.CONSTANT: to_subclass_layer(layer) output_name = layer.get_output(0).name shape = layer.shape shapes[output_name] = shape if is_shape_io: set_constant_value(layer, values) to_base_class_layer(layer) return elif layer.type == trt.LayerType.SHAPE: input_name = layer.get_input(0).name output_name = layer.get_output(0).name shape = [*shapes[input_name]] shapes[output_name] = trt.Dims([len(shape)]) values[output_name] = shape return if cache is not None: cache_key = get_cache_key(layer, shapes, values) if cache_key in cache: output_shapes, output_values = cache[cache_key] for i in range(layer.num_outputs): output = layer.get_output(i) shapes[output.name] = output_shapes[i] if output_values[i] is not None: values[output.name] = output_values[i] return graph, output_mapping = get_per_layer_graph(layer, shapes, values) dtypes = [ trt_dtype_to_str(layer.get_input(i).dtype) for i in range(layer.num_inputs) ] layer_info = (f"type={cache_key[0]}, " f"attrs={dict(cache_key[1])}, " f"dtypes={dtypes}, " f"shapes={list(cache_key[2])}, " f"values={list(cache_key[3])}") logger.debug(f"infer shapes for layer {layer.name} ({layer_info})") try: infer_shapes(graph.as_trt(), shapes, values) except RuntimeError as e: raise RuntimeError( f"infer shapes failed for layer {layer.name} ({layer_info})") from e for proxy_output, (output, dtype) in output_mapping.items(): shapes[output] = shapes[proxy_output] del shapes[proxy_output] if proxy_output in values: values[output] = [ *map(_trt_to_type_dict[dtype], values[proxy_output]) ] del values[proxy_output] if cache is not None: logger.debug( f"shape inference cache miss, layer: {layer.name}, cache key: {cache_key}" ) output_shapes = [] output_values = [] for i in range(layer.num_outputs): output = layer.get_output(i) output_shapes.append(shapes[output.name]) output_values.append(values.get(output.name)) cache[cache_key] = (output_shapes, output_values) def get_shape_info(trt_network, profile, shape_type: ShapeType = ShapeType.OPT): shapes = {} values = {} sorted_layer_ids = get_sorted_layer_ids(trt_network) infer_shape_layers = False shape_network, shape_profile, shape_layers, output_mapping = get_shape_network( trt_network, shapes, values, sorted_layer_ids, profile=profile, shape_type=shape_type) try: infer_shapes(shape_network, shapes, values, shape_profile) for proxy_output, (output, dtype) in output_mapping.items(): shapes[output] = shapes[proxy_output] values[output] = [ *map(_trt_to_type_dict[dtype], values[proxy_output]) ] del shapes[proxy_output] del values[proxy_output] except RuntimeError: infer_shape_layers = True cache = {} for layer_id in sorted_layer_ids: layer = trt_network.get_layer(layer_id) is_shape_io = layer.name in shape_layers if is_shape_io and not infer_shape_layers: continue set_trt_network(layer, trt_network) infer_per_layer_shapes(layer, shapes, values, cache, is_shape_io=is_shape_io) return ShapeInfo(shapes, values, shape_layers)