mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
* Update TensorRT-LLM --------- Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com> Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
348 lines
16 KiB
Python
348 lines
16 KiB
Python
from typing import List
|
|
|
|
import pandas as pd
|
|
import tensorrt as trt
|
|
|
|
from .pipeline_graph import PipelineGraph
|
|
from .runtime_profiling import RuntimeProfiler
|
|
from .simplifier import GraphConfig, StageType
|
|
from .solver import CostGraph, Solver
|
|
from .tensor_parallel.activation_node import Activation
|
|
from .tensor_parallel.assertion_node import Assertion
|
|
from .tensor_parallel.cast_node import Cast
|
|
from .tensor_parallel.concatenation_node import Concatenation
|
|
from .tensor_parallel.constant_node import Constant
|
|
from .tensor_parallel.elementwise_node import ElementWise
|
|
from .tensor_parallel.fill_node import Fill
|
|
from .tensor_parallel.gather_node import Gather
|
|
from .tensor_parallel.identity_node import Identity
|
|
from .tensor_parallel.input_node import InputNode
|
|
from .tensor_parallel.matmul_node import MatrixMultiply
|
|
from .tensor_parallel.node import Node
|
|
from .tensor_parallel.normalization_node import Normalization
|
|
from .tensor_parallel.output_node import OuputNode
|
|
from .tensor_parallel.p2p_node import P2PNode, P2PType
|
|
from .tensor_parallel.plugin_node import PluginNode
|
|
from .tensor_parallel.plugin_nodes.gemm_node import GemmPlugin
|
|
from .tensor_parallel.plugin_nodes.gpt_attention_node import GPTAttentionPlugin
|
|
from .tensor_parallel.plugin_nodes.identity_node import IdentityPlugin
|
|
from .tensor_parallel.plugin_nodes.look_up_node import LookupPlugin
|
|
from .tensor_parallel.plugin_nodes.normalization_node import (LayernormPlugin,
|
|
RMSnormPlugin)
|
|
from .tensor_parallel.reduce_node import Reduce
|
|
from .tensor_parallel.select_node import Select
|
|
from .tensor_parallel.shape_node import Shape
|
|
from .tensor_parallel.shuffle_node import Shuffle
|
|
from .tensor_parallel.slice_node import Slice
|
|
from .tensor_parallel.softmax_node import SoftMax
|
|
from .tensor_parallel.unary_node import Unary
|
|
|
|
LAYER_TYPE_2_NODE_TYPE = {
|
|
trt.LayerType.ACTIVATION: Activation,
|
|
trt.LayerType.ASSERTION: Assertion,
|
|
trt.LayerType.CAST: Cast,
|
|
trt.LayerType.CONCATENATION: Concatenation,
|
|
trt.LayerType.CONSTANT: Constant,
|
|
trt.LayerType.ELEMENTWISE: ElementWise,
|
|
trt.LayerType.FILL: Fill,
|
|
trt.LayerType.GATHER: Gather,
|
|
trt.LayerType.IDENTITY: Identity,
|
|
trt.LayerType.MATRIX_MULTIPLY: MatrixMultiply,
|
|
trt.LayerType.NORMALIZATION: Normalization,
|
|
trt.LayerType.PLUGIN_V2: PluginNode,
|
|
trt.LayerType.REDUCE: Reduce,
|
|
trt.LayerType.SELECT: Select,
|
|
trt.LayerType.SHAPE: Shape,
|
|
trt.LayerType.SHUFFLE: Shuffle,
|
|
trt.LayerType.SLICE: Slice,
|
|
trt.LayerType.SOFTMAX: SoftMax,
|
|
trt.LayerType.UNARY: Unary,
|
|
}
|
|
# TODO: BertAttention/All Quant plugins
|
|
PLUGIN_LAYER_TYPE_2_NODE_TYPE = {
|
|
'GPTAttention': GPTAttentionPlugin,
|
|
'Gemm': GemmPlugin,
|
|
'Layernorm': LayernormPlugin,
|
|
'Rmsnorm': RMSnormPlugin,
|
|
'Lookup': LookupPlugin,
|
|
'Identity': IdentityPlugin,
|
|
}
|
|
|
|
|
|
class NodeGraph:
|
|
|
|
def __init__(self, graph: PipelineGraph):
|
|
self._nodes = {}
|
|
|
|
# construct nodes
|
|
for input in graph.inputs:
|
|
self._nodes[input.name] = InputNode(input)
|
|
for layer in graph.layers:
|
|
layer.to_base_class()
|
|
if "p2p_type" in layer.attrs:
|
|
self._nodes[layer.name] = P2PNode(layer)
|
|
elif layer.type == trt.LayerType.PLUGIN_V2:
|
|
layer.to_subclass()
|
|
plugin_type = layer.as_trt().plugin.plugin_type
|
|
layer.to_base_class()
|
|
if plugin_type in PLUGIN_LAYER_TYPE_2_NODE_TYPE:
|
|
node = PLUGIN_LAYER_TYPE_2_NODE_TYPE[plugin_type](layer)
|
|
else:
|
|
node = LAYER_TYPE_2_NODE_TYPE[layer.type](layer)
|
|
self._nodes[layer.name] = node
|
|
else:
|
|
node = LAYER_TYPE_2_NODE_TYPE[layer.type](layer)
|
|
self._nodes[layer.name] = node
|
|
for output in graph.outputs:
|
|
self._nodes[output.name] = OuputNode(output)
|
|
for node in self.nodes:
|
|
node.post_init(self)
|
|
node.node_runtime_profiler = RuntimeProfiler()
|
|
|
|
def get_node(self, name):
|
|
return self._nodes[name]
|
|
|
|
@property
|
|
def nodes(self) -> List[Node]:
|
|
return [*self._nodes.values()]
|
|
|
|
def assign_cost_weights(self, graph_config: GraphConfig):
|
|
layer_mapping = graph_config.graph_mapping.layer_mapping
|
|
for layer_name in layer_mapping.values():
|
|
node = self.get_node(layer_name)
|
|
node.sharding_weight += 1
|
|
node.resharding_weight += 1
|
|
same_spec_layer_mapping = graph_config.graph_mapping.same_spec_layer_mapping
|
|
for same_spec_layer_name, layer_name in same_spec_layer_mapping.items():
|
|
node = self.get_node(layer_name)
|
|
same_spec_node = self.get_node(same_spec_layer_name)
|
|
same_spec_node.sharding_weight = node.sharding_weight
|
|
same_spec_node.resharding_weight = node.resharding_weight
|
|
|
|
def set_slowest_stage(self, stage_type: StageType,
|
|
graph_config: GraphConfig):
|
|
num_micro_batches = graph_config.num_micro_batches
|
|
block_per_stage = graph_config.num_blocks // graph_config.num_stages
|
|
block_pipeline_weight = block_per_stage * (num_micro_batches - 1)
|
|
for node in self.nodes:
|
|
node.pipeline_weight = 0
|
|
node.cost_level = -1
|
|
if node.stage_type == StageType.START:
|
|
if stage_type == StageType.START:
|
|
node.pipeline_weight = num_micro_batches - 1
|
|
node.cost_level = 1
|
|
else:
|
|
node.cost_level = 0
|
|
if stage_type == StageType.START and node.in_start_block:
|
|
node.pipeline_weight = block_pipeline_weight
|
|
if node.stage_type == StageType.END:
|
|
if stage_type == StageType.END:
|
|
node.pipeline_weight = num_micro_batches - 1
|
|
node.cost_level = 1
|
|
else:
|
|
node.cost_level = 0
|
|
if stage_type == StageType.END and node.in_end_block:
|
|
node.pipeline_weight = block_pipeline_weight
|
|
if isinstance(node, P2PNode):
|
|
if (graph_config.has_cross_host
|
|
and node.p2p_type == P2PType.CROSS_HOST) or (
|
|
not graph_config.has_cross_host
|
|
and node.p2p_type == P2PType.CROSS_DEVICE):
|
|
if stage_type == StageType.BLOCK:
|
|
node.pipeline_weight += num_micro_batches - 1
|
|
node.cost_level = 1
|
|
else:
|
|
node.cost_level = 0
|
|
elif (graph_config.has_cross_device
|
|
and node.p2p_type == P2PType.CROSS_DEVICE) or (
|
|
not graph_config.has_cross_device
|
|
and node.p2p_type == P2PType.CROSS_HOST):
|
|
node.pipeline_weight += num_micro_batches - 1
|
|
if stage_type == StageType.BLOCK and node.in_slowest_block:
|
|
node.pipeline_weight = block_pipeline_weight
|
|
|
|
def get_cost_graph(self, lmesh):
|
|
leaf_strategies = []
|
|
for node in self.nodes:
|
|
if node.is_replicated:
|
|
node.set_strategy(None, lmesh)
|
|
else:
|
|
node.collect_strategies(lmesh)
|
|
for node in self.nodes:
|
|
strategies_vector = node.update_resharding_cost()
|
|
if len(strategies_vector) != 0:
|
|
leaf_strategies.append(strategies_vector)
|
|
cost_graph = CostGraph(leaf_strategies)
|
|
return cost_graph
|
|
|
|
def find_solution(self, cost_graph, memory_budget):
|
|
solver = Solver(cost_graph, memory_budget=memory_budget)
|
|
solution = solver.find_solution()[1]
|
|
|
|
graph_strategy = solution.node_best_strategy
|
|
for node_name, strategy in graph_strategy.items():
|
|
node = self._nodes[node_name]
|
|
for idx, pre_node in enumerate(node.predecessor_nodes):
|
|
if pre_node is None:
|
|
continue
|
|
if pre_node.node_name not in strategy.best_resharding_cost:
|
|
continue
|
|
strategy.best_resharding_cost[
|
|
idx] = strategy.best_resharding_cost[pre_node.node_name]
|
|
strategy.node_names[idx] = pre_node.node_name
|
|
for key in list(strategy.best_resharding_cost.keys()):
|
|
if isinstance(key, str):
|
|
del strategy.best_resharding_cost[key]
|
|
|
|
return solution
|
|
|
|
def visualize(self, name='pp_graph'):
|
|
with open(name + '.dot', 'w') as f:
|
|
f.write("digraph {\n")
|
|
'''
|
|
f.write(" // Value Nodes\n")
|
|
for name, tensor in self._tensors.items():
|
|
f.write(" \"{}\" [fillcolor = \"green\", label = \"{}\", shape = \"box\", style = \"filled\"];\n".format(name, tensor.shape))
|
|
'''
|
|
f.write(" // Operation Nodes\n")
|
|
for name, node in self._nodes.items():
|
|
fillcolor = 'white'
|
|
if 'MATRIX_MULTIPLY' in name:
|
|
fillcolor = 'green'
|
|
label = name
|
|
if len(node.outputs) > 0:
|
|
label = name + '\\n' + str(node.outputs[0].shape)
|
|
f.write(
|
|
" \"{}\" [fillcolor = \"{}\", label = \"{}\", shape = \"box\", style = \"filled\"];\n"
|
|
.format(name, fillcolor, label))
|
|
f.write(" // Edges\n")
|
|
for name, node in self._nodes.items():
|
|
for successor_node in node.successor_nodes:
|
|
if successor_node:
|
|
f.write(" \"{}\" ->\"{}\";\n".format(
|
|
name, successor_node.node_name))
|
|
f.write(" }\n")
|
|
|
|
def visualize_solution(self,
|
|
solution,
|
|
fname='pp_graph_solution',
|
|
ignore_shape_io=True):
|
|
with open(fname + '.dot', 'w') as f:
|
|
names, costs, block_ids = [], [], []
|
|
f.write("digraph {\n")
|
|
f.write(" // Operation Nodes\n")
|
|
for name, node in self._nodes.items():
|
|
if ignore_shape_io and node.layer is not None and node.layer.is_shape_io:
|
|
continue
|
|
cost = 0.0
|
|
fillcolor = 'white'
|
|
if 'MATRIX_MULTIPLY' in name or 'PLUGIN_V2_Gemm' in name:
|
|
fillcolor = 'orange'
|
|
elif '_same_spec' in name:
|
|
fillcolor = 'gray'
|
|
elif 'p2p_block' in name:
|
|
fillcolor = 'blue'
|
|
elif 'PLUGIN' in name:
|
|
fillcolor = 'yellow'
|
|
|
|
shape = 'box'
|
|
if 'output_node' == node.node_type or 'input_node' == node.node_type:
|
|
shape = 'ellipse'
|
|
fillcolor = 'green'
|
|
|
|
label = name + f'_block{node.building_block_id}_weight{node.sharding_weight}'
|
|
if len(node.inputs) > 0:
|
|
for idx, input in enumerate(node.inputs):
|
|
if not input:
|
|
continue
|
|
label = label + f'\\ninput{idx}_' + str(
|
|
input.shape) + f'_{input.dtype_str_size[0]}_'
|
|
if node.node_name in solution.node_best_strategy:
|
|
best_strategy = solution.node_best_strategy[
|
|
node.node_name]
|
|
shard_seq = str(
|
|
best_strategy.sharding_specs[f'input{idx}'].
|
|
sharding_sequence)
|
|
label = label + shard_seq
|
|
if idx not in best_strategy.best_resharding_cost:
|
|
continue
|
|
rcosts = best_strategy.best_resharding_cost[idx][0]
|
|
comm_action_sequence, resharding_cost = rcosts[
|
|
1], rcosts[2]
|
|
if len(comm_action_sequence) > 0:
|
|
label = label + '|'
|
|
for commspec in comm_action_sequence:
|
|
comm = [
|
|
commspec.comm_pattern, commspec.gather_dim,
|
|
commspec.shard_dim,
|
|
commspec.logical_process_axis
|
|
]
|
|
label = label + '->' + str(comm)
|
|
if resharding_cost > 0:
|
|
label = label + '_rcost{:.2}'.format(
|
|
resharding_cost)
|
|
cost = cost + resharding_cost
|
|
if len(node.outputs) > 0:
|
|
best_strategy = None
|
|
for idx, output in enumerate(node.outputs):
|
|
label = label + f'\\noutput{idx}_' + str(
|
|
output.shape) + f'_{output.dtype_str_size[0]}'
|
|
if node.node_name in solution.node_best_strategy:
|
|
best_strategy = solution.node_best_strategy[
|
|
node.node_name]
|
|
shard_seq = str(
|
|
best_strategy.sharding_specs[f'output{idx}'].
|
|
sharding_sequence)
|
|
comm = None
|
|
if f'output{idx}' in best_strategy.communication_actions:
|
|
commspec = best_strategy.communication_actions[
|
|
f'output{idx}']
|
|
comm = [
|
|
commspec.comm_pattern, commspec.gather_dim,
|
|
commspec.shard_dim,
|
|
commspec.logical_process_axis
|
|
]
|
|
label = label + '_' + shard_seq
|
|
if comm:
|
|
label = label + f' | {comm}'
|
|
if best_strategy:
|
|
cost = cost + best_strategy.sharding_cost + best_strategy.communication_cost
|
|
label = label + '| scost{:.2}'.format(
|
|
best_strategy.sharding_cost)
|
|
if best_strategy.communication_cost > 0:
|
|
label = label + ' | ccost{:.2}'.format(
|
|
best_strategy.communication_cost)
|
|
names.append(name)
|
|
costs.append(cost)
|
|
block_ids.append([
|
|
node.building_block_id, node.cost_level,
|
|
node.sharding_weight + node.pipeline_weight,
|
|
node.same_spec_id
|
|
])
|
|
f.write(
|
|
" \"{}\" [fillcolor = \"{}\", label = \"{}\", shape = \"{}\", style = \"filled\"];\n"
|
|
.format(name, fillcolor, label, shape))
|
|
f.write(" // Edges\n")
|
|
for name, node in self._nodes.items():
|
|
if ignore_shape_io and node.layer is not None and node.layer.is_shape_io:
|
|
continue
|
|
for successor_node in node.successor_nodes:
|
|
if successor_node:
|
|
if ignore_shape_io and successor_node.layer is not None and successor_node.layer.is_shape_io:
|
|
continue
|
|
f.write(" \"{}\" ->\"{}\";\n".format(
|
|
name, successor_node.node_name))
|
|
f.write(" }\n")
|
|
df = pd.DataFrame.from_dict({
|
|
'node':
|
|
names,
|
|
'cost':
|
|
costs,
|
|
'block_id': [block[0] for block in block_ids],
|
|
'cost_level': [block[1] for block in block_ids],
|
|
'sharding_weight': [block[2] for block in block_ids],
|
|
'same_spec_id': [block[3] for block in block_ids]
|
|
})
|
|
df['weight_cost'] = df['sharding_weight'] * df['cost']
|
|
df.to_csv(fname + '.csv')
|