# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import contextlib import hashlib from collections import defaultdict from dataclasses import dataclass, field from functools import lru_cache from typing import Any, Dict, Iterable, List, Optional, OrderedDict, Set import numpy as np import tensorrt as trt from ._common import set_network from .logger import logger from .plugin import PluginConfig class _UniqueNameGenerator(object): def __init__(self, prefix=''): self.ids = collections.defaultdict(int) self.prefix = prefix def __call__(self, key, module_name=''): if module_name != '': module_name = module_name.replace(".", "/") key = module_name + '/' + key tmp = self.ids[key] self.ids[key] += 1 return f"{self.prefix}{key}_{tmp}" class Network(object): def __init__(self, **kwargs): # intentionally use **kwargs, user should never call this ctor directly # use Builder.create_network() instead # Holds the removed layers and disable them in graph rewritings and other phases. # This is a hacky way since INetwork python API doesn't provide a way to remove a layer. # TODO: remove this when TensorRT provides a better way to remove a layer self._removed_layers: Set[str] = set() self.is_graph_altered = False from .graph_rewriting import FLayerInfoMemo self.flayer_memo = FLayerInfoMemo() # holds the functional metadata def _init(self, trt_network): self._trt_network = trt_network self._inputs = {} self._named_parameters = None # layer precision of a given scope, this is used together with precision(dtype) context manager self._dtype = None self._name_generator = _UniqueNameGenerator() self._plugin_config = PluginConfig() self._module_call_stack = _TrtLlmModuleCallStack() self._registered_ndarrays = [] self._strongly_typed = trt.INetworkDefinition.get_flag( self._trt_network, trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) return self @property def dtype(self) -> trt.DataType: return self._dtype @dtype.setter def dtype(self, dtype: trt.DataType): assert isinstance(dtype, trt.DataType) or dtype is None self._dtype = dtype @property def trt_network(self) -> trt.INetworkDefinition: return self._trt_network @property def plugin_config(self) -> PluginConfig: return self._plugin_config @property def strongly_typed(self) -> bool: return self._strongly_typed def _add_input(self, tensor, name, dtype, shape, dim_range: OrderedDict = None): assert isinstance(dtype, trt.DataType) tensor.trt_tensor = self.trt_network.add_input( name=name, shape=shape, dtype=dtype, ) if dim_range is not None: logger.debug( f'Add input: {name}, shape: {shape}, dtype: {dtype}, dimension names:{list(dim_range.keys())}' ) for i, dim_name in enumerate(dim_range.keys()): tensor.trt_tensor.set_dimension_name(i, str(dim_name)) else: logger.debug(f'Add input: {name}, shape: {shape}, dtype: {dtype}') self._inputs[name] = tensor def _mark_output(self, tensor, name, dtype): from .functional import cast if self.strongly_typed: if tensor.trt_tensor.dtype != dtype: # If stronglyTyped mode is enabled and inferred output dtype does not match desired dtype, add a cast. cast_output = cast(tensor, dtype) self.trt_network.mark_output(cast_output.trt_tensor) cast_output.trt_tensor.name = name else: # Otherwise, mark the tensor as network output. We should not set tensor dtype in stronglyTyped mode. self.trt_network.mark_output(tensor.trt_tensor) tensor.trt_tensor.name = name else: self.trt_network.mark_output(tensor.trt_tensor) tensor.trt_tensor.name = name tensor.trt_tensor.dtype = dtype logger.debug(f'Mark output: {name}, dtype: {dtype}') def set_named_parameters(self, named_parameters): self._named_parameters = named_parameters @property def named_parameters(self): return self._named_parameters def _set_layer_name(self, layer): layer_name = str(layer.type).split('.')[-1] current_module = self._module_call_stack.get_current_module() if layer.type == trt.LayerType.PLUGIN_V2: layer_name = '_'.join( [layer_name, str(layer.plugin.plugin_type).split('.')[-1]]) elif layer.type in [ trt.LayerType.UNARY, trt.LayerType.REDUCE, trt.LayerType.ELEMENTWISE ]: layer_name = '_'.join([layer_name, str(layer.op).split('.')[-1]]) layer.name = self._name_generator(layer_name, current_module) for idx in range(layer.num_outputs): # TRT initializes tensor names from the initial layer's name when the layer is created, # and does not update tensor names when layer name changed by application, needs to # change the tensor name to align with the new layer name for better debugging layer.get_output(idx).name = f"{layer.name}_output_{idx}" def register_ndarray(self, ndarray: np.ndarray) -> None: self._registered_ndarrays.append(ndarray) def get_inputs(self): ''' Get the inputs of the network. Returns: Iterable[Tensor] ''' return self._inputs.values() def get_outputs(self): ''' Get the outputs of the network. Returns: Iterable[Tensor] ''' from .functional import Tensor for i in range(self._trt_network.num_outputs): tensor = self._trt_network.get_output(i) yield Tensor(trt_tensor=tensor, network=self, is_network_input=False) def is_input(self, tensor) -> bool: ''' Tell if a tensor is a input of the network. Parameters: tensor: Union[Tensor, str, trt.ITensor] ''' from .functional import Tensor if isinstance(tensor, str): tensor_name = tensor elif isinstance(tensor, (trt.ITensor, Tensor)): tensor_name = tensor.name else: raise ValueError( f"tensor should be Tensor, str or ITensor, got {tensor}") return self._inputs.get(tensor_name, False) def is_output(self, tensor) -> bool: ''' Tell if a tensor is a output of the network. Parameters: tensor: Tensor ''' for i in range(self._trt_network.num_outputs): if tensor.trt_tensor is self._trt_network.get_output(i): return True return False def get_layers(self) -> Iterable["Layer"]: ''' Get all the layers of network. Returns: Iterable[Layer] ''' from .graph_rewriting import Layer for i in range(self._trt_network.num_layers): layer = Layer(network=self, trt_layer=self._trt_network.get_layer(i)) yield layer def get_layer_by_name(self, name: str) -> Optional["Layer"]: state = self._get_graph() return state.name_to_layer.get(name, None) def get_tensor_users(self, tensor) -> Iterable["Layer"]: ''' Get the layers those consumes this tensor. ''' state = self._get_graph() for layer in state.tensor_to_consumers[tensor]: yield layer def get_tensor_parent(self, tensor) -> Optional["Layer"]: ''' Get the layer that produces this tensor. ''' state = self._get_graph() return state.tensor_to_producer.get(tensor, None) def mark_removed_layer(self, layer: "Layer"): from .graph_rewriting import FLayerInfoMemo self._removed_layers.add(layer.name) # Try to delete the layer if it is a Plugin FLayerInfoMemo.instance().remove(layer.name) def is_removed_layer(self, layer: "Layer") -> bool: return layer.name in self._removed_layers @property def removed_layers(self) -> Iterable["Layer"]: for layer_name in self._removed_layers: layer = self.get_layer_by_name(layer_name) assert layer, "Invalid layer name" yield layer def to_dot(self, path=None) -> Optional[str]: ''' Get a graphviz representation of the network. NOTE, the graph might be redundancy since TRT's INetwork wont clean the unused inputs and layers automatically. TODO: add an flag to hide all the removed layers and their output tensors TODO: replace this when TensorRT provides a better way to get the graph of INetworkDefinition TODO: a little feature, add blocks in the figure to highlight the subgraphes of Modules Parameters: path: the path to save the graphviz file, if not provided, will return the graphviz source code ''' format = 'text' if not path else path.split('.')[-1] try: import graphviz except ImportError: logger.error( "Failed to import graphviz, please install graphviz to enable Network.to_dot()" ) return dot = graphviz.Digraph(comment='TensorRT Graph', format=format if format != 'text' else None) inputs_names = set([x.name for x in self.get_inputs()]) output_names = set([x.name for x in self.get_outputs()]) node_style = dict( shape='box', style='rounded,filled,bold', fontname='Arial', fillcolor='#ffffff', color='#303A3A', width='1.3', height='0.84', ) hl_node_style = dict( shape='box', style='rounded,filled,bold', fontname='Arial', fillcolor='lightblue', color='#303A3A', width='1.3', height='0.84', ) state = self._get_graph() nodes = set() tensor_to_alias = {} tensor_id = [0] def get_alias(tensor, tensor_id): if tensor not in tensor_to_alias: if (not tensor in inputs_names) and (not tensor in output_names): tensor_to_alias[tensor] = f"t{tensor_id[0]}" tensor_id[0] += 1 else: tensor_to_alias[tensor] = tensor return tensor_to_alias[tensor] def create_tensor_node(tensor: str): tensor_alias = get_alias(tensor, tensor_id) if tensor_alias not in nodes: dot.node(tensor_alias, tensor_alias, **node_style) nodes.add(tensor_alias) return tensor_alias def create_layer_node(layer: str): if layer not in nodes: dot.node(layer, layer, **hl_node_style) nodes.add(layer) for tensor, layer in state.tensor_to_producer.items(): tensor_alias = create_tensor_node(tensor.name) create_layer_node(layer.name) dot.edge(layer.name, tensor_alias) for tensor, layers in state.tensor_to_consumers.items(): tensor_alias = create_tensor_node(tensor.name) for layer in layers: create_layer_node(layer.name) dot.edge(tensor_alias, layer.name) if format == "text": return dot.source dot.render(path) def _get_graph(self) -> "Network._GraphState": ''' Get the graph of the network. Returns: Network._GraphState ''' return self._get_graph_impl(self._get_network_hash()) @lru_cache(maxsize=1) def _get_graph_impl(self, network_hash: bytes) -> "Network._GraphState": graph = Network._GraphState() graph.build(self) return graph @dataclass class _GraphState: # Tensor to Layers tensor_to_consumers: Dict[Any, List["Layer"]] = field( default_factory=lambda: defaultdict(list)) # Tensor to Layer tensor_to_producer: Dict[Any, "Layer"] = field(default_factory=dict) inputs: Dict[str, Any] = field(default_factory=OrderedDict) outputs: Dict[str, Any] = field(default_factory=OrderedDict) name_to_layer: Dict[str, "Layer"] = field(default_factory=dict) def build(self, network: "Network") -> None: from .graph_rewriting import Layer self.inputs = network.get_inputs() self.outputs = network.get_outputs() for layer in network.get_layers(): self.name_to_layer[layer.name] = Layer( network=network, trt_layer=layer.trt_layer) for i in range(layer.num_inputs): input_tensor = layer.get_inputs(i)[0] if input_tensor.is_trt_wrapper(): self.tensor_to_consumers[input_tensor].append(layer) for i in range(layer.num_outputs): output_tensor = layer.get_outputs(i)[0] if output_tensor.is_trt_wrapper(): self.tensor_to_producer[output_tensor] = layer def _get_network_hash(self, lightweight=True) -> bytes: # TODO: Ask TensorRT team to add a hash function for INetworkDefinition instead of using this hacky way num_layers = self.trt_network.num_layers # Some special layers, such as slice, may be associated with tensors that do not have the `trt_tensor` member. get_tensor_tag = lambda tensor: tensor.trt_tensor.name if tensor.is_trt_wrapper( ) else 'None' if lightweight and not self.is_graph_altered: return num_layers self.is_graph_altered = False data = hashlib.sha256() # network layer count data.update(str(num_layers).encode()) # network inputs data.update(','.join( [get_tensor_tag(tensor) for tensor in self.get_inputs()]).encode()) # network outputs data.update(','.join( [get_tensor_tag(tensor) for tensor in self.get_outputs()]).encode()) # layer names data.update(','.join( [layer.trt_layer.name for layer in self.get_layers()]).encode()) # layer -> output data.update(','.join([ f'{layer.trt_layer.name}->{get_tensor_tag(tensor)}' for layer in self.get_layers() for tensor in layer.get_outputs() ]).encode()) # input -> layer data.update(','.join([ f'{get_tensor_tag(tensor)}->{layer.trt_layer.name}' for layer in self.get_layers() for tensor in layer.get_inputs() ]).encode()) return data.hexdigest() @contextlib.contextmanager def net_guard(network): from ._common import net assert isinstance( network, Network ), f"Invalid network, can only guard Network instance, got: {network}" old_net = net set_network(network) yield set_network(old_net) class _TrtLlmModuleCallStack(object): call_stack = [] module_name_map = {} def __init__(self): super().__init__() self.mod_names_set = False def module_names_set(self): return self.mod_names_set def set_module_names(self, top_level_module): assert top_level_module, "Expected a top level module" for name, mod in top_level_module.named_modules( prefix=top_level_module._get_name()): if mod not in self.module_name_map: self.module_name_map[mod] = name self.mod_names_set = True return def get_current_module(self): mod_name = '' if len(self.call_stack): mod_name = self.call_stack[-1] return mod_name def get_mod_name(self, mod_obj): name = '' if mod_obj in self.module_name_map: name = self.module_name_map[mod_obj] return name def get_stack(self): return self.call_stack @contextlib.contextmanager def call_stack_mgr(self): call_stack = self.get_stack() try: yield call_stack finally: call_stack.pop()