TensorRT-LLMs/tensorrt_llm/auto_parallel/pipeline_graph.py
石晓伟 2a115dae84
Update TensorRT-LLM (#1793)
Co-authored-by: DreamGenX <x@dreamgen.com>
Co-authored-by: Ace-RR <78812427+Ace-RR@users.noreply.github.com>
Co-authored-by: bprus <39293131+bprus@users.noreply.github.com>
Co-authored-by: janpetrov <janpetrov@icloud.com>
2024-06-18 18:18:23 +08:00

1036 lines
37 KiB
Python

from dataclasses import dataclass
from typing import Dict, List, Optional
import numpy as np
import tensorrt as trt
import torch
from tensorrt_llm._utils import trt_dtype_to_str, trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.network import Network, get_plugin_info, set_plugin_info
from tensorrt_llm.plugin.plugin import PluginConfig
from tensorrt_llm.runtime.session import Session
from .utils import (current_flags, get_builder_flags, get_sorted_layer_ids,
get_strongly_typed, get_trt_network, set_trt_network,
to_base_class_layer, to_subclass_layer)
class Tensor:
def __init__(self, graph: "PipelineGraph"):
self._graph = graph
self._trt = None
self._shape = None
self._max_shape = None
self._value = None
self.producer: Layer = None
self.output_index = None
self.consumers = []
self.graph_input_index = -1
self.graph_output_index = -1
self.attrs = {}
@staticmethod
def from_trt(graph: "PipelineGraph", trt_tensor: trt.ITensor):
tensor = Tensor(graph)
tensor._trt = trt_tensor
return tensor
def as_trt(self) -> trt.ITensor:
return self._trt
def copy(self) -> "Tensor":
tensor = Tensor(self._graph)
tensor._trt = self._trt
tensor._shape = self._shape
tensor._max_shape = self._max_shape
tensor._value = self._value
tensor.producer = self.producer
tensor.output_index = self.output_index
tensor.consumers = [*self.consumers]
tensor.graph_input_index = self.graph_input_index
tensor.graph_output_index = self.graph_output_index
tensor.attrs = self.attrs.copy()
return tensor
@property
def graph(self) -> "PipelineGraph":
return self._graph
@property
def name(self) -> str:
return self._trt.name
@name.setter
def name(self, name: str):
old_name = self._trt.name
if name != old_name:
self._trt.name = name
self.graph._tensors[name] = self
del self.graph._tensors[old_name]
if self.is_graph_input:
self.graph._inputs[name] = self
del self.graph._inputs[old_name]
elif self.is_graph_output:
self.graph._outputs[name] = self
del self.graph._outputs[old_name]
@property
def shape(self):
return self._shape
@property
def max_shape(self):
return self._max_shape
@property
def raw_shape(self):
assert isinstance(self._trt, trt.ITensor)
return self._trt.shape
@shape.setter
def shape(self, shape):
self._shape = shape
@max_shape.setter
def max_shape(self, max_shape):
self._max_shape = max_shape
@raw_shape.setter
def raw_shape(self, raw_shape):
assert isinstance(self._trt, trt.ITensor)
self._trt.shape = raw_shape
@property
def value(self):
return self._value
@value.setter
def value(self, value):
self._value = value
@property
def dtype(self):
return self._trt.dtype
@property
def broadcast_across_batch(self):
return self._trt.broadcast_across_batch
@property
def dtype_size(self):
return self.dtype.itemsize
@property
def dtype_str(self):
return trt_dtype_to_str(self.dtype)
@property
def dtype_str_size(self):
return [trt_dtype_to_str(self.dtype), self.dtype.itemsize]
@property
def is_graph_input(self) -> bool:
return self.graph_input_index != -1
@property
def is_graph_output(self) -> bool:
return self.graph_output_index != -1
@property
def is_graph_io(self) -> bool:
return self.is_graph_input or self.is_graph_output
class Layer:
def __init__(self, graph):
self._graph = graph
self._trt = None
self._index = None
self._inputs = []
self._outputs = []
self._is_shape_io = False
self.attrs = {}
@staticmethod
def from_trt(graph, trt_layer, index):
layer = Layer(graph)
layer._trt = trt_layer
layer._index = index
for i in range(trt_layer.num_inputs):
input = trt_layer.get_input(i)
if input is not None:
layer._inputs.append(graph.get_tensor(input.name))
layer._inputs[i].consumers.append((layer, i))
else:
layer._inputs.append(None)
for i in range(trt_layer.num_outputs):
output = trt_layer.get_output(i)
layer._outputs.append(graph.get_tensor(output.name))
layer._outputs[i].producer = layer
layer._outputs[i].output_index = i
set_trt_network(trt_layer, graph.as_trt())
return layer
def as_trt(self) -> trt.ILayer:
return self._trt
@property
def graph(self) -> "PipelineGraph":
return self._graph
@property
def name(self) -> str:
return self._trt.name
@name.setter
def name(self, name: str):
old_name = self._trt.name
if name != old_name:
self._trt.name = name
self.graph._layers[name] = self
del self.graph._layers[old_name]
@property
def type(self) -> trt.LayerType:
return self._trt.type
@property
def index(self) -> int:
return self._index
@property
def inputs(self) -> List[Tensor]:
return self._inputs
@property
def outputs(self) -> List[Tensor]:
return self._outputs
def get_input(self, index: int) -> Tensor:
return self._inputs[index]
def get_output(self, index: int) -> Tensor:
return self._outputs[index]
@property
def num_inputs(self) -> int:
return self._trt.num_inputs
@property
def num_outputs(self) -> int:
return self._trt.num_outputs
@property
def is_shape_io(self) -> bool:
return self._is_shape_io
def to_subclass(self):
to_subclass_layer(self._trt)
def to_base_class(self):
to_base_class_layer(self._trt)
def assign_shapes(self, shapes, values):
for output in self.outputs:
output.shape = shapes[output.name]
output.value = values.get(output.name)
@dataclass
class GraphRunner:
session: Session
inputs: Dict[str, torch.Tensor]
outputs: Dict[str, torch.Tensor]
stream: torch.Stream
def run(self):
cuda_stream = self.stream.cuda_stream
assert self.session.run(self.inputs, self.outputs, cuda_stream)
self.stream.synchronize()
return self.outputs
class PipelineGraph:
def __init__(self):
self._trt = None
self._inputs: Dict[str, Tensor] = {}
self._outputs: Dict[str, Tensor] = {}
self._layers: Dict[str, Layer] = {}
self._tensors: Dict[str, Tensor] = {}
self._io_buffer_mapping = {}
self._unfilled_weights = {}
self._auto_parallel_config = None
self._plugin_config: PluginConfig = None
@staticmethod
def create_graph():
graph = PipelineGraph()
trt_builder = trt.Builder(logger.trt_logger)
explicit_batch_flag = 0
# Explicit batch flag will be deprecated in TRT 10
if "EXPLICIT_BATCH" in trt.NetworkDefinitionCreationFlag.__members__.keys(
):
explicit_batch_flag = 1 << int(
trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
if get_strongly_typed():
network = trt_builder.create_network(
explicit_batch_flag
| (1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)))
else:
network = trt_builder.create_network(explicit_batch_flag)
graph._trt = network
return graph
def _register_unfilled_weights(self, layer_name, weights, values):
self._unfilled_weights[layer_name] = (weights, values)
def _add_tensor(self, tensor, old_tensor, prefix):
if prefix is not None:
tensor.name = prefix + old_tensor.name
else:
tensor.name = old_tensor.name
tensor.location = old_tensor.location
if old_tensor.dynamic_range is not None:
tensor.dynamic_range = old_tensor.dynamic_range
if tensor.is_network_input:
tensor.shape = old_tensor.shape
for i in range(len(old_tensor.shape)):
name = old_tensor.get_dimension_name(i)
if name is not None:
tensor.set_dimension_name(i, name)
return self._register_tensor(tensor)
def _register_tensor(self, tensor):
wrapped_tensor = Tensor.from_trt(self, tensor)
assert tensor.name not in self._tensors
self._tensors[tensor.name] = wrapped_tensor
return wrapped_tensor
def add_input(self, tensor, prefix=None):
tensor_name = tensor.name
if prefix is not None:
tensor_name = prefix + tensor_name
input = self._trt.add_input(tensor_name, tensor.dtype, tensor.shape)
new_tensor = self._add_tensor(input, tensor, prefix)
new_tensor.graph_input_index = len(self._inputs)
self._inputs[tensor_name] = new_tensor
return new_tensor
def register_input(self, tensor, index=None):
if index is None:
index = self.num_inputs - 1
assert self._trt.get_input(index).name == tensor.name
wrapped_input = self._register_tensor(tensor)
wrapped_input.graph_input_index = index
self._inputs[tensor.name] = wrapped_input
return wrapped_input
def add_output(self, tensor, prefix=None):
tensor_name = tensor.name
if prefix is not None:
tensor_name = prefix + tensor_name
output = self.get_tensor(tensor_name)
output.graph_output_index = len(self._outputs)
trt_output = output.as_trt()
self._trt.mark_output(trt_output)
trt_output.dtype = tensor.dtype
self._outputs[tensor_name] = output
return output
def add_output_shape(self, tensor, prefix=None):
tensor_name = tensor.name
if prefix is not None:
tensor_name = prefix + tensor_name
output = self.get_tensor(tensor_name)
trt_output = output.as_trt()
self._trt.mark_output_for_shapes(trt_output)
trt_output.dtype = tensor.dtype
self._outputs[tensor_name] = output
return output
def add_layer(
self,
layer,
input_mapping=None,
prefix=None,
updated_attrs=None,
) -> Layer:
def get_input(i):
name = layer.get_input(i).name
if prefix is not None:
name = prefix + name
if input_mapping is not None and name in input_mapping:
name = input_mapping[name]
return self.get_tensor(name).as_trt()
network = self._trt
layer_type = layer.type
to_subclass_layer(layer)
if layer_type == trt.LayerType.ACTIVATION:
trt_input = get_input(0)
new_layer = network.add_activation(trt_input, layer.type)
new_layer.alpha = layer.alpha
new_layer.beta = layer.beta
elif layer_type == trt.LayerType.CONCATENATION:
trt_inputs = [get_input(i) for i in range(layer.num_inputs)]
new_layer = network.add_concatenation(trt_inputs)
new_layer.axis = layer.axis
elif layer_type == trt.LayerType.CONSTANT:
new_layer = network.add_constant(layer.shape, layer.weights)
elif layer_type == trt.LayerType.ELEMENTWISE:
new_layer = network.add_elementwise(get_input(0), get_input(1),
layer.op)
elif layer_type == trt.LayerType.FILL:
if layer.num_inputs >= 1 and layer.get_input(0) is not None:
shape_input = get_input(0)
shape = [1]
else:
shape_input = None
shape = layer.shape
new_layer = network.add_fill(shape, layer.operation, layer.to_type)
if shape_input is not None:
new_layer.set_input(0, shape_input)
if layer.num_inputs >= 1 and layer.get_input(0) is not None:
new_layer.set_input(0, get_input(0))
if layer.num_inputs >= 2 and layer.get_input(1) is not None:
new_layer.set_input(1, get_input(1))
else:
new_layer.alpha = layer.alpha
if layer.num_inputs >= 3 and layer.get_input(2) is not None:
new_layer.set_input(2, get_input(2))
else:
new_layer.beta = layer.beta
elif layer_type == trt.LayerType.GATHER:
trt_input = get_input(0)
trt_indices = get_input(1)
new_layer = network.add_gather_v2(trt_input, trt_indices,
layer.mode)
new_layer.axis = layer.axis
new_layer.num_elementwise_dims = layer.num_elementwise_dims
new_layer.mode = layer.mode
elif layer_type == trt.LayerType.MATRIX_MULTIPLY:
new_layer = network.add_matrix_multiply(get_input(0), layer.op0,
get_input(1), layer.op1)
elif layer_type == trt.LayerType.REDUCE:
new_layer = network.add_reduce(get_input(0), layer.op, layer.axes,
layer.keep_dims)
elif layer_type == trt.LayerType.SELECT:
trt_condition = get_input(0)
trt_then = get_input(1)
trt_else = get_input(2)
new_layer = network.add_select(trt_condition, trt_then, trt_else)
elif layer_type == trt.LayerType.SHUFFLE:
new_layer = network.add_shuffle(get_input(0))
new_layer.first_transpose = layer.first_transpose
new_layer.second_transpose = layer.second_transpose
new_layer.zero_is_placeholder = layer.zero_is_placeholder
if layer.num_inputs >= 2:
trt_reshape_dims_tensor = get_input(1)
new_layer.set_input(1, trt_reshape_dims_tensor)
else:
new_layer.reshape_dims = layer.reshape_dims
elif layer_type == trt.LayerType.SLICE:
if layer.num_inputs >= 2 and layer.get_input(1) is not None:
trt_start = get_input(1)
start = []
else:
trt_start = None
start = layer.start
if layer.num_inputs >= 3 and layer.get_input(2) is not None:
trt_shape = get_input(2)
shape = []
else:
trt_shape = None
shape = layer.shape
if layer.num_inputs >= 4 and layer.get_input(3) is not None:
trt_stride = get_input(3)
stride = []
else:
trt_stride = None
stride = layer.stride
new_layer = network.add_slice(get_input(0), start, shape, stride)
new_layer.mode = layer.mode
if trt_start is not None:
new_layer.set_input(1, trt_start)
if trt_shape is not None:
new_layer.set_input(2, trt_shape)
if trt_stride is not None:
new_layer.set_input(3, trt_stride)
elif layer_type == trt.LayerType.SOFTMAX:
new_layer = network.add_softmax(get_input(0))
new_layer.axes = layer.axes
elif layer_type == trt.LayerType.UNARY:
new_layer = network.add_unary(get_input(0), layer.op)
elif layer_type == trt.LayerType.SHAPE:
new_layer = network.add_shape(get_input(0))
elif layer_type == trt.LayerType.ASSERTION:
new_layer = network.add_assertion(get_input(0), layer.message)
elif layer_type == trt.LayerType.CAST:
new_layer = network.add_cast(get_input(0), layer.to_type)
elif layer_type == trt.LayerType.NORMALIZATION:
trt_input = get_input(0)
trt_scale = get_input(1)
trt_bias = get_input(2)
new_layer = network.add_normalization(trt_input, trt_scale,
trt_bias, layer.axes)
new_layer.epsilon = layer.epsilon
new_layer.num_groups = layer.num_groups
new_layer.compute_precision = layer.compute_precision
elif layer_type == trt.LayerType.IDENTITY:
new_layer = network.add_identity(get_input(0))
elif layer_type == trt.LayerType.PLUGIN_V2:
plugin = layer.plugin
updated = False
if (updated_attrs is not None
and updated_attrs.get("plugin") is not None):
plugin = updated_attrs["plugin"]
updated = True
updated_attrs = None
new_layer = network.add_plugin_v2(
[get_input(i) for i in range(layer.num_inputs)],
plugin,
)
else:
raise NotImplementedError(
"Unsupported layer type: {}".format(layer_type))
if updated_attrs is not None:
for attr_name, attr_value in updated_attrs.items():
setattr(new_layer, attr_name, attr_value)
to_base_class_layer(layer)
to_base_class_layer(new_layer)
layer_index = network.num_layers - 1
layer_name = layer.name
if prefix is not None:
layer_name = prefix + layer_name
new_layer.name = layer_name
new_layer.metadata = new_layer.name
if layer.precision_is_set:
new_layer.precision = layer.precision
for i in range(layer.num_outputs):
if layer.output_type_is_set(i):
new_layer.set_output_type(i, layer.get_output_type(i))
output = new_layer.get_output(i)
self._add_tensor(output, layer.get_output(i), prefix)
wrapped_layer = Layer.from_trt(self, new_layer, layer_index)
assert layer_name not in self._layers
self._layers[layer_name] = wrapped_layer
if layer_type == trt.LayerType.PLUGIN_V2:
if not updated:
plugin_info = get_plugin_info(get_trt_network(layer),
layer.name)
set_plugin_info(self.as_trt(), new_layer.name, plugin_info)
return wrapped_layer
def register_layer(self, layer, index=None):
if index is None:
index = self.num_layers - 1
assert self._trt.get_layer(index).name == layer.name
to_base_class_layer(layer)
for i in range(layer.num_outputs):
output = layer.get_output(i)
self._register_tensor(output)
wrapped_layer = Layer.from_trt(self, layer, index)
assert layer.name not in self._layers
self._layers[layer.name] = wrapped_layer
to_subclass_layer(layer)
return wrapped_layer
def get_runner(
self,
shapes=None,
values=None,
profile=None,
timing_cache=None,
opt_level=None,
) -> GraphRunner:
shapes = shapes or {}
values = values or {}
inputs = {}
outputs = {}
for input in self.inputs:
if input is not None:
value = values.get(input.name)
if value is None:
value = input.value
if value is not None:
if not isinstance(value, torch.Tensor):
value = torch.tensor(
value,
dtype=trt_dtype_to_torch(input.dtype),
device='cpu',
)
inputs[input.name] = value
else:
shape = shapes.get(input.name)
if shape is None:
shape = input.shape
assert shape is not None
inputs[input.name] = torch.empty(
tuple(shape),
dtype=trt_dtype_to_torch(input.dtype),
device=torch.cuda.current_device(),
)
if torch.is_floating_point(inputs[input.name]):
inputs[input.name].normal_()
# inputs[input.name][:] = random.choice([2, 3, 5, 7])
for output in self.outputs:
if output.as_trt().is_shape_tensor:
continue
if output.name in self._io_buffer_mapping:
input_name = self._io_buffer_mapping[output.name]
if input_name in inputs:
outputs[output.name] = inputs[input_name]
continue
value = values.get(output.name)
if value is not None and isinstance(value, torch.Tensor):
outputs[output.name] = value
else:
shape = shapes.get(output.name)
if shape is None:
shape = output.shape
assert shape is not None
outputs[output.name] = torch.empty(
tuple(shape),
dtype=trt_dtype_to_torch(output.dtype),
device=torch.cuda.current_device(),
)
network = self.as_trt()
config = network.builder.create_builder_config()
if opt_level is not None:
config.builder_optimization_level = opt_level
config.flags = get_builder_flags()
profile = profile or network.builder.create_optimization_profile()
profile_index = config.add_optimization_profile(profile)
if timing_cache is not None:
config.set_timing_cache(timing_cache, ignore_mismatch=False)
plan = network.builder.build_serialized_network(network, config)
if plan is None:
logger.error('Engine building failed, please check the error log.')
session = Session.from_serialized_engine(plan)
stream = torch.cuda.current_stream()
cuda_stream = stream.cuda_stream
context = session.context
context.set_optimization_profile_async(profile_index, cuda_stream)
runner = GraphRunner(session, inputs, outputs, stream)
return runner
def run(
self,
shapes=None,
values=None,
profile=None,
timing_cache=None,
opt_level=None,
):
return self.get_runner(
shapes,
values,
profile,
timing_cache,
opt_level,
).run()
def duplicate_graph(self):
graph = PipelineGraph.create_graph()
network = self.as_trt()
for i in range(network.num_inputs):
input = network.get_input(i)
graph.add_input(input)
sorted_layer_ids = get_sorted_layer_ids(network)
for i in sorted_layer_ids:
layer = network.get_layer(i)
graph.add_layer(layer)
for i in range(network.num_outputs):
output = network.get_output(i)
if output.is_shape_tensor:
graph.add_output_shape(output)
else:
graph.add_output(output)
return graph
@staticmethod
def from_trt(trt_network):
graph = PipelineGraph()
graph._trt = trt_network
# construct inputs and tensors
for i in range(trt_network.num_inputs):
trt_input = trt_network.get_input(i)
tensor = Tensor.from_trt(graph, trt_input)
tensor.graph_input_index = i
graph._tensors[tensor.name] = tensor
graph._inputs[tensor.name] = tensor
for i in range(trt_network.num_layers):
trt_layer = trt_network.get_layer(i)
for i in range(trt_layer.num_outputs):
trt_output = trt_layer.get_output(i)
tensor = Tensor.from_trt(graph, trt_output)
graph._tensors[tensor.name] = tensor
# construct layers and outputs
for i in range(trt_network.num_layers):
layer = Layer.from_trt(graph, trt_network.get_layer(i), i)
graph._layers[layer.name] = layer
for i in range(trt_network.num_outputs):
tensor_name = trt_network.get_output(i).name
output_tensor = graph._tensors[tensor_name]
output_tensor.graph_output_index = i
graph._outputs[tensor_name] = output_tensor
return graph
@staticmethod
def from_network(network: Network, builder_config):
builder_flags = builder_config.trt_builder_config.flags
with current_flags(builder_flags, network.strongly_typed):
graph = PipelineGraph.from_trt(network.trt_network)
graph.infer_shapes(network._generate_optimization_profiles()[-1])
return graph
def assign_shapes(self, shape_info=None, is_partial=False):
if shape_info is None:
for tensor in self.tensors:
tensor.shape = tensor.raw_shape
return
for tensor in self.tensors:
if tensor.name in shape_info.shapes:
tensor.shape = shape_info.shapes[tensor.name]
elif not is_partial:
raise ValueError(f"Cannot find shape for tensor: {tensor.name}")
if shape_info.max_shapes is not None:
if tensor.name in shape_info.max_shapes:
tensor.max_shape = shape_info.max_shapes[tensor.name]
elif not is_partial:
raise ValueError(
f"Cannot find max shape for tensor: {tensor.name}")
if tensor.name in shape_info.values:
tensor.value = shape_info.values[tensor.name]
for layer in self.layers:
if layer.name in shape_info.shape_layers:
layer._is_shape_io = True
def infer_shapes(self, profile=None):
from .shape_info import get_shape_info
shape_info = get_shape_info(self._trt, profile)
self.assign_shapes(shape_info)
def as_trt(self) -> trt.INetworkDefinition:
return self._trt
def get_input(self, name: str) -> Tensor:
return self._inputs.get(name)
def is_input(self, name: str) -> bool:
return name in self._inputs
@property
def inputs(self) -> List[Tensor]:
return [*self._inputs.values()]
@property
def num_inputs(self) -> int:
return self._trt.num_inputs
def get_output(self, name: str) -> Tensor:
return self._outputs.get(name)
def is_output(self, name: str) -> bool:
return name in self._outputs
@property
def outputs(self) -> List[Tensor]:
return [*self._outputs.values()]
@property
def num_outputs(self) -> int:
return self._trt.num_outputs
def get_tensor(self, name: str) -> Tensor:
return self._tensors.get(name)
@property
def tensors(self) -> List[Tensor]:
return [*self._tensors.values()]
def get_layer(self, name: str) -> Layer:
return self._layers.get(name)
@property
def layers(self) -> List[Layer]:
return [*self._layers.values()]
@property
def sorted_layers(self) -> List[Layer]:
sorted_layer_ids = get_sorted_layer_ids(self.as_trt())
return [
self.get_layer(self.as_trt().get_layer(layer_id).name)
for layer_id in sorted_layer_ids
]
@property
def num_layers(self) -> int:
return self._trt.num_layers
def to_dot(self,
path=None,
per_device=False,
per_block=False,
ignore_shape_io=False,
no_style=False,
extra_attrs=None) -> Optional[str]:
'''
Get a graphviz representation of the graph.
Parameters:
path: the path to save the graphviz file, if not provided, will return the graphviz source code
'''
try:
import graphviz
except ImportError:
logger.error(
"Failed to import graphviz, please install graphviz to enable PipelineGraph.to_dot()"
)
return
extra_attrs = extra_attrs or []
graph = graphviz.Digraph()
input_block_graph = graphviz.Digraph(name='cluster_inputs')
output_block_graph = graphviz.Digraph(name='cluster_outputs')
device_graphs = {}
block_graphs = {}
block_graph_mapping = []
tensor_names = set()
layer_names = set()
common_style = dict(fontname='Arial', )
node_style = dict(
**common_style,
style='rounded,filled,bold',
)
tensor_style = dict(
**node_style,
shape='ellipse',
fillcolor='white',
)
input_tensor_style = {**tensor_style, 'fillcolor': 'green'}
output_tensor_style = {**tensor_style, 'fillcolor': 'lightgreen'}
layer_style = dict(
**node_style,
shape='box',
fillcolor='white',
)
shape_layer_style = {**layer_style, 'fillcolor': 'grey'}
helper_layer_style = {**layer_style, 'fillcolor': 'lightgrey'}
graph_style = dict(
**common_style,
style='rounded',
penwidth='5',
fontsize='28',
)
device_graph_style = dict(
**graph_style,
color='cornflowerblue',
)
block_graph_style = dict(
**graph_style,
color='darkcyan',
)
input_block_style = dict(
**graph_style,
color='green',
)
output_block_style = dict(
**graph_style,
color='lightgreen',
)
if no_style:
device_graph_style = {}
block_graph_style = {}
input_block_style = {}
output_block_style = {}
input_block_graph.attr(label='inputs', **input_block_style)
output_block_graph.attr(label='outputs', **output_block_style)
def get_tensor_labels(tensor):
labels = []
if tensor.value is not None:
labels.append(f"value={tensor.value}")
else:
labels.append(f"dtype={tensor.dtype.name}{tensor.shape}")
for attr_name in extra_attrs:
if attr_name in tensor.attrs:
labels.append(f"{attr_name}={tensor.attrs[attr_name]}")
return labels
def get_device_graph(name):
if per_device and name.startswith('device'):
device_name = name.split('_')[0]
if device_name not in device_graphs:
device_graph = graphviz.Digraph(name='cluster_' +
device_name)
device_graph.attr(label=device_name, **device_graph_style)
device_graphs[device_name] = device_graph
return device_graphs[device_name]
return None
def get_block_graph(layer, current_graph):
if per_block and 'block_id' in layer.attrs:
block_label = f"block{layer.attrs['block_id']}"
if current_graph.name is not None:
graph_label = current_graph.name[len('cluster_'):]
else:
graph_label = ''
block_name = f"{graph_label}{block_label}"
if block_name not in block_graphs:
block_graph = graphviz.Digraph(name='cluster_' + block_name)
block_graph.attr(label=block_label, **block_graph_style)
block_graphs[block_name] = block_graph
block_graph_mapping.append((current_graph, block_graph))
return block_graphs[block_name]
return current_graph
for name, tensor in self._tensors.items():
style = tensor_style
if tensor.is_graph_input:
style = input_tensor_style
current_graph = input_block_graph
elif tensor.is_graph_output:
style = output_tensor_style
current_graph = output_block_graph
elif tensor.producer.num_outputs == 1:
continue
else:
current_graph = get_device_graph(name) or graph
current_graph = get_block_graph(tensor.producer, current_graph)
if no_style:
style = {}
labels = [name, *get_tensor_labels(tensor)]
content = "\n".join(labels)
current_graph.node(name, content, **style)
tensor_names.add(name)
for layer in self.sorted_layers:
name = layer.name
style = layer_style
if layer.is_shape_io:
if ignore_shape_io:
continue
style = shape_layer_style
elif layer.attrs.get("role", None) == "helper":
style = helper_layer_style
fillcolor = None
plugin_type = None
if layer.type == trt.LayerType.PLUGIN_V2:
fillcolor = 'yellow'
layer.to_subclass()
plugin_type = layer.as_trt().plugin.plugin_type
layer.to_base_class()
if layer.type == trt.LayerType.MATRIX_MULTIPLY or plugin_type == 'Gemm':
fillcolor = 'orange'
if fillcolor is not None:
style = {**style, 'fillcolor': fillcolor}
if no_style:
style = {}
layer_attrs = {}
layer_type = layer.type
layer.to_subclass()
if layer_type == trt.LayerType.CONSTANT:
if not layer.is_shape_io:
if trt.volume(layer.get_output(0).shape) <= 8:
weights = layer.as_trt().weights
if isinstance(weights, trt.Weights):
weights = weights.numpy()
value = np.array2string(
weights,
formatter={'float_kind': lambda x: f"{x:.2e}"})
layer_attrs['value'] = value
elif layer_type == trt.LayerType.SHUFFLE:
for attr_name in ['first_transpose', 'second_transpose']:
attr_value = getattr(layer.as_trt(), attr_name)
if tuple(attr_value) != (0, 1, 2, 3, 4, 5, 6, 7):
tensor = layer.get_input(
0
) if attr_name == 'first_transpose' else layer.get_output(
0)
layer_attrs[attr_name] = tuple(
attr_value)[:len(tensor.shape)]
if layer.num_inputs < 2:
attr_value = layer.as_trt().reshape_dims
layer_attrs['reshape_dims'] = attr_value
elif layer_type == trt.LayerType.SLICE:
if layer.num_inputs < 2 or layer.get_input(1) is None:
layer_attrs['start'] = layer.as_trt().start
if layer.num_inputs < 4 or layer.get_input(3) is None:
attr_value = layer.as_trt().stride
if attr_value != tuple(
[1] * len(layer.get_output(0).shape)):
layer_attrs['stride'] = attr_value
layer.to_base_class()
if layer.is_shape_io:
labels = [layer.type.name]
else:
labels = [name, layer.type.name]
for key, value in layer_attrs.items():
labels.append(f"{key}={value}")
for attr_name in extra_attrs:
if attr_name in layer.attrs:
labels.append(f"{attr_name}={layer.attrs[attr_name]}")
if layer.num_outputs == 1:
output = layer.get_output(0)
if output.name != f'{layer.name}_output_0':
labels.append(f"output={output.name}")
labels.extend(get_tensor_labels(output))
content = "\n".join(labels)
current_graph = get_device_graph(name) or graph
current_graph = get_block_graph(layer, current_graph)
current_graph.node(name, content, **style)
layer_names.add(name)
for index, input in enumerate(layer.inputs):
if input is not None:
if input.is_graph_input or input.producer.num_outputs > 1:
if input.name in tensor_names:
graph.edge(input.name, name, str(index))
else:
if input.producer.name in layer_names:
graph.edge(input.producer.name, name, str(index))
if layer.num_outputs > 1 or (layer.num_outputs == 1 and
layer.get_output(0).is_graph_output):
for index, output in enumerate(layer.outputs):
graph.edge(name, output.name, str(index))
graph.subgraph(input_block_graph)
graph.subgraph(output_block_graph)
for parent_graph, block_graph in block_graph_mapping:
parent_graph.subgraph(block_graph)
for device_graph in device_graphs.values():
graph.subgraph(device_graph)
if not path:
return graph.source
graph.save(path)
@staticmethod
def trt_to_dot(trt_network, path=None):
graph = PipelineGraph.from_trt(trt_network)
graph.assign_shapes()
dot = graph.to_dot(no_style=True)
if path is not None:
with open(path, "w") as f:
f.write(dot)
else:
return dot