mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com> Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
836 lines
38 KiB
Python
836 lines
38 KiB
Python
import math
|
|
import re
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import Dict, List, Tuple
|
|
|
|
import numpy as np
|
|
|
|
from tensorrt_llm.network import Network
|
|
|
|
from .config import AutoParallelConfig
|
|
from .device_mesh import PhysicalDeviceMesh
|
|
from .pipeline_graph import PipelineGraph
|
|
from .shape_info import ShapeInfo, ShapeType, get_shape_info
|
|
from .tensor_parallel.p2p_node import P2PType
|
|
from .utils import get_cache_key, get_sorted_layer_ids, silent_trt_logger
|
|
|
|
|
|
class StageType(Enum):
|
|
START = 0
|
|
BLOCK = 1
|
|
END = 2
|
|
|
|
|
|
class BuildingBlock:
|
|
|
|
def __init__(self, graph, layer_range) -> None:
|
|
self.graph = graph
|
|
self.layer_range = layer_range
|
|
self.network = graph.as_trt()
|
|
self.owned_inputs = {}
|
|
self.is_edges_collected = False
|
|
self.intra_edges = []
|
|
self.src_inter_edges = []
|
|
self.dst_inter_edges = []
|
|
self.relative_src_inter_edges = []
|
|
self.relative_dst_inter_edges = []
|
|
self.relative_inter_edges = set()
|
|
self.edge_hash = None
|
|
self.outputs = None
|
|
self.type_id = -1
|
|
self.block_id = -1
|
|
self.p2p_type = None
|
|
self.is_superset = False
|
|
self.is_subset = False
|
|
self.sorted_layer_ids = []
|
|
|
|
def collect_edges(self):
|
|
if self.is_edges_collected:
|
|
return
|
|
for layer_index in self.layer_range:
|
|
trt_layer = self.network.get_layer(layer_index)
|
|
layer = self.graph.get_layer(trt_layer.name)
|
|
layer_offset = layer.index - self.layer_range.start
|
|
for input_index, input in enumerate(layer.inputs):
|
|
if input is not None:
|
|
if input.is_graph_input:
|
|
is_owned = input.graph_input_index in self.owned_inputs
|
|
if not is_owned and np.all([
|
|
layer.index in self.layer_range or np.all([
|
|
output.as_trt().is_shape_tensor
|
|
for output in layer.outputs
|
|
]) for layer, _ in input.consumers
|
|
]):
|
|
self.owned_inputs[input.graph_input_index] = len(
|
|
self.owned_inputs)
|
|
is_owned = True
|
|
if is_owned:
|
|
self.intra_edges.append(
|
|
(-1, self.owned_inputs[input.graph_input_index],
|
|
layer_offset, input_index))
|
|
else:
|
|
self.dst_inter_edges.append(
|
|
(-1, input.graph_input_index, layer_offset,
|
|
input_index))
|
|
else:
|
|
src_layer_index = input.producer.index
|
|
if src_layer_index < self.layer_range.start or src_layer_index >= self.layer_range.stop:
|
|
self.dst_inter_edges.append(
|
|
(src_layer_index, input.output_index,
|
|
layer_offset, input_index))
|
|
else:
|
|
src_layer_offset = src_layer_index - self.layer_range.start
|
|
self.intra_edges.append(
|
|
(src_layer_offset, input.output_index,
|
|
layer_offset, input_index))
|
|
for output_index, output in enumerate(layer.outputs):
|
|
for dst_layer, dst_input_index in output.consumers:
|
|
dst_layer_index = dst_layer.index
|
|
if dst_layer_index < self.layer_range.start or dst_layer_index >= self.layer_range.stop:
|
|
self.src_inter_edges.append(
|
|
(layer_offset, output_index, dst_layer_index,
|
|
dst_input_index))
|
|
self.edge_hash = tuple(self.intra_edges)
|
|
self.outputs = sorted(
|
|
set((edge[0], edge[1]) for edge in self.src_inter_edges))
|
|
self.is_edges_collected = True
|
|
|
|
def collect_relative_inter_edges(self, layer_to_block):
|
|
self.collect_edges()
|
|
for src_layer_index, src_output_index, dst_layer_index, dst_input_index in self.dst_inter_edges:
|
|
if src_layer_index in layer_to_block:
|
|
src_block = layer_to_block[src_layer_index]
|
|
src_layer_offset = src_layer_index - src_block.layer_range.start
|
|
dst = (self.type_id, dst_layer_index, dst_input_index)
|
|
self.relative_dst_inter_edges.append(
|
|
(src_block.type_id, src_layer_offset, src_output_index,
|
|
*dst))
|
|
else:
|
|
self.relative_dst_inter_edges.append(
|
|
(-1, src_layer_index, src_output_index, self.type_id,
|
|
dst_layer_index, dst_input_index))
|
|
self.relative_inter_edges = set(self.relative_dst_inter_edges +
|
|
self.outputs)
|
|
|
|
def get_input_names(self):
|
|
self.collect_edges()
|
|
input_tensor_names = []
|
|
for edge in self.dst_inter_edges:
|
|
layer_index = edge[0]
|
|
output_index = edge[1]
|
|
if layer_index == -1:
|
|
tensor_name = self.network.get_input(output_index).name
|
|
else:
|
|
tensor_name = self.network.get_layer(layer_index).get_output(
|
|
output_index).name
|
|
input_tensor_names.append(tensor_name)
|
|
return input_tensor_names
|
|
|
|
def get_input_mapping(self, last_blocks):
|
|
input_mapping = {}
|
|
for tensor_name, relative_edge in zip(self.get_input_names(),
|
|
self.relative_dst_inter_edges):
|
|
type_id = relative_edge[0]
|
|
output_index = relative_edge[2]
|
|
if type_id >= 0:
|
|
last_block = last_blocks[type_id]
|
|
layer_offset = relative_edge[1]
|
|
mapped_layer_index = last_block.layer_range.start + layer_offset
|
|
mapped_tensor_name = self.network.get_layer(
|
|
mapped_layer_index).get_output(output_index).name
|
|
input_mapping[tensor_name] = mapped_tensor_name
|
|
else:
|
|
input_mapping[tensor_name] = tensor_name
|
|
return input_mapping
|
|
|
|
|
|
@dataclass
|
|
class GraphMapping:
|
|
layer_mapping: Dict[int, int] = None
|
|
block_mapping: Dict[int, int] = None
|
|
p2p_types: Dict[int, P2PType] = None
|
|
p2p_tensors: Dict[int, List[str]] = None
|
|
block_to_stage: Dict[int, int] = None
|
|
same_spec_layer_mapping: Dict[str, str] = None
|
|
|
|
|
|
@dataclass
|
|
class GraphConfig:
|
|
num_micro_batches: int = 1
|
|
num_blocks: int = 1
|
|
num_stages: int = 1
|
|
has_cross_device: bool = False
|
|
has_cross_host: bool = False
|
|
graph_mapping: GraphMapping = None
|
|
phy_mesh: PhysicalDeviceMesh = None
|
|
stage_phy_meshes: List[PhysicalDeviceMesh] = None
|
|
|
|
|
|
class Simplifier:
|
|
|
|
def __init__(self, network: Network, config: AutoParallelConfig):
|
|
self.config = config
|
|
self.sharded_io_allowlist = config.sharded_io_allowlist
|
|
self.same_buffer_io = config.same_buffer_io
|
|
self.same_spec_io = config.same_spec_io.copy()
|
|
for key, value in self.same_buffer_io.items():
|
|
if key not in self.same_spec_io:
|
|
self.same_spec_io[key] = value
|
|
|
|
self.llm_network = network
|
|
self.network = network.trt_network
|
|
self.module_to_layer_range_map = network._module_call_stack.module_to_layer_range_map
|
|
self.graph = self.get_graph()
|
|
self.init_layer_hash()
|
|
|
|
module_tree = self.get_module_tree()
|
|
building_blocks = self.collect_building_blocks(module_tree)
|
|
blocks_by_module_hash = self.get_blocks_by_module_hash(building_blocks)
|
|
self.blocks_by_edge_hash = self.get_blocks_by_edge_hash(
|
|
blocks_by_module_hash)
|
|
self.layer_to_block = self.get_layer_to_block()
|
|
self.blocks = self.get_all_blocks()
|
|
self.backbone_blocks = self.get_backbone_blocks()
|
|
self.graph_mapping_for_shape = self.get_graph_mapping_for_shape()
|
|
self.graph_for_shape = self.create_simplified_graph_for_shape()
|
|
self.shape_info = None
|
|
self.num_micro_batches = None
|
|
|
|
def infer_shapes(self, num_micro_batches):
|
|
if self.num_micro_batches == num_micro_batches:
|
|
return
|
|
with silent_trt_logger():
|
|
self.shape_info = self.get_full_shape_info(num_micro_batches)
|
|
self.graph.assign_shapes(self.shape_info)
|
|
self.num_micro_batches = num_micro_batches
|
|
|
|
def list_all_num_micro_batches(self):
|
|
opt_batch_size = self.get_opt_batch_size()
|
|
candidates = []
|
|
for num_micro_batches in range(1, self.get_opt_batch_size() + 1):
|
|
if opt_batch_size % num_micro_batches == 0:
|
|
candidates.append(num_micro_batches)
|
|
return candidates
|
|
|
|
def get_graph(self):
|
|
graph = PipelineGraph.from_trt(self.network)
|
|
graph._unfilled_weights = self.llm_network._unfilled_weights.copy()
|
|
graph._io_buffer_mapping
|
|
for input in graph.inputs:
|
|
input_name = input.name
|
|
for pattern, repl in self.same_buffer_io.items():
|
|
if re.match(pattern, input_name):
|
|
output_name = re.sub(pattern, repl, input_name)
|
|
output = graph.get_output(output_name)
|
|
if output is not None:
|
|
graph._io_buffer_mapping[output_name] = input_name
|
|
return graph
|
|
|
|
def get_opt_batch_size(self):
|
|
input_tensors = self.llm_network._inputs
|
|
num_profiles = len(list(input_tensors.values())[0].profiles)
|
|
opt_batch_sizes = []
|
|
for i in range(num_profiles):
|
|
for input_tensor in input_tensors.values():
|
|
shape_profile = input_tensor.profiles[i]
|
|
opt_shape = shape_profile.opt
|
|
for j in range(len(input_tensor.shape)):
|
|
name = input_tensor.trt_tensor.get_dimension_name(j)
|
|
if name == 'batch_size':
|
|
opt_batch_sizes.append(opt_shape[j])
|
|
return min(opt_batch_sizes)
|
|
|
|
def get_module_hash(self, layer_range):
|
|
module_hash = ()
|
|
for i in layer_range:
|
|
assert i < self.network.num_layers, f"layer index {i} in {layer_range} out of range of {self.network.num_layers}"
|
|
layer_name = self.network.get_layer(i).name
|
|
layer = self.graph.get_layer(layer_name)
|
|
module_hash += (layer.attrs["hash"], )
|
|
return module_hash
|
|
|
|
def get_network_hash(self) -> str:
|
|
return str(self.get_module_hash(range(self.network.num_layers)))
|
|
|
|
def collect_building_blocks(self, module_tree):
|
|
building_blocks = {}
|
|
queue = []
|
|
for tree in module_tree["children"].values():
|
|
queue.append(tree)
|
|
while len(queue) > 0:
|
|
while len(queue) > 0:
|
|
tree = queue.pop(0)
|
|
module_name = tree["name"]
|
|
if module_name is None:
|
|
for child in tree["children"].values():
|
|
queue.append(child)
|
|
continue
|
|
layer_range = self.module_to_layer_range_map[module_name]
|
|
module_hash = self.get_module_hash(layer_range)
|
|
if module_hash in building_blocks:
|
|
building_blocks[module_hash].append(tree)
|
|
else:
|
|
building_blocks[module_hash] = [tree]
|
|
for module_hash in [*building_blocks.keys()]:
|
|
if len(building_blocks[module_hash]) == 1:
|
|
tree = building_blocks[module_hash][0]
|
|
for child in tree["children"].values():
|
|
queue.append(child)
|
|
del building_blocks[module_hash]
|
|
blocks_by_module_hash = {
|
|
module_hash: [
|
|
BuildingBlock(self.graph,
|
|
self.module_to_layer_range_map[tree["name"]])
|
|
for tree in trees
|
|
]
|
|
for module_hash, trees in building_blocks.items()
|
|
}
|
|
building_blocks = []
|
|
for block_list in blocks_by_module_hash.values():
|
|
for block in block_list:
|
|
building_blocks.append(block)
|
|
building_blocks = sorted(building_blocks,
|
|
key=lambda x: x.layer_range.start)
|
|
if len(building_blocks) >= 2:
|
|
for block, next_block in zip(building_blocks[:-1],
|
|
building_blocks[1:]):
|
|
block.layer_range = range(block.layer_range.start,
|
|
next_block.layer_range.start)
|
|
return building_blocks
|
|
|
|
def get_all_blocks(self):
|
|
building_blocks = []
|
|
for block_list in self.blocks_by_edge_hash.values():
|
|
for block in block_list:
|
|
building_blocks.append(block)
|
|
building_blocks = sorted(building_blocks,
|
|
key=lambda x: x.layer_range.start)
|
|
all_blocks = []
|
|
current_layer_index = 0
|
|
block_id = 0
|
|
for block in building_blocks:
|
|
assert current_layer_index <= block.layer_range.start
|
|
if current_layer_index < block.layer_range.start:
|
|
new_block = BuildingBlock(
|
|
self.graph,
|
|
range(current_layer_index, block.layer_range.start))
|
|
new_block.block_id = block_id
|
|
block_id += 1
|
|
all_blocks.append(new_block)
|
|
block.block_id = block_id
|
|
block_id += 1
|
|
all_blocks.append(block)
|
|
current_layer_index = block.layer_range.stop
|
|
if current_layer_index < self.graph.num_layers:
|
|
new_block = BuildingBlock(
|
|
self.graph, range(current_layer_index, self.graph.num_layers))
|
|
new_block.block_id = block_id
|
|
all_blocks.append(new_block)
|
|
sorted_layer_ids = get_sorted_layer_ids(self.network)
|
|
for block in all_blocks:
|
|
block.collect_relative_inter_edges(self.layer_to_block)
|
|
for layer_id in sorted_layer_ids:
|
|
if layer_id in block.layer_range:
|
|
block.sorted_layer_ids.append(layer_id)
|
|
return all_blocks
|
|
|
|
def get_backbone_blocks(self):
|
|
sorted_blocks = sorted(
|
|
self.blocks_by_edge_hash.values(),
|
|
key=lambda blocks: (len(blocks), len(blocks[0].layer_range)),
|
|
)
|
|
if len(sorted_blocks) == 0:
|
|
return []
|
|
else:
|
|
return sorted_blocks[-1]
|
|
|
|
def get_blocks_by_module_hash(self, blocks):
|
|
blocks_by_module_hash = {}
|
|
for block in blocks:
|
|
module_hash = self.get_module_hash(block.layer_range)
|
|
if module_hash not in blocks_by_module_hash:
|
|
blocks_by_module_hash[module_hash] = []
|
|
blocks_by_module_hash[module_hash].append(block)
|
|
for module_hash in [*blocks_by_module_hash.keys()]:
|
|
if len(blocks_by_module_hash[module_hash]) == 1:
|
|
del blocks_by_module_hash[module_hash]
|
|
return blocks_by_module_hash
|
|
|
|
def get_module_tree(self):
|
|
module_tree = {"children": {}, "name": None}
|
|
for module_name in self.module_to_layer_range_map.keys():
|
|
full_name = module_name.split('.')
|
|
current_tree = module_tree["children"]
|
|
for depth, name in enumerate(full_name):
|
|
if name not in current_tree:
|
|
current_tree[name] = {"children": {}, "name": None}
|
|
if depth == len(full_name) - 1:
|
|
current_tree[name]["name"] = module_name
|
|
else:
|
|
current_tree = current_tree[name]["children"]
|
|
return module_tree
|
|
|
|
def get_blocks_by_edge_hash(self, blocks_by_module_hash):
|
|
blocks_by_edge_hash = {}
|
|
for block_list in blocks_by_module_hash.values():
|
|
for block in block_list:
|
|
block.collect_edges()
|
|
edge_hash = block.edge_hash
|
|
if edge_hash not in blocks_by_edge_hash:
|
|
blocks_by_edge_hash[edge_hash] = []
|
|
blocks_by_edge_hash[edge_hash].append(block)
|
|
for edge_hash in [*blocks_by_edge_hash.keys()]:
|
|
if len(blocks_by_edge_hash[edge_hash]) == 1:
|
|
del blocks_by_edge_hash[edge_hash]
|
|
else:
|
|
block_list = blocks_by_edge_hash[edge_hash]
|
|
blocks_by_edge_hash[edge_hash] = sorted(
|
|
block_list, key=lambda x: x.layer_range.start)
|
|
for type_id, block_list in enumerate(blocks_by_edge_hash.values()):
|
|
for block in block_list:
|
|
block.type_id = type_id
|
|
return blocks_by_edge_hash
|
|
|
|
def get_layer_to_block(self):
|
|
layer_to_block = {}
|
|
for block_list in self.blocks_by_edge_hash.values():
|
|
for block in block_list:
|
|
for layer_index in block.layer_range:
|
|
layer_to_block[layer_index] = block
|
|
return layer_to_block
|
|
|
|
def clean_blocks(self):
|
|
for block in self.blocks:
|
|
block.p2p_type = None
|
|
block.is_superset = False
|
|
block.is_subset = False
|
|
|
|
def mark_p2p_type(self, phy_mesh, stage_phy_meshes,
|
|
graph_config: GraphConfig):
|
|
if len(self.backbone_blocks) == 0 or len(stage_phy_meshes) == 1:
|
|
return
|
|
assert len(self.backbone_blocks) % len(stage_phy_meshes) == 0
|
|
block_per_stage = len(self.backbone_blocks) // len(stage_phy_meshes)
|
|
|
|
for block in self.backbone_blocks:
|
|
block.p2p_type = None
|
|
for stage_index, stage_phy_mesh in enumerate(stage_phy_meshes[:-1]):
|
|
next_stage_phy_mesh = stage_phy_meshes[stage_index + 1]
|
|
last_device_id = stage_phy_mesh.phy_devices_id.flatten()[-1]
|
|
next_first_device_id = next_stage_phy_mesh.phy_devices_id.flatten(
|
|
)[0]
|
|
num_devices_per_host = phy_mesh.num_devices_per_host
|
|
next_block = self.backbone_blocks[(stage_index + 1) *
|
|
block_per_stage]
|
|
if last_device_id // num_devices_per_host != next_first_device_id // num_devices_per_host:
|
|
next_block.p2p_type = P2PType.CROSS_HOST
|
|
graph_config.has_cross_host = True
|
|
else:
|
|
next_block.p2p_type = P2PType.CROSS_DEVICE
|
|
graph_config.has_cross_device = True
|
|
|
|
def get_graph_mapping(self):
|
|
layer_mapping = {}
|
|
block_mapping = {}
|
|
p2p_types = {}
|
|
p2p_tensors = {}
|
|
for block_list in self.blocks_by_edge_hash.values():
|
|
superset_blocks = []
|
|
superset_block_index = {}
|
|
for block in block_list:
|
|
block_added = False
|
|
for index, superset_block in enumerate(list(superset_blocks)):
|
|
if block.p2p_type == superset_block.p2p_type:
|
|
if block.relative_inter_edges.issubset(
|
|
superset_block.relative_inter_edges):
|
|
block.is_subset = True
|
|
block.is_superset = False
|
|
superset_block_index[id(block)] = index
|
|
block_added = True
|
|
break
|
|
elif superset_block.relative_inter_edges.issubset(
|
|
block.relative_inter_edges):
|
|
superset_block.is_subset = True
|
|
superset_block.is_superset = False
|
|
block.is_subset = False
|
|
block.is_superset = True
|
|
superset_blocks[index] = block
|
|
superset_block_index[id(block)] = index
|
|
block_added = True
|
|
break
|
|
if not block_added:
|
|
block.is_subset = False
|
|
block.is_superset = True
|
|
superset_blocks.append(block)
|
|
superset_block_index[id(block)] = len(superset_blocks) - 1
|
|
for block in block_list:
|
|
assert not (block.is_subset and block.is_superset)
|
|
if block.is_subset:
|
|
superset_block = superset_blocks[superset_block_index[id(
|
|
block)]]
|
|
block_mapping[block.block_id] = superset_block.block_id
|
|
owned_inputs = map(
|
|
lambda x: x[0],
|
|
sorted(block.owned_inputs.items(), key=lambda x: x[1]))
|
|
superset_owned_inputs = map(
|
|
lambda x: x[0],
|
|
sorted(superset_block.owned_inputs.items(),
|
|
key=lambda x: x[1]))
|
|
for from_input_id, to_input_id in zip(
|
|
owned_inputs, superset_owned_inputs):
|
|
from_input_name = self.network.get_input(
|
|
from_input_id).name
|
|
to_input_name = self.network.get_input(to_input_id).name
|
|
layer_mapping[from_input_name] = to_input_name
|
|
for from_layer_id, to_layer_id in zip(
|
|
block.layer_range, superset_block.layer_range):
|
|
from_layer = self.network.get_layer(from_layer_id)
|
|
to_layer = self.network.get_layer(to_layer_id)
|
|
layer_mapping[from_layer.name] = to_layer.name
|
|
for i in range(from_layer.num_outputs):
|
|
from_output = from_layer.get_output(i)
|
|
if from_output.is_network_output:
|
|
to_output = to_layer.get_output(i)
|
|
layer_mapping[from_output.name] = to_output.name
|
|
if block.p2p_type is not None:
|
|
p2p_types[block.block_id] = block.p2p_type
|
|
p2p_tensors[block.block_id] = [
|
|
*set(block.get_input_names())
|
|
]
|
|
for from_name, to_name in zip(
|
|
block.get_input_names(),
|
|
superset_block.get_input_names()):
|
|
layer_mapping[
|
|
f"p2p_block{block.block_id}_{from_name}"] = f"p2p_block{superset_block.block_id}_{to_name}"
|
|
stage_id = 0
|
|
block_to_stage = {}
|
|
for block in self.blocks:
|
|
if block.p2p_type is not None:
|
|
stage_id += 1
|
|
block_to_stage[block.block_id] = stage_id
|
|
return GraphMapping(
|
|
layer_mapping,
|
|
block_mapping,
|
|
p2p_types,
|
|
p2p_tensors,
|
|
block_to_stage,
|
|
)
|
|
|
|
def create_simplified_graph(self, graph_config: GraphConfig):
|
|
new_graph = PipelineGraph.create_graph()
|
|
new_graph._io_buffer_mapping = self.graph._io_buffer_mapping
|
|
layer_mapping = graph_config.graph_mapping.layer_mapping
|
|
|
|
for i in range(self.network.num_inputs):
|
|
trt_input = self.network.get_input(i)
|
|
if trt_input.name not in layer_mapping:
|
|
new_graph.add_input(trt_input)
|
|
|
|
last_blocks = {}
|
|
same_spec_mapping = {}
|
|
same_spec_layer_mapping = {}
|
|
shape_mapping = {}
|
|
building_block_id = 0
|
|
same_spec_ids = {}
|
|
same_spec_count = 0
|
|
for block in self.blocks:
|
|
if not block.is_subset:
|
|
stage_type = None
|
|
if not block.is_superset:
|
|
if block.block_id == 0:
|
|
stage_type = StageType.START
|
|
elif block.block_id == len(self.blocks) - 1:
|
|
stage_type = StageType.END
|
|
input_mapping = block.get_input_mapping(last_blocks)
|
|
for from_name, to_name in [*input_mapping.items()]:
|
|
if to_name in same_spec_mapping:
|
|
input_mapping[from_name] = same_spec_mapping[to_name]
|
|
if to_name in layer_mapping:
|
|
input_mapping[from_name] = layer_mapping[to_name]
|
|
if block.is_superset and block.p2p_type is not None:
|
|
for from_name, to_name in [*input_mapping.items()]:
|
|
output_tensor = new_graph.get_tensor(to_name)
|
|
p2p_layer = new_graph.as_trt().add_identity(
|
|
output_tensor.as_trt())
|
|
p2p_layer.name = f"p2p_block{block.block_id}_{from_name}"
|
|
p2p_tensor = p2p_layer.get_output(0)
|
|
p2p_tensor.name = f"{p2p_layer.name}_output"
|
|
wrapped_layer = new_graph.register_layer(p2p_layer)
|
|
wrapped_layer.attrs[
|
|
"building_block_id"] = building_block_id
|
|
wrapped_layer.attrs["p2p_type"] = block.p2p_type
|
|
input_mapping[from_name] = p2p_tensor.name
|
|
shape_mapping[p2p_tensor.name] = from_name
|
|
building_block_id += 1
|
|
for i in block.sorted_layer_ids:
|
|
layer = self.network.get_layer(i)
|
|
wrapped_layer = new_graph.add_layer(
|
|
layer,
|
|
input_mapping=input_mapping,
|
|
)
|
|
wrapped_layer.attrs["building_block_id"] = building_block_id
|
|
wrapped_layer.attrs["stage_type"] = stage_type
|
|
if block.is_superset:
|
|
last_blocks[block.type_id] = block
|
|
|
|
if block.type_id in same_spec_ids:
|
|
same_spec_id = same_spec_ids[block.type_id]
|
|
update_same_spec_count = False
|
|
else:
|
|
same_spec_id = same_spec_count
|
|
same_spec_ids[block.type_id] = same_spec_id
|
|
update_same_spec_count = True
|
|
count = same_spec_id
|
|
for i, (layer_offset,
|
|
output_index) in enumerate(block.outputs):
|
|
layer = self.network.get_layer(block.layer_range.start +
|
|
layer_offset)
|
|
tensor_name = layer.get_output(output_index).name
|
|
output_tensor = new_graph.get_tensor(tensor_name)
|
|
same_spec_layer = new_graph.as_trt().add_identity(
|
|
output_tensor.as_trt())
|
|
same_spec_layer.name = f"{tensor_name}_same_spec"
|
|
same_spec_tensor = same_spec_layer.get_output(0)
|
|
same_spec_tensor.name = f"{same_spec_layer.name}_output"
|
|
wrapped_layer = new_graph.register_layer(
|
|
same_spec_layer)
|
|
wrapped_layer.attrs[
|
|
"building_block_id"] = building_block_id
|
|
wrapped_layer.attrs["same_spec_id"] = count
|
|
count += 1
|
|
same_spec_mapping[tensor_name] = same_spec_tensor.name
|
|
same_spec_layer_mapping[
|
|
same_spec_layer.name] = layer.name
|
|
shape_mapping[same_spec_tensor.name] = tensor_name
|
|
for i, graph_input_index in enumerate(
|
|
block.owned_inputs.keys()):
|
|
input_name = self.network.get_input(
|
|
graph_input_index).name
|
|
input_tensor = new_graph.get_input(input_name)
|
|
input_tensor.attrs["same_spec_id"] = count
|
|
count += 1
|
|
if update_same_spec_count:
|
|
same_spec_count = count
|
|
building_block_id += 1
|
|
graph_config.graph_mapping.same_spec_layer_mapping = same_spec_layer_mapping
|
|
|
|
if len(self.backbone_blocks) >= 2:
|
|
start_block = self.backbone_blocks[0]
|
|
if start_block.is_subset:
|
|
start_block = self.blocks[graph_config.graph_mapping.
|
|
block_mapping[start_block.block_id]]
|
|
for i in start_block.layer_range:
|
|
layer_name = self.network.get_layer(i).name
|
|
layer = new_graph.get_layer(layer_name)
|
|
layer.attrs["in_start_block"] = True
|
|
end_block = self.backbone_blocks[-1]
|
|
if end_block.is_subset:
|
|
end_block = self.blocks[graph_config.graph_mapping.
|
|
block_mapping[end_block.block_id]]
|
|
for i in end_block.layer_range:
|
|
layer_name = self.network.get_layer(i).name
|
|
layer = new_graph.get_layer(layer_name)
|
|
layer.attrs["in_end_block"] = True
|
|
slowest_p2p_type = None
|
|
if graph_config.has_cross_host:
|
|
slowest_p2p_type = P2PType.CROSS_HOST
|
|
elif graph_config.has_cross_device:
|
|
slowest_p2p_type = P2PType.CROSS_DEVICE
|
|
if slowest_p2p_type is not None:
|
|
for block in self.blocks:
|
|
if block.is_superset and block.p2p_type == slowest_p2p_type:
|
|
for i in block.layer_range:
|
|
layer_name = self.network.get_layer(i).name
|
|
layer = new_graph.get_layer(layer_name)
|
|
layer.attrs["in_slowest_block"] = True
|
|
|
|
for i in range(self.network.num_outputs):
|
|
trt_output = self.network.get_output(i)
|
|
output = self.graph.get_output(trt_output.name)
|
|
if output.producer is not None and output.producer.index in self.layer_to_block and self.layer_to_block[
|
|
output.producer.index].is_subset:
|
|
continue
|
|
if trt_output.is_shape_tensor:
|
|
new_output = new_graph.add_output_shape(trt_output)
|
|
else:
|
|
new_output = new_graph.add_output(trt_output)
|
|
sharded_io = False
|
|
for pattern in self.sharded_io_allowlist:
|
|
if re.match(pattern, new_output.name):
|
|
sharded_io = True
|
|
break
|
|
if not sharded_io:
|
|
new_output.producer.attrs["is_replicated"] = True
|
|
|
|
for input in new_graph.inputs:
|
|
input_name = input.name
|
|
sharded_io = False
|
|
for pattern in self.sharded_io_allowlist:
|
|
if re.match(pattern, input_name):
|
|
sharded_io = True
|
|
break
|
|
if not sharded_io:
|
|
input.attrs["is_replicated"] = True
|
|
for pattern, repl in self.same_spec_io.items():
|
|
if re.match(pattern, input_name):
|
|
output_name = re.sub(pattern, repl, input_name)
|
|
output = new_graph.get_output(output_name)
|
|
if output is not None:
|
|
if "same_spec_id" in input.attrs:
|
|
same_spec_id = input.attrs["same_spec_id"]
|
|
else:
|
|
same_spec_id = same_spec_count
|
|
same_spec_count += 1
|
|
input.attrs["same_spec_id"] = same_spec_id
|
|
output.attrs["same_spec_id"] = same_spec_id
|
|
if math.prod(self.graph.get_input(
|
|
input_name).shape) < math.prod(
|
|
self.graph.get_output(output_name).shape):
|
|
input.attrs["no_memory_footprint"] = True
|
|
else:
|
|
output.attrs["no_memory_footprint"] = True
|
|
|
|
return new_graph, shape_mapping
|
|
|
|
def enrich_shape_info(self, shape_mapping):
|
|
shapes = self.shape_info.shapes.copy()
|
|
max_shapes = self.shape_info.max_shapes.copy()
|
|
values = self.shape_info.values.copy()
|
|
shape_layers = self.shape_info.shape_layers
|
|
for from_name, to_name in shape_mapping.items():
|
|
if to_name in shapes:
|
|
shapes[from_name] = shapes[to_name]
|
|
if to_name in max_shapes:
|
|
max_shapes[from_name] = max_shapes[to_name]
|
|
if to_name in values:
|
|
values[from_name] = values[to_name]
|
|
shape_info = ShapeInfo(shapes, values, shape_layers, max_shapes)
|
|
return shape_info
|
|
|
|
def simplify_graph(
|
|
self, phy_mesh: PhysicalDeviceMesh, num_stages: int,
|
|
num_devices_per_stage: int) -> Tuple[PipelineGraph, GraphConfig]:
|
|
num_blocks = len(self.backbone_blocks)
|
|
if num_blocks % num_stages != 0:
|
|
return None, None
|
|
graph_config = GraphConfig()
|
|
graph_config.num_micro_batches = self.num_micro_batches
|
|
graph_config.num_blocks = num_blocks
|
|
graph_config.num_stages = num_stages
|
|
graph_config.phy_mesh = phy_mesh
|
|
stage_phy_meshes = phy_mesh.split_pipeline_meshes(
|
|
num_stages, num_devices_per_stage)
|
|
graph_config.stage_phy_meshes = stage_phy_meshes
|
|
with silent_trt_logger():
|
|
self.clean_blocks()
|
|
self.mark_p2p_type(phy_mesh, stage_phy_meshes, graph_config)
|
|
graph_config.graph_mapping = self.get_graph_mapping()
|
|
new_graph, shape_mapping = self.create_simplified_graph(
|
|
graph_config)
|
|
shape_info = self.enrich_shape_info(shape_mapping)
|
|
new_graph.assign_shapes(shape_info)
|
|
return new_graph, graph_config
|
|
|
|
def get_graph_mapping_for_shape(self):
|
|
layer_mapping = {}
|
|
tensor_mapping = {}
|
|
for block_list in self.blocks_by_edge_hash.values():
|
|
head_block = block_list[0]
|
|
for block in block_list[1:]:
|
|
for from_layer_id, to_layer_id in zip(block.layer_range,
|
|
head_block.layer_range):
|
|
from_layer = self.network.get_layer(from_layer_id)
|
|
to_layer = self.network.get_layer(to_layer_id)
|
|
layer_mapping[from_layer.name] = to_layer.name
|
|
for i in range(from_layer.num_outputs):
|
|
tensor_mapping[from_layer.get_output(
|
|
i).name] = to_layer.get_output(i).name
|
|
return layer_mapping, tensor_mapping
|
|
|
|
def create_simplified_graph_for_shape(self):
|
|
new_graph = PipelineGraph.create_graph()
|
|
|
|
for i in range(self.network.num_inputs):
|
|
trt_input = self.network.get_input(i)
|
|
new_graph.add_input(trt_input)
|
|
|
|
head_blocks = {}
|
|
removed_blocks = set()
|
|
removed_layers = set()
|
|
for block_list in self.blocks_by_edge_hash.values():
|
|
head_block = block_list[0]
|
|
head_blocks[head_block.type_id] = head_block
|
|
for block in block_list[1:]:
|
|
removed_blocks.add(id(block))
|
|
for layer_index in block.layer_range:
|
|
removed_layers.add(layer_index)
|
|
|
|
for block in self.blocks:
|
|
if not id(block) in removed_blocks:
|
|
input_mapping = block.get_input_mapping(head_blocks)
|
|
for i in block.sorted_layer_ids:
|
|
layer = self.network.get_layer(i)
|
|
new_graph.add_layer(
|
|
layer,
|
|
input_mapping=input_mapping,
|
|
)
|
|
|
|
for i in range(self.network.num_outputs):
|
|
trt_output = self.network.get_output(i)
|
|
output = self.graph.get_output(trt_output.name)
|
|
if output.producer is not None and output.producer.index in removed_layers:
|
|
continue
|
|
if trt_output.is_shape_tensor:
|
|
new_graph.add_output_shape(trt_output)
|
|
else:
|
|
new_graph.add_output(trt_output)
|
|
|
|
return new_graph
|
|
|
|
def get_full_shape_info(self, num_micro_batches):
|
|
layer_mapping, tensor_mapping = self.graph_mapping_for_shape
|
|
optimization_profiles = self.llm_network._generate_optimization_profiles(
|
|
)
|
|
if len(optimization_profiles) > 0:
|
|
optimization_profile = optimization_profiles[-1]
|
|
else:
|
|
optimization_profile = None
|
|
shape_info = get_shape_info(self.graph_for_shape.as_trt(),
|
|
optimization_profile)
|
|
max_shape_info = get_shape_info(self.graph_for_shape.as_trt(),
|
|
optimization_profile,
|
|
shape_type=ShapeType.MAX)
|
|
shape_info.max_shapes = max_shape_info.shapes
|
|
for removed_tensor_name, tensor_name in tensor_mapping.items():
|
|
shape_info.shapes[removed_tensor_name] = shape_info.shapes[
|
|
tensor_name]
|
|
shape_info.max_shapes[removed_tensor_name] = shape_info.max_shapes[
|
|
tensor_name]
|
|
if tensor_name in shape_info.values:
|
|
shape_info.values[removed_tensor_name] = shape_info.values[
|
|
tensor_name]
|
|
for removed_layer_name, layer_name in layer_mapping.items():
|
|
if layer_name in shape_info.shape_layers:
|
|
shape_info.shape_layers.add(removed_layer_name)
|
|
return shape_info
|
|
|
|
def init_layer_hash(self):
|
|
with silent_trt_logger():
|
|
optimization_profiles = self.llm_network._generate_optimization_profiles(
|
|
)
|
|
if len(optimization_profiles) > 0:
|
|
optimization_profile = optimization_profiles[-1]
|
|
else:
|
|
optimization_profile = None
|
|
shape_info = get_shape_info(self.network, optimization_profile)
|
|
dtypes = {tensor.name: tensor.dtype for tensor in self.graph.tensors}
|
|
for layer in self.graph.layers:
|
|
layer_hash = get_cache_key(
|
|
layer.as_trt(),
|
|
shape_info.shapes,
|
|
shape_info.values,
|
|
dtypes,
|
|
)
|
|
layer.attrs["hash"] = layer_hash
|