TensorRT-LLMs/tensorrt_llm/auto_parallel/node_graph.py
Kaiyu Xie 4bb65f216f
Update TensorRT-LLM (#1274)
* Update TensorRT-LLM

---------

Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-03-12 18:15:52 +08:00

348 lines
16 KiB
Python

from typing import List
import pandas as pd
import tensorrt as trt
from .pipeline_graph import PipelineGraph
from .runtime_profiling import RuntimeProfiler
from .simplifier import GraphConfig, StageType
from .solver import CostGraph, Solver
from .tensor_parallel.activation_node import Activation
from .tensor_parallel.assertion_node import Assertion
from .tensor_parallel.cast_node import Cast
from .tensor_parallel.concatenation_node import Concatenation
from .tensor_parallel.constant_node import Constant
from .tensor_parallel.elementwise_node import ElementWise
from .tensor_parallel.fill_node import Fill
from .tensor_parallel.gather_node import Gather
from .tensor_parallel.identity_node import Identity
from .tensor_parallel.input_node import InputNode
from .tensor_parallel.matmul_node import MatrixMultiply
from .tensor_parallel.node import Node
from .tensor_parallel.normalization_node import Normalization
from .tensor_parallel.output_node import OuputNode
from .tensor_parallel.p2p_node import P2PNode, P2PType
from .tensor_parallel.plugin_node import PluginNode
from .tensor_parallel.plugin_nodes.gemm_node import GemmPlugin
from .tensor_parallel.plugin_nodes.gpt_attention_node import GPTAttentionPlugin
from .tensor_parallel.plugin_nodes.identity_node import IdentityPlugin
from .tensor_parallel.plugin_nodes.look_up_node import LookupPlugin
from .tensor_parallel.plugin_nodes.normalization_node import (LayernormPlugin,
RMSnormPlugin)
from .tensor_parallel.reduce_node import Reduce
from .tensor_parallel.select_node import Select
from .tensor_parallel.shape_node import Shape
from .tensor_parallel.shuffle_node import Shuffle
from .tensor_parallel.slice_node import Slice
from .tensor_parallel.softmax_node import SoftMax
from .tensor_parallel.unary_node import Unary
LAYER_TYPE_2_NODE_TYPE = {
trt.LayerType.ACTIVATION: Activation,
trt.LayerType.ASSERTION: Assertion,
trt.LayerType.CAST: Cast,
trt.LayerType.CONCATENATION: Concatenation,
trt.LayerType.CONSTANT: Constant,
trt.LayerType.ELEMENTWISE: ElementWise,
trt.LayerType.FILL: Fill,
trt.LayerType.GATHER: Gather,
trt.LayerType.IDENTITY: Identity,
trt.LayerType.MATRIX_MULTIPLY: MatrixMultiply,
trt.LayerType.NORMALIZATION: Normalization,
trt.LayerType.PLUGIN_V2: PluginNode,
trt.LayerType.REDUCE: Reduce,
trt.LayerType.SELECT: Select,
trt.LayerType.SHAPE: Shape,
trt.LayerType.SHUFFLE: Shuffle,
trt.LayerType.SLICE: Slice,
trt.LayerType.SOFTMAX: SoftMax,
trt.LayerType.UNARY: Unary,
}
# TODO: BertAttention/All Quant plugins
PLUGIN_LAYER_TYPE_2_NODE_TYPE = {
'GPTAttention': GPTAttentionPlugin,
'Gemm': GemmPlugin,
'Layernorm': LayernormPlugin,
'Rmsnorm': RMSnormPlugin,
'Lookup': LookupPlugin,
'Identity': IdentityPlugin,
}
class NodeGraph:
def __init__(self, graph: PipelineGraph):
self._nodes = {}
# construct nodes
for input in graph.inputs:
self._nodes[input.name] = InputNode(input)
for layer in graph.layers:
layer.to_base_class()
if "p2p_type" in layer.attrs:
self._nodes[layer.name] = P2PNode(layer)
elif layer.type == trt.LayerType.PLUGIN_V2:
layer.to_subclass()
plugin_type = layer.as_trt().plugin.plugin_type
layer.to_base_class()
if plugin_type in PLUGIN_LAYER_TYPE_2_NODE_TYPE:
node = PLUGIN_LAYER_TYPE_2_NODE_TYPE[plugin_type](layer)
else:
node = LAYER_TYPE_2_NODE_TYPE[layer.type](layer)
self._nodes[layer.name] = node
else:
node = LAYER_TYPE_2_NODE_TYPE[layer.type](layer)
self._nodes[layer.name] = node
for output in graph.outputs:
self._nodes[output.name] = OuputNode(output)
for node in self.nodes:
node.post_init(self)
node.node_runtime_profiler = RuntimeProfiler()
def get_node(self, name):
return self._nodes[name]
@property
def nodes(self) -> List[Node]:
return [*self._nodes.values()]
def assign_cost_weights(self, graph_config: GraphConfig):
layer_mapping = graph_config.graph_mapping.layer_mapping
for layer_name in layer_mapping.values():
node = self.get_node(layer_name)
node.sharding_weight += 1
node.resharding_weight += 1
same_spec_layer_mapping = graph_config.graph_mapping.same_spec_layer_mapping
for same_spec_layer_name, layer_name in same_spec_layer_mapping.items():
node = self.get_node(layer_name)
same_spec_node = self.get_node(same_spec_layer_name)
same_spec_node.sharding_weight = node.sharding_weight
same_spec_node.resharding_weight = node.resharding_weight
def set_slowest_stage(self, stage_type: StageType,
graph_config: GraphConfig):
num_micro_batches = graph_config.num_micro_batches
block_per_stage = graph_config.num_blocks // graph_config.num_stages
block_pipeline_weight = block_per_stage * (num_micro_batches - 1)
for node in self.nodes:
node.pipeline_weight = 0
node.cost_level = -1
if node.stage_type == StageType.START:
if stage_type == StageType.START:
node.pipeline_weight = num_micro_batches - 1
node.cost_level = 1
else:
node.cost_level = 0
if stage_type == StageType.START and node.in_start_block:
node.pipeline_weight = block_pipeline_weight
if node.stage_type == StageType.END:
if stage_type == StageType.END:
node.pipeline_weight = num_micro_batches - 1
node.cost_level = 1
else:
node.cost_level = 0
if stage_type == StageType.END and node.in_end_block:
node.pipeline_weight = block_pipeline_weight
if isinstance(node, P2PNode):
if (graph_config.has_cross_host
and node.p2p_type == P2PType.CROSS_HOST) or (
not graph_config.has_cross_host
and node.p2p_type == P2PType.CROSS_DEVICE):
if stage_type == StageType.BLOCK:
node.pipeline_weight += num_micro_batches - 1
node.cost_level = 1
else:
node.cost_level = 0
elif (graph_config.has_cross_device
and node.p2p_type == P2PType.CROSS_DEVICE) or (
not graph_config.has_cross_device
and node.p2p_type == P2PType.CROSS_HOST):
node.pipeline_weight += num_micro_batches - 1
if stage_type == StageType.BLOCK and node.in_slowest_block:
node.pipeline_weight = block_pipeline_weight
def get_cost_graph(self, lmesh):
leaf_strategies = []
for node in self.nodes:
if node.is_replicated:
node.set_strategy(None, lmesh)
else:
node.collect_strategies(lmesh)
for node in self.nodes:
strategies_vector = node.update_resharding_cost()
if len(strategies_vector) != 0:
leaf_strategies.append(strategies_vector)
cost_graph = CostGraph(leaf_strategies)
return cost_graph
def find_solution(self, cost_graph, memory_budget):
solver = Solver(cost_graph, memory_budget=memory_budget)
solution = solver.find_solution()[1]
graph_strategy = solution.node_best_strategy
for node_name, strategy in graph_strategy.items():
node = self._nodes[node_name]
for idx, pre_node in enumerate(node.predecessor_nodes):
if pre_node is None:
continue
if pre_node.node_name not in strategy.best_resharding_cost:
continue
strategy.best_resharding_cost[
idx] = strategy.best_resharding_cost[pre_node.node_name]
strategy.node_names[idx] = pre_node.node_name
for key in list(strategy.best_resharding_cost.keys()):
if isinstance(key, str):
del strategy.best_resharding_cost[key]
return solution
def visualize(self, name='pp_graph'):
with open(name + '.dot', 'w') as f:
f.write("digraph {\n")
'''
f.write(" // Value Nodes\n")
for name, tensor in self._tensors.items():
f.write(" \"{}\" [fillcolor = \"green\", label = \"{}\", shape = \"box\", style = \"filled\"];\n".format(name, tensor.shape))
'''
f.write(" // Operation Nodes\n")
for name, node in self._nodes.items():
fillcolor = 'white'
if 'MATRIX_MULTIPLY' in name:
fillcolor = 'green'
label = name
if len(node.outputs) > 0:
label = name + '\\n' + str(node.outputs[0].shape)
f.write(
" \"{}\" [fillcolor = \"{}\", label = \"{}\", shape = \"box\", style = \"filled\"];\n"
.format(name, fillcolor, label))
f.write(" // Edges\n")
for name, node in self._nodes.items():
for successor_node in node.successor_nodes:
if successor_node:
f.write(" \"{}\" ->\"{}\";\n".format(
name, successor_node.node_name))
f.write(" }\n")
def visualize_solution(self,
solution,
fname='pp_graph_solution',
ignore_shape_io=True):
with open(fname + '.dot', 'w') as f:
names, costs, block_ids = [], [], []
f.write("digraph {\n")
f.write(" // Operation Nodes\n")
for name, node in self._nodes.items():
if ignore_shape_io and node.layer is not None and node.layer.is_shape_io:
continue
cost = 0.0
fillcolor = 'white'
if 'MATRIX_MULTIPLY' in name or 'PLUGIN_V2_Gemm' in name:
fillcolor = 'orange'
elif '_same_spec' in name:
fillcolor = 'gray'
elif 'p2p_block' in name:
fillcolor = 'blue'
elif 'PLUGIN' in name:
fillcolor = 'yellow'
shape = 'box'
if 'output_node' == node.node_type or 'input_node' == node.node_type:
shape = 'ellipse'
fillcolor = 'green'
label = name + f'_block{node.building_block_id}_weight{node.sharding_weight}'
if len(node.inputs) > 0:
for idx, input in enumerate(node.inputs):
if not input:
continue
label = label + f'\\ninput{idx}_' + str(
input.shape) + f'_{input.dtype_str_size[0]}_'
if node.node_name in solution.node_best_strategy:
best_strategy = solution.node_best_strategy[
node.node_name]
shard_seq = str(
best_strategy.sharding_specs[f'input{idx}'].
sharding_sequence)
label = label + shard_seq
if idx not in best_strategy.best_resharding_cost:
continue
rcosts = best_strategy.best_resharding_cost[idx][0]
comm_action_sequence, resharding_cost = rcosts[
1], rcosts[2]
if len(comm_action_sequence) > 0:
label = label + '|'
for commspec in comm_action_sequence:
comm = [
commspec.comm_pattern, commspec.gather_dim,
commspec.shard_dim,
commspec.logical_process_axis
]
label = label + '->' + str(comm)
if resharding_cost > 0:
label = label + '_rcost{:.2}'.format(
resharding_cost)
cost = cost + resharding_cost
if len(node.outputs) > 0:
best_strategy = None
for idx, output in enumerate(node.outputs):
label = label + f'\\noutput{idx}_' + str(
output.shape) + f'_{output.dtype_str_size[0]}'
if node.node_name in solution.node_best_strategy:
best_strategy = solution.node_best_strategy[
node.node_name]
shard_seq = str(
best_strategy.sharding_specs[f'output{idx}'].
sharding_sequence)
comm = None
if f'output{idx}' in best_strategy.communication_actions:
commspec = best_strategy.communication_actions[
f'output{idx}']
comm = [
commspec.comm_pattern, commspec.gather_dim,
commspec.shard_dim,
commspec.logical_process_axis
]
label = label + '_' + shard_seq
if comm:
label = label + f' | {comm}'
if best_strategy:
cost = cost + best_strategy.sharding_cost + best_strategy.communication_cost
label = label + '| scost{:.2}'.format(
best_strategy.sharding_cost)
if best_strategy.communication_cost > 0:
label = label + ' | ccost{:.2}'.format(
best_strategy.communication_cost)
names.append(name)
costs.append(cost)
block_ids.append([
node.building_block_id, node.cost_level,
node.sharding_weight + node.pipeline_weight,
node.same_spec_id
])
f.write(
" \"{}\" [fillcolor = \"{}\", label = \"{}\", shape = \"{}\", style = \"filled\"];\n"
.format(name, fillcolor, label, shape))
f.write(" // Edges\n")
for name, node in self._nodes.items():
if ignore_shape_io and node.layer is not None and node.layer.is_shape_io:
continue
for successor_node in node.successor_nodes:
if successor_node:
if ignore_shape_io and successor_node.layer is not None and successor_node.layer.is_shape_io:
continue
f.write(" \"{}\" ->\"{}\";\n".format(
name, successor_node.node_name))
f.write(" }\n")
df = pd.DataFrame.from_dict({
'node':
names,
'cost':
costs,
'block_id': [block[0] for block in block_ids],
'cost_level': [block[1] for block in block_ids],
'sharding_weight': [block[2] for block in block_ids],
'same_spec_id': [block[3] for block in block_ids]
})
df['weight_cost'] = df['sharding_weight'] * df['cost']
df.to_csv(fname + '.csv')