TensorRT-LLMs/tensorrt_llm/auto_parallel/simplifier.py
2024-07-17 20:45:02 +08:00

838 lines
38 KiB
Python

import math
import re
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Tuple
import numpy as np
from tensorrt_llm.network import Network
from .config import AutoParallelConfig
from .device_mesh import PhysicalDeviceMesh
from .pipeline_graph import PipelineGraph
from .shape_info import ShapeInfo, ShapeType, get_shape_info
from .tensor_parallel.p2p_node import P2PType
from .utils import get_cache_key, get_sorted_layer_ids, silent_trt_logger
class StageType(Enum):
START = 0
BLOCK = 1
END = 2
class BuildingBlock:
def __init__(self, graph, layer_range) -> None:
self.graph = graph
self.layer_range = layer_range
self.network = graph.as_trt()
self.owned_inputs = {}
self.is_edges_collected = False
self.intra_edges = []
self.src_inter_edges = []
self.dst_inter_edges = []
self.relative_src_inter_edges = []
self.relative_dst_inter_edges = []
self.relative_inter_edges = set()
self.edge_hash = None
self.outputs = None
self.type_id = -1
self.block_id = -1
self.p2p_type = None
self.is_superset = False
self.is_subset = False
self.sorted_layer_ids = []
def collect_edges(self):
if self.is_edges_collected:
return
for layer_index in self.layer_range:
trt_layer = self.network.get_layer(layer_index)
layer = self.graph.get_layer(trt_layer.name)
layer_offset = layer.index - self.layer_range.start
for input_index, input in enumerate(layer.inputs):
if input is not None:
if input.is_graph_input:
is_owned = input.graph_input_index in self.owned_inputs
if not is_owned and np.all([
layer.index in self.layer_range or np.all([
output.as_trt().is_shape_tensor
for output in layer.outputs
]) for layer, _ in input.consumers
]):
self.owned_inputs[input.graph_input_index] = len(
self.owned_inputs)
is_owned = True
if is_owned:
self.intra_edges.append(
(-1, self.owned_inputs[input.graph_input_index],
layer_offset, input_index))
else:
self.dst_inter_edges.append(
(-1, input.graph_input_index, layer_offset,
input_index))
else:
src_layer_index = input.producer.index
if src_layer_index < self.layer_range.start or src_layer_index >= self.layer_range.stop:
self.dst_inter_edges.append(
(src_layer_index, input.output_index,
layer_offset, input_index))
else:
src_layer_offset = src_layer_index - self.layer_range.start
self.intra_edges.append(
(src_layer_offset, input.output_index,
layer_offset, input_index))
for output_index, output in enumerate(layer.outputs):
for dst_layer, dst_input_index in output.consumers:
dst_layer_index = dst_layer.index
if dst_layer_index < self.layer_range.start or dst_layer_index >= self.layer_range.stop:
self.src_inter_edges.append(
(layer_offset, output_index, dst_layer_index,
dst_input_index))
self.edge_hash = tuple(self.intra_edges)
self.outputs = sorted(
set((edge[0], edge[1]) for edge in self.src_inter_edges))
self.is_edges_collected = True
def collect_relative_inter_edges(self, layer_to_block):
self.collect_edges()
for src_layer_index, src_output_index, dst_layer_index, dst_input_index in self.dst_inter_edges:
if src_layer_index in layer_to_block:
src_block = layer_to_block[src_layer_index]
src_layer_offset = src_layer_index - src_block.layer_range.start
dst = (self.type_id, dst_layer_index, dst_input_index)
self.relative_dst_inter_edges.append(
(src_block.type_id, src_layer_offset, src_output_index,
*dst))
else:
self.relative_dst_inter_edges.append(
(-1, src_layer_index, src_output_index, self.type_id,
dst_layer_index, dst_input_index))
self.relative_inter_edges = set(self.relative_dst_inter_edges +
self.outputs)
def get_input_names(self):
self.collect_edges()
input_tensor_names = []
for edge in self.dst_inter_edges:
layer_index = edge[0]
output_index = edge[1]
if layer_index == -1:
tensor_name = self.network.get_input(output_index).name
else:
tensor_name = self.network.get_layer(layer_index).get_output(
output_index).name
input_tensor_names.append(tensor_name)
return input_tensor_names
def get_input_mapping(self, last_blocks):
input_mapping = {}
for tensor_name, relative_edge in zip(self.get_input_names(),
self.relative_dst_inter_edges):
type_id = relative_edge[0]
output_index = relative_edge[2]
if type_id >= 0:
last_block = last_blocks[type_id]
layer_offset = relative_edge[1]
mapped_layer_index = last_block.layer_range.start + layer_offset
mapped_tensor_name = self.network.get_layer(
mapped_layer_index).get_output(output_index).name
input_mapping[tensor_name] = mapped_tensor_name
else:
input_mapping[tensor_name] = tensor_name
return input_mapping
@dataclass
class GraphMapping:
layer_mapping: Dict[int, int] = None
block_mapping: Dict[int, int] = None
p2p_types: Dict[int, P2PType] = None
p2p_tensors: Dict[int, List[str]] = None
block_to_stage: Dict[int, int] = None
same_spec_layer_mapping: Dict[str, str] = None
@dataclass
class GraphConfig:
num_micro_batches: int = 1
num_blocks: int = 1
num_stages: int = 1
has_cross_device: bool = False
has_cross_host: bool = False
graph_mapping: GraphMapping = None
phy_mesh: PhysicalDeviceMesh = None
stage_phy_meshes: List[PhysicalDeviceMesh] = None
class Simplifier:
def __init__(self, network: Network, config: AutoParallelConfig):
self.config = config
self.sharded_io_allowlist = config.sharded_io_allowlist
self.same_buffer_io = config.same_buffer_io
self.same_spec_io = config.same_spec_io.copy()
for key, value in self.same_buffer_io.items():
if key not in self.same_spec_io:
self.same_spec_io[key] = value
self.llm_network = network
self.network = network.trt_network
self.module_to_layer_range_map = network._module_call_stack.module_to_layer_range_map
self.graph = self.get_graph()
self.init_layer_hash()
module_tree = self.get_module_tree()
building_blocks = self.collect_building_blocks(module_tree)
blocks_by_module_hash = self.get_blocks_by_module_hash(building_blocks)
self.blocks_by_edge_hash = self.get_blocks_by_edge_hash(
blocks_by_module_hash)
self.layer_to_block = self.get_layer_to_block()
self.blocks = self.get_all_blocks()
self.backbone_blocks = self.get_backbone_blocks()
self.graph_mapping_for_shape = self.get_graph_mapping_for_shape()
self.graph_for_shape = self.create_simplified_graph_for_shape()
self.shape_info = None
self.num_micro_batches = None
def infer_shapes(self, num_micro_batches):
if self.num_micro_batches == num_micro_batches:
return
with silent_trt_logger():
self.shape_info = self.get_full_shape_info(num_micro_batches)
self.graph.assign_shapes(self.shape_info)
self.num_micro_batches = num_micro_batches
def list_all_num_micro_batches(self):
opt_batch_size = self.get_opt_batch_size()
candidates = []
for num_micro_batches in range(1, self.get_opt_batch_size() + 1):
if opt_batch_size % num_micro_batches == 0:
candidates.append(num_micro_batches)
return candidates
def get_graph(self):
graph = PipelineGraph.from_trt(self.network)
graph._unfilled_weights = self.llm_network._unfilled_weights.copy()
graph._io_buffer_mapping
for input in graph.inputs:
input_name = input.name
for pattern, repl in self.same_buffer_io.items():
if re.match(pattern, input_name):
output_name = re.sub(pattern, repl, input_name)
output = graph.get_output(output_name)
if output is not None:
graph._io_buffer_mapping[output_name] = input_name
return graph
def get_opt_batch_size(self):
input_tensors = self.llm_network._inputs
num_profiles = len(list(input_tensors.values())[0].profiles)
opt_batch_sizes = []
for i in range(num_profiles):
for input_tensor in input_tensors.values():
shape_profile = input_tensor.profiles[i]
opt_shape = shape_profile.opt
for j in range(len(input_tensor.shape)):
name = input_tensor.trt_tensor.get_dimension_name(j)
if name == 'batch_size':
opt_batch_sizes.append(opt_shape[j])
return min(opt_batch_sizes)
def get_module_hash(self, layer_range):
module_hash = ()
for i in layer_range:
assert i < self.network.num_layers, f"layer index {i} in {layer_range} out of range of {self.network.num_layers}"
layer_name = self.network.get_layer(i).name
layer = self.graph.get_layer(layer_name)
module_hash += (layer.attrs["hash"], )
return module_hash
def get_network_hash(self) -> str:
return str(self.get_module_hash(range(self.network.num_layers)))
def collect_building_blocks(self, module_tree):
building_blocks = {}
queue = []
for tree in module_tree["children"].values():
queue.append(tree)
while len(queue) > 0:
while len(queue) > 0:
tree = queue.pop(0)
module_name = tree["name"]
if module_name is None:
for child in tree["children"].values():
queue.append(child)
continue
layer_range = self.module_to_layer_range_map[module_name]
module_hash = self.get_module_hash(layer_range)
if module_hash in building_blocks:
building_blocks[module_hash].append(tree)
else:
building_blocks[module_hash] = [tree]
for module_hash in [*building_blocks.keys()]:
if len(building_blocks[module_hash]) == 1:
tree = building_blocks[module_hash][0]
for child in tree["children"].values():
queue.append(child)
del building_blocks[module_hash]
blocks_by_module_hash = {
module_hash: [
BuildingBlock(self.graph,
self.module_to_layer_range_map[tree["name"]])
for tree in trees
]
for module_hash, trees in building_blocks.items()
}
building_blocks = []
for block_list in blocks_by_module_hash.values():
for block in block_list:
building_blocks.append(block)
building_blocks = sorted(building_blocks,
key=lambda x: x.layer_range.start)
if len(building_blocks) >= 2:
for block, next_block in zip(building_blocks[:-1],
building_blocks[1:]):
block.layer_range = range(block.layer_range.start,
next_block.layer_range.start)
return building_blocks
def get_all_blocks(self):
building_blocks = []
for block_list in self.blocks_by_edge_hash.values():
for block in block_list:
building_blocks.append(block)
building_blocks = sorted(building_blocks,
key=lambda x: x.layer_range.start)
all_blocks = []
current_layer_index = 0
block_id = 0
for block in building_blocks:
assert current_layer_index <= block.layer_range.start
if current_layer_index < block.layer_range.start:
new_block = BuildingBlock(
self.graph,
range(current_layer_index, block.layer_range.start))
new_block.block_id = block_id
block_id += 1
all_blocks.append(new_block)
block.block_id = block_id
block_id += 1
all_blocks.append(block)
current_layer_index = block.layer_range.stop
if current_layer_index < self.graph.num_layers:
new_block = BuildingBlock(
self.graph, range(current_layer_index, self.graph.num_layers))
new_block.block_id = block_id
all_blocks.append(new_block)
sorted_layer_ids = get_sorted_layer_ids(self.network)
for block in all_blocks:
block.collect_relative_inter_edges(self.layer_to_block)
for layer_id in sorted_layer_ids:
if layer_id in block.layer_range:
block.sorted_layer_ids.append(layer_id)
return all_blocks
def get_backbone_blocks(self):
sorted_blocks = sorted(
self.blocks_by_edge_hash.values(),
key=lambda blocks: (len(blocks), len(blocks[0].layer_range)),
)
if len(sorted_blocks) == 0:
return []
else:
return sorted_blocks[-1]
def get_blocks_by_module_hash(self, blocks):
blocks_by_module_hash = {}
for block in blocks:
module_hash = self.get_module_hash(block.layer_range)
if module_hash not in blocks_by_module_hash:
blocks_by_module_hash[module_hash] = []
blocks_by_module_hash[module_hash].append(block)
for module_hash in [*blocks_by_module_hash.keys()]:
if len(blocks_by_module_hash[module_hash]) == 1:
del blocks_by_module_hash[module_hash]
return blocks_by_module_hash
def get_module_tree(self):
module_tree = {"children": {}, "name": None}
for module_name in self.module_to_layer_range_map.keys():
full_name = module_name.split('.')
current_tree = module_tree["children"]
for depth, name in enumerate(full_name):
if name not in current_tree:
current_tree[name] = {"children": {}, "name": None}
if depth == len(full_name) - 1:
current_tree[name]["name"] = module_name
else:
current_tree = current_tree[name]["children"]
return module_tree
def get_blocks_by_edge_hash(self, blocks_by_module_hash):
blocks_by_edge_hash = {}
for block_list in blocks_by_module_hash.values():
for block in block_list:
block.collect_edges()
edge_hash = block.edge_hash
if edge_hash not in blocks_by_edge_hash:
blocks_by_edge_hash[edge_hash] = []
blocks_by_edge_hash[edge_hash].append(block)
for edge_hash in [*blocks_by_edge_hash.keys()]:
if len(blocks_by_edge_hash[edge_hash]) == 1:
del blocks_by_edge_hash[edge_hash]
else:
block_list = blocks_by_edge_hash[edge_hash]
blocks_by_edge_hash[edge_hash] = sorted(
block_list, key=lambda x: x.layer_range.start)
for type_id, block_list in enumerate(blocks_by_edge_hash.values()):
for block in block_list:
block.type_id = type_id
return blocks_by_edge_hash
def get_layer_to_block(self):
layer_to_block = {}
for block_list in self.blocks_by_edge_hash.values():
for block in block_list:
for layer_index in block.layer_range:
layer_to_block[layer_index] = block
return layer_to_block
def clean_blocks(self):
for block in self.blocks:
block.p2p_type = None
block.is_superset = False
block.is_subset = False
def mark_p2p_type(self, phy_mesh, stage_phy_meshes,
graph_config: GraphConfig):
if len(self.backbone_blocks) == 0 or len(stage_phy_meshes) == 1:
return
assert len(self.backbone_blocks) % len(stage_phy_meshes) == 0
block_per_stage = len(self.backbone_blocks) // len(stage_phy_meshes)
for block in self.backbone_blocks:
block.p2p_type = None
for stage_index, stage_phy_mesh in enumerate(stage_phy_meshes[:-1]):
next_stage_phy_mesh = stage_phy_meshes[stage_index + 1]
last_device_id = stage_phy_mesh.phy_devices_id.flatten()[-1]
next_first_device_id = next_stage_phy_mesh.phy_devices_id.flatten(
)[0]
num_devices_per_host = phy_mesh.num_devices_per_host
next_block = self.backbone_blocks[(stage_index + 1) *
block_per_stage]
if last_device_id // num_devices_per_host != next_first_device_id // num_devices_per_host:
next_block.p2p_type = P2PType.CROSS_HOST
graph_config.has_cross_host = True
else:
next_block.p2p_type = P2PType.CROSS_DEVICE
graph_config.has_cross_device = True
def get_graph_mapping(self):
layer_mapping = {}
block_mapping = {}
p2p_types = {}
p2p_tensors = {}
for block_list in self.blocks_by_edge_hash.values():
superset_blocks = []
superset_block_index = {}
for block in block_list:
block_added = False
for index, superset_block in enumerate(list(superset_blocks)):
if block.p2p_type == superset_block.p2p_type:
if block.relative_inter_edges.issubset(
superset_block.relative_inter_edges):
block.is_subset = True
block.is_superset = False
superset_block_index[id(block)] = index
block_added = True
break
elif superset_block.relative_inter_edges.issubset(
block.relative_inter_edges):
superset_block.is_subset = True
superset_block.is_superset = False
block.is_subset = False
block.is_superset = True
superset_blocks[index] = block
superset_block_index[id(block)] = index
block_added = True
break
if not block_added:
block.is_subset = False
block.is_superset = True
superset_blocks.append(block)
superset_block_index[id(block)] = len(superset_blocks) - 1
for block in block_list:
assert not (block.is_subset and block.is_superset)
if block.is_subset:
superset_block = superset_blocks[superset_block_index[id(
block)]]
block_mapping[block.block_id] = superset_block.block_id
owned_inputs = map(
lambda x: x[0],
sorted(block.owned_inputs.items(), key=lambda x: x[1]))
superset_owned_inputs = map(
lambda x: x[0],
sorted(superset_block.owned_inputs.items(),
key=lambda x: x[1]))
for from_input_id, to_input_id in zip(
owned_inputs, superset_owned_inputs):
from_input_name = self.network.get_input(
from_input_id).name
to_input_name = self.network.get_input(to_input_id).name
layer_mapping[from_input_name] = to_input_name
for from_layer_id, to_layer_id in zip(
block.layer_range, superset_block.layer_range):
from_layer = self.network.get_layer(from_layer_id)
to_layer = self.network.get_layer(to_layer_id)
layer_mapping[from_layer.name] = to_layer.name
for i in range(from_layer.num_outputs):
from_output = from_layer.get_output(i)
if from_output.is_network_output:
to_output = to_layer.get_output(i)
layer_mapping[from_output.name] = to_output.name
if block.p2p_type is not None:
p2p_types[block.block_id] = block.p2p_type
p2p_tensors[block.block_id] = [
*set(block.get_input_names())
]
for from_name, to_name in zip(
block.get_input_names(),
superset_block.get_input_names()):
layer_mapping[
f"p2p_block{block.block_id}_{from_name}"] = f"p2p_block{superset_block.block_id}_{to_name}"
stage_id = 0
block_to_stage = {}
for block in self.blocks:
if block.p2p_type is not None:
stage_id += 1
block_to_stage[block.block_id] = stage_id
return GraphMapping(
layer_mapping,
block_mapping,
p2p_types,
p2p_tensors,
block_to_stage,
)
def create_simplified_graph(self, graph_config: GraphConfig):
new_graph = PipelineGraph.create_graph()
new_graph._io_buffer_mapping = self.graph._io_buffer_mapping
layer_mapping = graph_config.graph_mapping.layer_mapping
for i in range(self.network.num_inputs):
trt_input = self.network.get_input(i)
if trt_input.name not in layer_mapping:
new_graph.add_input(trt_input)
last_blocks = {}
same_spec_mapping = {}
same_spec_layer_mapping = {}
shape_mapping = {}
building_block_id = 0
same_spec_ids = {}
same_spec_count = 0
for block in self.blocks:
if not block.is_subset:
stage_type = None
if not block.is_superset:
if block.block_id == 0:
stage_type = StageType.START
elif block.block_id == len(self.blocks) - 1:
stage_type = StageType.END
input_mapping = block.get_input_mapping(last_blocks)
for from_name, to_name in [*input_mapping.items()]:
if to_name in same_spec_mapping:
input_mapping[from_name] = same_spec_mapping[to_name]
if to_name in layer_mapping:
input_mapping[from_name] = layer_mapping[to_name]
if block.is_superset and block.p2p_type is not None:
for from_name, to_name in [*input_mapping.items()]:
output_tensor = new_graph.get_tensor(to_name)
p2p_layer = new_graph.as_trt().add_identity(
output_tensor.as_trt())
p2p_layer.name = f"p2p_block{block.block_id}_{from_name}"
p2p_layer.metadata = p2p_layer.name
p2p_tensor = p2p_layer.get_output(0)
p2p_tensor.name = f"{p2p_layer.name}_output"
wrapped_layer = new_graph.register_layer(p2p_layer)
wrapped_layer.attrs[
"building_block_id"] = building_block_id
wrapped_layer.attrs["p2p_type"] = block.p2p_type
input_mapping[from_name] = p2p_tensor.name
shape_mapping[p2p_tensor.name] = from_name
building_block_id += 1
for i in block.sorted_layer_ids:
layer = self.network.get_layer(i)
wrapped_layer = new_graph.add_layer(
layer,
input_mapping=input_mapping,
)
wrapped_layer.attrs["building_block_id"] = building_block_id
wrapped_layer.attrs["stage_type"] = stage_type
if block.is_superset:
last_blocks[block.type_id] = block
if block.type_id in same_spec_ids:
same_spec_id = same_spec_ids[block.type_id]
update_same_spec_count = False
else:
same_spec_id = same_spec_count
same_spec_ids[block.type_id] = same_spec_id
update_same_spec_count = True
count = same_spec_id
for i, (layer_offset,
output_index) in enumerate(block.outputs):
layer = self.network.get_layer(block.layer_range.start +
layer_offset)
tensor_name = layer.get_output(output_index).name
output_tensor = new_graph.get_tensor(tensor_name)
same_spec_layer = new_graph.as_trt().add_identity(
output_tensor.as_trt())
same_spec_layer.name = f"{tensor_name}_same_spec"
same_spec_layer.metadata = same_spec_layer.name
same_spec_tensor = same_spec_layer.get_output(0)
same_spec_tensor.name = f"{same_spec_layer.name}_output"
wrapped_layer = new_graph.register_layer(
same_spec_layer)
wrapped_layer.attrs[
"building_block_id"] = building_block_id
wrapped_layer.attrs["same_spec_id"] = count
count += 1
same_spec_mapping[tensor_name] = same_spec_tensor.name
same_spec_layer_mapping[
same_spec_layer.name] = layer.name
shape_mapping[same_spec_tensor.name] = tensor_name
for i, graph_input_index in enumerate(
block.owned_inputs.keys()):
input_name = self.network.get_input(
graph_input_index).name
input_tensor = new_graph.get_input(input_name)
input_tensor.attrs["same_spec_id"] = count
count += 1
if update_same_spec_count:
same_spec_count = count
building_block_id += 1
graph_config.graph_mapping.same_spec_layer_mapping = same_spec_layer_mapping
if len(self.backbone_blocks) >= 2:
start_block = self.backbone_blocks[0]
if start_block.is_subset:
start_block = self.blocks[graph_config.graph_mapping.
block_mapping[start_block.block_id]]
for i in start_block.layer_range:
layer_name = self.network.get_layer(i).name
layer = new_graph.get_layer(layer_name)
layer.attrs["in_start_block"] = True
end_block = self.backbone_blocks[-1]
if end_block.is_subset:
end_block = self.blocks[graph_config.graph_mapping.
block_mapping[end_block.block_id]]
for i in end_block.layer_range:
layer_name = self.network.get_layer(i).name
layer = new_graph.get_layer(layer_name)
layer.attrs["in_end_block"] = True
slowest_p2p_type = None
if graph_config.has_cross_host:
slowest_p2p_type = P2PType.CROSS_HOST
elif graph_config.has_cross_device:
slowest_p2p_type = P2PType.CROSS_DEVICE
if slowest_p2p_type is not None:
for block in self.blocks:
if block.is_superset and block.p2p_type == slowest_p2p_type:
for i in block.layer_range:
layer_name = self.network.get_layer(i).name
layer = new_graph.get_layer(layer_name)
layer.attrs["in_slowest_block"] = True
for i in range(self.network.num_outputs):
trt_output = self.network.get_output(i)
output = self.graph.get_output(trt_output.name)
if output.producer is not None and output.producer.index in self.layer_to_block and self.layer_to_block[
output.producer.index].is_subset:
continue
if trt_output.is_shape_tensor:
new_output = new_graph.add_output_shape(trt_output)
else:
new_output = new_graph.add_output(trt_output)
sharded_io = False
for pattern in self.sharded_io_allowlist:
if re.match(pattern, new_output.name):
sharded_io = True
break
if not sharded_io:
new_output.producer.attrs["is_replicated"] = True
for input in new_graph.inputs:
input_name = input.name
sharded_io = False
for pattern in self.sharded_io_allowlist:
if re.match(pattern, input_name):
sharded_io = True
break
if not sharded_io:
input.attrs["is_replicated"] = True
for pattern, repl in self.same_spec_io.items():
if re.match(pattern, input_name):
output_name = re.sub(pattern, repl, input_name)
output = new_graph.get_output(output_name)
if output is not None:
if "same_spec_id" in input.attrs:
same_spec_id = input.attrs["same_spec_id"]
else:
same_spec_id = same_spec_count
same_spec_count += 1
input.attrs["same_spec_id"] = same_spec_id
output.attrs["same_spec_id"] = same_spec_id
if math.prod(self.graph.get_input(
input_name).shape) < math.prod(
self.graph.get_output(output_name).shape):
input.attrs["no_memory_footprint"] = True
else:
output.attrs["no_memory_footprint"] = True
return new_graph, shape_mapping
def enrich_shape_info(self, shape_mapping):
shapes = self.shape_info.shapes.copy()
max_shapes = self.shape_info.max_shapes.copy()
values = self.shape_info.values.copy()
shape_layers = self.shape_info.shape_layers
for from_name, to_name in shape_mapping.items():
if to_name in shapes:
shapes[from_name] = shapes[to_name]
if to_name in max_shapes:
max_shapes[from_name] = max_shapes[to_name]
if to_name in values:
values[from_name] = values[to_name]
shape_info = ShapeInfo(shapes, values, shape_layers, max_shapes)
return shape_info
def simplify_graph(
self, phy_mesh: PhysicalDeviceMesh, num_stages: int,
num_devices_per_stage: int) -> Tuple[PipelineGraph, GraphConfig]:
num_blocks = len(self.backbone_blocks)
if num_blocks % num_stages != 0:
return None, None
graph_config = GraphConfig()
graph_config.num_micro_batches = self.num_micro_batches
graph_config.num_blocks = num_blocks
graph_config.num_stages = num_stages
graph_config.phy_mesh = phy_mesh
stage_phy_meshes = phy_mesh.split_pipeline_meshes(
num_stages, num_devices_per_stage)
graph_config.stage_phy_meshes = stage_phy_meshes
with silent_trt_logger():
self.clean_blocks()
self.mark_p2p_type(phy_mesh, stage_phy_meshes, graph_config)
graph_config.graph_mapping = self.get_graph_mapping()
new_graph, shape_mapping = self.create_simplified_graph(
graph_config)
shape_info = self.enrich_shape_info(shape_mapping)
new_graph.assign_shapes(shape_info)
return new_graph, graph_config
def get_graph_mapping_for_shape(self):
layer_mapping = {}
tensor_mapping = {}
for block_list in self.blocks_by_edge_hash.values():
head_block = block_list[0]
for block in block_list[1:]:
for from_layer_id, to_layer_id in zip(block.layer_range,
head_block.layer_range):
from_layer = self.network.get_layer(from_layer_id)
to_layer = self.network.get_layer(to_layer_id)
layer_mapping[from_layer.name] = to_layer.name
for i in range(from_layer.num_outputs):
tensor_mapping[from_layer.get_output(
i).name] = to_layer.get_output(i).name
return layer_mapping, tensor_mapping
def create_simplified_graph_for_shape(self):
new_graph = PipelineGraph.create_graph()
for i in range(self.network.num_inputs):
trt_input = self.network.get_input(i)
new_graph.add_input(trt_input)
head_blocks = {}
removed_blocks = set()
removed_layers = set()
for block_list in self.blocks_by_edge_hash.values():
head_block = block_list[0]
head_blocks[head_block.type_id] = head_block
for block in block_list[1:]:
removed_blocks.add(id(block))
for layer_index in block.layer_range:
removed_layers.add(layer_index)
for block in self.blocks:
if not id(block) in removed_blocks:
input_mapping = block.get_input_mapping(head_blocks)
for i in block.sorted_layer_ids:
layer = self.network.get_layer(i)
new_graph.add_layer(
layer,
input_mapping=input_mapping,
)
for i in range(self.network.num_outputs):
trt_output = self.network.get_output(i)
output = self.graph.get_output(trt_output.name)
if output.producer is not None and output.producer.index in removed_layers:
continue
if trt_output.is_shape_tensor:
new_graph.add_output_shape(trt_output)
else:
new_graph.add_output(trt_output)
return new_graph
def get_full_shape_info(self, num_micro_batches):
layer_mapping, tensor_mapping = self.graph_mapping_for_shape
optimization_profiles = self.llm_network._generate_optimization_profiles(
)
if len(optimization_profiles) > 0:
optimization_profile = optimization_profiles[-1]
else:
optimization_profile = None
shape_info = get_shape_info(self.graph_for_shape.as_trt(),
optimization_profile)
max_shape_info = get_shape_info(self.graph_for_shape.as_trt(),
optimization_profile,
shape_type=ShapeType.MAX)
shape_info.max_shapes = max_shape_info.shapes
for removed_tensor_name, tensor_name in tensor_mapping.items():
shape_info.shapes[removed_tensor_name] = shape_info.shapes[
tensor_name]
shape_info.max_shapes[removed_tensor_name] = shape_info.max_shapes[
tensor_name]
if tensor_name in shape_info.values:
shape_info.values[removed_tensor_name] = shape_info.values[
tensor_name]
for removed_layer_name, layer_name in layer_mapping.items():
if layer_name in shape_info.shape_layers:
shape_info.shape_layers.add(removed_layer_name)
return shape_info
def init_layer_hash(self):
with silent_trt_logger():
optimization_profiles = self.llm_network._generate_optimization_profiles(
)
if len(optimization_profiles) > 0:
optimization_profile = optimization_profiles[-1]
else:
optimization_profile = None
shape_info = get_shape_info(self.network, optimization_profile)
dtypes = {tensor.name: tensor.dtype for tensor in self.graph.tensors}
for layer in self.graph.layers:
layer_hash = get_cache_key(
layer.as_trt(),
shape_info.shapes,
shape_info.values,
dtypes,
)
layer.attrs["hash"] = layer_hash