mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
2300 lines
99 KiB
Python
2300 lines
99 KiB
Python
import contextlib
|
||
import copy
|
||
import itertools
|
||
import re
|
||
from abc import ABC, abstractmethod
|
||
from dataclasses import dataclass, field
|
||
from typing import Any, ClassVar, Dict, List, Sequence, Set, Tuple, Union
|
||
|
||
import numpy as np
|
||
import tensorrt as trt
|
||
import torch
|
||
from filelock import FileLock
|
||
|
||
from tensorrt_llm import serialization
|
||
from tensorrt_llm._utils import (str_dtype_to_trt, trt_dtype_to_np,
|
||
trt_dtype_to_torch)
|
||
from tensorrt_llm.functional import AllReduceParams, create_allreduce_plugin
|
||
from tensorrt_llm.logger import logger
|
||
from tensorrt_llm.mapping import Mapping
|
||
from tensorrt_llm.network import (PluginInfo, delete_plugin_info, get_np_weight,
|
||
get_plugin_info, set_plugin_info)
|
||
from tensorrt_llm.plugin import TRT_LLM_PLUGIN_NAMESPACE, init_all_reduce_helper
|
||
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
|
||
from tensorrt_llm.version import __version__
|
||
|
||
from .config import AutoParallelConfig
|
||
from .device_mesh import LogicalDeviceMesh
|
||
from .pipeline_graph import Layer, PipelineGraph, Tensor
|
||
from .shape_info import (ShapeInfo, get_per_layer_graph, get_shape_layers,
|
||
infer_per_layer_shapes)
|
||
from .simplifier import GraphConfig, GraphMapping, Simplifier, StageType
|
||
from .tensor_parallel.comm_spec import CommSpec
|
||
from .tensor_parallel.plugin_nodes.gpt_attention_node import (
|
||
GPTAttentionPlugin, IdxEntry, IdxEntryParser)
|
||
from .tensor_parallel.sharding_spec import ShardingSpec, get_sharding_sequence
|
||
from .tensor_parallel.sharding_strategy import ShardingStrategy
|
||
from .utils import (get_updated_plugin, to_base_class_layer, to_subclass_layer,
|
||
to_trt_weights)
|
||
|
||
default_int_dtype = trt.int64
|
||
|
||
# These dataclasses are used in ParallelConfig serialization. If there are other classes need to be serialized, please add to this list.
|
||
BASE_AUTOPP_CLASSES = {
|
||
"tensorrt_llm.auto_parallel.parallelization": ["ParallelConfig"],
|
||
"tensorrt_llm.auto_parallel.config": ["AutoParallelConfig", "CostModel"],
|
||
"tensorrt_llm.auto_parallel.simplifier": ["GraphConfig", "StageType"]
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class ParallelConfig:
|
||
VERSION: ClassVar[str] = __version__
|
||
|
||
version: str = VERSION
|
||
network_hash: str = None
|
||
auto_parallel_config: AutoParallelConfig = None
|
||
graph_config: GraphConfig = None
|
||
lmesh: LogicalDeviceMesh = None
|
||
cost: float = None
|
||
graph_strategy: Dict[str, ShardingStrategy] = None
|
||
stage_type: StageType = None
|
||
|
||
def save(self, filename):
|
||
with open(filename, 'wb') as file:
|
||
serialization.dump(self, file)
|
||
|
||
@staticmethod
|
||
def from_file(filename) -> "ParallelConfig":
|
||
with open(filename, "rb") as file:
|
||
return serialization.load(file,
|
||
approved_imports=BASE_AUTOPP_CLASSES)
|
||
|
||
def print_graph_strategy(self, file=None):
|
||
for index, (node_name,
|
||
strategy) in enumerate(self.graph_strategy.items()):
|
||
print(f'\n[{index}]: node_name = {node_name}', file=file)
|
||
strategy.print_strategy(best_resharding_cost_only=True, file=file)
|
||
|
||
|
||
def desimplify_strategy(
|
||
graph: PipelineGraph,
|
||
graph_strategy: Dict[str, ShardingStrategy],
|
||
graph_mapping: GraphMapping,
|
||
):
|
||
for strategy in graph_strategy.values():
|
||
for name, commspec in list(strategy.communication_actions.items()):
|
||
strategy.communication_actions[name] = [commspec]
|
||
strategy.sharding_specs[
|
||
f"{name}_after_comm"] = strategy.sharding_specs[name]
|
||
|
||
# insert same spec layers' communication actions after
|
||
# its producer's communication actions
|
||
same_spec_layer_mapping = graph_mapping.same_spec_layer_mapping
|
||
for same_spec_layer_name in same_spec_layer_mapping.keys():
|
||
same_spec_strategy = graph_strategy[same_spec_layer_name]
|
||
same_spec_commspecs = same_spec_strategy.best_resharding_cost[0][0][1]
|
||
if len(same_spec_commspecs) == 0:
|
||
continue
|
||
output_name = same_spec_layer_name[:-len("_same_spec")]
|
||
output = graph.get_tensor(output_name)
|
||
layer_name = output.producer.name
|
||
output_index = output.output_index
|
||
strategy = graph_strategy[layer_name]
|
||
commspecs = strategy.communication_actions.get(f"output{output_index}",
|
||
[])
|
||
commspecs.extend(same_spec_commspecs)
|
||
strategy.communication_actions[f"output{output_index}"] = commspecs
|
||
strategy.sharding_specs[
|
||
f"output{output_index}_after_comm"] = same_spec_strategy.sharding_specs[
|
||
"output0"]
|
||
|
||
layer_mapping = graph_mapping.layer_mapping
|
||
for removed_layer_name, layer_name in layer_mapping.items():
|
||
if layer_name in graph_strategy:
|
||
strategy = copy.copy(graph_strategy[layer_name])
|
||
layer = graph.get_layer(removed_layer_name)
|
||
if layer is not None:
|
||
strategy.node_names = strategy.node_names.copy()
|
||
for index, name in list(strategy.node_names.items()):
|
||
input = layer.get_input(index)
|
||
node_name = input.name if input.producer is None else input.producer.name
|
||
strategy.node_names[index] = node_name
|
||
graph_strategy[removed_layer_name] = strategy
|
||
|
||
|
||
@dataclass
|
||
class SplitInfo:
|
||
input_dim: Union[int, trt.ITensor]
|
||
partition: int
|
||
|
||
def __deepcopy__(self, memo) -> "SplitInfo":
|
||
return SplitInfo(self.input_dim, self.partition)
|
||
|
||
|
||
@dataclass
|
||
class TensorInfo:
|
||
name: str = None
|
||
split_infos: Dict[int, SplitInfo] = field(default_factory=dict)
|
||
|
||
def set_split_info(self, dim, split_info):
|
||
self.split_infos[dim] = split_info
|
||
|
||
def __deepcopy__(self, memo) -> "TensorInfo":
|
||
return TensorInfo(self.name, copy.deepcopy(self.split_infos))
|
||
|
||
|
||
@dataclass
|
||
class TensorContext:
|
||
info_by_device: Dict[int, TensorInfo] = field(default_factory=dict)
|
||
device_dims_for_shape: Set[int] = field(default_factory=set)
|
||
|
||
def update_name_mapping(self, device_id, new_name):
|
||
if device_id not in self.info_by_device:
|
||
self.info_by_device[device_id] = TensorInfo()
|
||
self.info_by_device[device_id].name = new_name
|
||
|
||
def set_split_info(self, device_id, dim, split_info):
|
||
if device_id not in self.info_by_device:
|
||
self.info_by_device[device_id] = TensorInfo()
|
||
self.info_by_device[device_id].set_split_info(dim, split_info)
|
||
|
||
def set_split_infos(self, device_id, split_infos: Dict[int, SplitInfo]):
|
||
if device_id not in self.info_by_device:
|
||
self.info_by_device[device_id] = TensorInfo()
|
||
self.info_by_device[device_id].split_infos = split_infos
|
||
|
||
def __deepcopy__(self, memo) -> "TensorContext":
|
||
return TensorContext(copy.deepcopy(self.info_by_device),
|
||
set(self.device_dims_for_shape))
|
||
|
||
|
||
@dataclass
|
||
class LayerUpdate:
|
||
updated_attrs: Dict[str, Any] = field(default_factory=dict)
|
||
updated_inputs: Dict[int, trt.ITensor] = field(default_factory=dict)
|
||
split_info_updated: bool = False
|
||
|
||
@staticmethod
|
||
def none() -> "LayerUpdate":
|
||
return LayerUpdate()
|
||
|
||
|
||
@dataclass
|
||
class GraphContext:
|
||
tensor_contexts: Dict[str, TensorContext] = field(default_factory=dict)
|
||
|
||
def get_name(self, tensor_name, device_id):
|
||
if tensor_name not in self.tensor_contexts:
|
||
return None
|
||
if device_id not in self.tensor_contexts[tensor_name].info_by_device:
|
||
return None
|
||
return self.tensor_contexts[tensor_name].info_by_device[device_id].name
|
||
|
||
def update_name_mapping(self, tensor_name, device_id, new_name):
|
||
if tensor_name not in self.tensor_contexts:
|
||
self.tensor_contexts[tensor_name] = TensorContext()
|
||
self.tensor_contexts[tensor_name].update_name_mapping(
|
||
device_id, new_name)
|
||
|
||
def get_name_mapping(self, device_id, prefix: str) -> Dict[str, str]:
|
||
name_mapping = {}
|
||
for tensor_name in self.tensor_contexts.keys():
|
||
new_name = self.get_name(tensor_name, device_id)
|
||
if new_name is not None:
|
||
name_mapping[f"{prefix}{tensor_name}"] = new_name
|
||
return name_mapping
|
||
|
||
def add_device_dims_for_shape(self, tensor_name: str,
|
||
device_dims: Sequence[int]):
|
||
if tensor_name not in self.tensor_contexts:
|
||
self.tensor_contexts[tensor_name] = TensorContext()
|
||
self.tensor_contexts[tensor_name].device_dims_for_shape.update(
|
||
device_dims)
|
||
|
||
def get_device_dims_for_shape(self, tensor_name: str):
|
||
if tensor_name not in self.tensor_contexts:
|
||
return set()
|
||
return self.tensor_contexts[tensor_name].device_dims_for_shape
|
||
|
||
def get_split_infos(self, tensor_name, device_id):
|
||
if tensor_name not in self.tensor_contexts:
|
||
return None
|
||
if device_id not in self.tensor_contexts[tensor_name].info_by_device:
|
||
return None
|
||
return self.tensor_contexts[tensor_name].info_by_device[
|
||
device_id].split_infos
|
||
|
||
def set_split_info(self, tensor_name, device_id, dim, split_info):
|
||
if tensor_name not in self.tensor_contexts:
|
||
self.tensor_contexts[tensor_name] = TensorContext()
|
||
self.tensor_contexts[tensor_name].set_split_info(
|
||
device_id, dim, split_info)
|
||
|
||
def set_split_infos(self, tensor_name, device_id,
|
||
split_infos: Dict[int, SplitInfo]):
|
||
if tensor_name not in self.tensor_contexts:
|
||
self.tensor_contexts[tensor_name] = TensorContext()
|
||
self.tensor_contexts[tensor_name].set_split_infos(
|
||
device_id, split_infos)
|
||
|
||
def update_layer_context(self, wrapped_layer: Layer,
|
||
layer_update: LayerUpdate,
|
||
local_context: "GraphContext", device_id: int,
|
||
device_ids: np.ndarray,
|
||
sharding_specs: Dict[str, ShardingSpec]):
|
||
layer = wrapped_layer.as_trt()
|
||
for i in range(layer.num_outputs):
|
||
output = layer.get_output(i)
|
||
new_name = local_context.get_name(output.name, device_id)
|
||
if new_name is not None:
|
||
self.update_name_mapping(output.name, device_id, new_name)
|
||
if layer_update.split_info_updated:
|
||
for i in range(layer.num_outputs):
|
||
output = layer.get_output(i)
|
||
split_infos = local_context.get_split_infos(
|
||
output.name, device_id)
|
||
if split_infos is not None:
|
||
self.set_split_infos(output.name, device_id, split_infos)
|
||
return
|
||
split_info_by_device_dim = {}
|
||
for i in range(layer.num_inputs):
|
||
input = layer.get_input(i)
|
||
if input is None:
|
||
continue
|
||
sharding_spec = sharding_specs[f"input{i}"]
|
||
split_infos = local_context.get_split_infos(input.name, device_id)
|
||
if split_infos is None:
|
||
continue
|
||
for dim, split_info in split_infos.items():
|
||
device_dim = tuple(sharding_spec.dim_partition_dict[dim])
|
||
split_info_by_device_dim[device_dim] = split_info
|
||
for i in range(layer.num_outputs):
|
||
output = layer.get_output(i)
|
||
sharding_spec = sharding_specs[f"output{i}"]
|
||
for dim, device_dim in sharding_spec.dim_partition_dict.items():
|
||
split_info = split_info_by_device_dim.get(tuple(device_dim))
|
||
if split_info is None:
|
||
if device_dim == [0, 1] or device_dim == [1, 0]:
|
||
if (0, ) in split_info_by_device_dim and (
|
||
1, ) in split_info_by_device_dim:
|
||
split_info = SplitInfo(
|
||
split_info_by_device_dim[(0, )].input_dim *
|
||
split_info_by_device_dim[(1, )].input_dim,
|
||
split_info_by_device_dim[(0, )].partition *
|
||
split_info_by_device_dim[(1, )].partition,
|
||
)
|
||
assert split_info is not None
|
||
partition = get_partition(device_dim, device_ids)
|
||
if split_info.input_dim != output.shape[dim]:
|
||
assert output.shape[
|
||
dim] > 0 and output.shape[dim] % partition == 0
|
||
output_split_info = SplitInfo(output.shape[dim], partition)
|
||
self.set_split_info(output.name, device_id, dim,
|
||
output_split_info)
|
||
|
||
def get_local_context(self, layer: trt.ILayer) -> "GraphContext":
|
||
local_context = GraphContext()
|
||
for i in range(layer.num_inputs):
|
||
input = layer.get_input(i)
|
||
if input is None:
|
||
continue
|
||
local_context.tensor_contexts[input.name] = copy.deepcopy(
|
||
self.tensor_contexts[input.name])
|
||
return local_context
|
||
|
||
def get_local_context_for_output(self,
|
||
output: trt.ITensor) -> "GraphContext":
|
||
local_context = GraphContext()
|
||
local_context.tensor_contexts[output.name] = copy.deepcopy(
|
||
self.tensor_contexts[output.name])
|
||
return local_context
|
||
|
||
def merge_context(self, context: "GraphContext"):
|
||
self.tensor_contexts.update(context.tensor_contexts)
|
||
|
||
|
||
@dataclass
|
||
class ShardContext:
|
||
graph_context: GraphContext
|
||
layer: Layer
|
||
nditer: np.nditer
|
||
device_ids: np.ndarray
|
||
strategy: ShardingStrategy
|
||
|
||
|
||
def get_partition(device_dim, device_ids):
|
||
if device_dim == [0]:
|
||
partition = device_ids.shape[0]
|
||
elif device_dim == [1]:
|
||
partition = device_ids.shape[1]
|
||
else:
|
||
assert device_dim == [0, 1] or device_dim == [1, 0]
|
||
partition = device_ids.size
|
||
return partition
|
||
|
||
|
||
def get_index(device_dim, iter):
|
||
if device_dim == [0]:
|
||
index = iter.multi_index[0]
|
||
elif device_dim == [1]:
|
||
index = iter.multi_index[1]
|
||
else:
|
||
assert device_dim == [0, 1] or device_dim == [1, 0]
|
||
index = iter.iterindex
|
||
return index
|
||
|
||
|
||
def get_full_sharding_spec(sharding_spec):
|
||
return ShardingSpec(sharding_spec.device_mesh,
|
||
sharding_spec.data_type_size,
|
||
sharding_spec.entire_shape,
|
||
sharding_spec.max_entire_shape,
|
||
sharding_spec.raw_shape,
|
||
dim_partition_dict={})
|
||
|
||
|
||
def get_comm_action_sequence(from_sharding_sepc, to_sharding_sepc):
|
||
comm_action_sequence = from_sharding_sepc.device_mesh.shape_consistency_manager.shape_consistency(
|
||
from_sharding_sepc, to_sharding_sepc)[1]
|
||
# TODO: should merged by shape_consistency
|
||
if len(comm_action_sequence) == 2:
|
||
if comm_action_sequence[0].comm_pattern == comm_action_sequence[
|
||
1].comm_pattern == "all_gather":
|
||
if comm_action_sequence[0].gather_dim == comm_action_sequence[
|
||
1].gather_dim:
|
||
comm_action_sequence = [
|
||
CommSpec(
|
||
comm_action_sequence[0].comm_pattern,
|
||
comm_action_sequence[0].sharding_spec,
|
||
comm_action_sequence[0].gather_dim,
|
||
comm_action_sequence[0].shard_dim, [[
|
||
*comm_action_sequence[0].logical_process_axis[0],
|
||
*comm_action_sequence[1].logical_process_axis[0]
|
||
]], comm_action_sequence[0].mix_gather,
|
||
comm_action_sequence[0].forward_only)
|
||
]
|
||
assert len(comm_action_sequence[0].logical_process_axis[0]) <= 2
|
||
assert len(comm_action_sequence) <= 1
|
||
return comm_action_sequence
|
||
|
||
|
||
class GraphGroup(ABC):
|
||
|
||
@staticmethod
|
||
def from_graph(
|
||
graph: PipelineGraph,
|
||
config: ParallelConfig,
|
||
auto_parallel_config: AutoParallelConfig,
|
||
) -> "GraphGroup":
|
||
if auto_parallel_config.debug_mode:
|
||
return PrefixedGraphGroup(graph, config, auto_parallel_config)
|
||
else:
|
||
return DistributedGraphGroup(graph, config, auto_parallel_config)
|
||
|
||
@property
|
||
@abstractmethod
|
||
def auto_parallel_config(self) -> AutoParallelConfig:
|
||
...
|
||
|
||
@abstractmethod
|
||
def add_input(self, tensor, device_ids, strategy: ShardingStrategy):
|
||
...
|
||
|
||
@abstractmethod
|
||
def add_layer(self, layer, device_ids, strategy: ShardingStrategy):
|
||
...
|
||
|
||
@abstractmethod
|
||
def add_output(self, tensor, device_ids, sharding_spec: ShardingSpec):
|
||
...
|
||
|
||
@abstractmethod
|
||
def get_network(self, device_id) -> trt.INetworkDefinition:
|
||
...
|
||
|
||
@abstractmethod
|
||
def get_graph(self, device_id) -> PipelineGraph:
|
||
...
|
||
|
||
@property
|
||
@abstractmethod
|
||
def full_graph(self) -> PipelineGraph:
|
||
...
|
||
|
||
@abstractmethod
|
||
def get_prefix(self, device_id) -> str:
|
||
...
|
||
|
||
@abstractmethod
|
||
def get_shapes(self, device_id) -> Dict[str, Tuple[int, ...]]:
|
||
...
|
||
|
||
@abstractmethod
|
||
def get_values(self, device_id) -> Dict[str, List[int]]:
|
||
...
|
||
|
||
@abstractmethod
|
||
def add_all_reduce_layer(self, context: GraphContext, input_name,
|
||
output_name, device_ids, to_reduce_tensors):
|
||
...
|
||
|
||
@abstractmethod
|
||
def add_all_gather_layer(self, context: GraphContext, input_name,
|
||
output_name, device_ids, to_gather_tensors):
|
||
...
|
||
|
||
@abstractmethod
|
||
def register_layer(self,
|
||
layer,
|
||
base_name,
|
||
input_name,
|
||
output_name=None,
|
||
device_id=None,
|
||
keep_tensor_name=False) -> Layer:
|
||
...
|
||
|
||
def get_tensor(self, context: GraphContext, tensor_name: str,
|
||
device_id: int) -> Tensor:
|
||
name = context.get_name(tensor_name, device_id)
|
||
return self.get_graph(device_id).get_tensor(name)
|
||
|
||
def add_comm(self,
|
||
context: GraphContext,
|
||
input_name,
|
||
device_ids,
|
||
commspec,
|
||
output_name=None,
|
||
is_singleton=False):
|
||
remove_index = []
|
||
for i, device_dim in enumerate(commspec.logical_process_axis):
|
||
partition = get_partition(device_dim, device_ids)
|
||
if partition == 1:
|
||
remove_index.append(i)
|
||
if len(remove_index) > 0:
|
||
if commspec.comm_pattern in ["all_gather", "all_to_all"]:
|
||
commspec.gather_dim = [
|
||
dim for i, dim in enumerate(commspec.gather_dim)
|
||
if i not in remove_index
|
||
]
|
||
if commspec.comm_pattern in [
|
||
"split", "reduce_scatter", "all_to_all"
|
||
]:
|
||
commspec.shard_dim = [
|
||
dim for i, dim in enumerate(commspec.shard_dim)
|
||
if i not in remove_index
|
||
]
|
||
commspec.logical_process_axis = [
|
||
dim for i, dim in enumerate(commspec.logical_process_axis)
|
||
if i not in remove_index
|
||
]
|
||
flatten_device_dim = list(
|
||
itertools.chain.from_iterable(commspec.logical_process_axis))
|
||
if flatten_device_dim == []:
|
||
return
|
||
if flatten_device_dim == [0, 1] or flatten_device_dim == [1, 0]:
|
||
self._add_comm(context, input_name, device_ids, commspec,
|
||
output_name, is_singleton)
|
||
elif flatten_device_dim == [0]:
|
||
for i in range(device_ids.shape[1]):
|
||
self._add_comm(context, input_name, device_ids[:, i:i + 1],
|
||
commspec, output_name, is_singleton)
|
||
elif flatten_device_dim == [1]:
|
||
for i in range(device_ids.shape[0]):
|
||
self._add_comm(context, input_name, device_ids[i:i + 1, :],
|
||
commspec, output_name, is_singleton)
|
||
else:
|
||
raise RuntimeError(
|
||
f"Invalid flatten device_dim: {flatten_device_dim}")
|
||
|
||
def _add_comm(self,
|
||
context: GraphContext,
|
||
input_name,
|
||
device_ids,
|
||
commspec,
|
||
output_name=None,
|
||
is_singleton=False):
|
||
input_tensors = [
|
||
self.get_tensor(context, input_name, device_id.item())
|
||
for device_id in np.nditer(device_ids)
|
||
]
|
||
comm_pattern = commspec.comm_pattern
|
||
if comm_pattern == "split":
|
||
self.add_split(context, input_name, output_name, device_ids,
|
||
commspec.shard_dim, commspec.logical_process_axis)
|
||
elif comm_pattern == "all_gather":
|
||
self.add_all_gather(context, input_name, output_name, device_ids,
|
||
commspec.gather_dim,
|
||
commspec.logical_process_axis, is_singleton)
|
||
elif comm_pattern == "all_reduce":
|
||
self.add_all_reduce(context, input_name, output_name, device_ids)
|
||
elif comm_pattern == "reduce_scatter":
|
||
self.add_reduce_scatter(context, input_name, output_name,
|
||
device_ids, commspec.shard_dim,
|
||
commspec.logical_process_axis)
|
||
elif comm_pattern == "all_to_all":
|
||
self.add_all_to_all(context, input_name, output_name, device_ids,
|
||
commspec.gather_dim, commspec.shard_dim,
|
||
commspec.logical_process_axis)
|
||
else:
|
||
raise NotImplementedError
|
||
output_tensors = [
|
||
self.get_tensor(context, input_name, device_id.item())
|
||
for device_id in np.nditer(device_ids)
|
||
]
|
||
for input_tensor, output_tensor in zip(input_tensors, output_tensors):
|
||
if input_tensor.dtype != output_tensor.dtype:
|
||
raise ValueError(
|
||
f"Input tensor and output tensor should have the same dtype for communication layers, "
|
||
f"input dtype is {input_tensor.dtype} for {input_tensor.name}, "
|
||
f"but output dtype is {output_tensor.dtype} for {output_tensor.name}"
|
||
)
|
||
|
||
def add_all_reduce(self, context: GraphContext, input_name, output_name,
|
||
device_ids):
|
||
dtype = str_dtype_to_trt(self.full_graph._plugin_config.dtype)
|
||
to_reduce_tensors = []
|
||
for device_id in np.nditer(device_ids):
|
||
device_id = device_id.item()
|
||
layer_info = (input_name, output_name, device_id)
|
||
network = self.get_network(device_id)
|
||
input_tensor = self.get_tensor(context, input_name,
|
||
device_id).as_trt()
|
||
input_dtype = input_tensor.dtype
|
||
if input_dtype != dtype:
|
||
to_reduce_tensor = self.cast(
|
||
network,
|
||
input_tensor,
|
||
dtype,
|
||
layer_info,
|
||
)
|
||
else:
|
||
to_reduce_tensor = input_tensor
|
||
to_reduce_tensors.append(to_reduce_tensor)
|
||
self.add_all_reduce_layer(context, input_name, output_name, device_ids,
|
||
to_reduce_tensors)
|
||
if input_dtype != dtype:
|
||
for device_id in np.nditer(device_ids):
|
||
device_id = device_id.item()
|
||
layer_info = (input_name, output_name, device_id)
|
||
network = self.get_network(device_id)
|
||
input_tensor = self.get_tensor(
|
||
context,
|
||
input_name,
|
||
device_id,
|
||
).as_trt()
|
||
output_tensor = self.cast(
|
||
network,
|
||
input_tensor,
|
||
input_dtype,
|
||
layer_info,
|
||
)
|
||
context.update_name_mapping(
|
||
input_name,
|
||
device_id,
|
||
output_tensor.name,
|
||
)
|
||
|
||
def add_reduce_scatter(self, context: GraphContext, input_name, output_name,
|
||
device_ids, shard_dims, device_dims):
|
||
self.add_all_reduce(context, input_name, output_name, device_ids)
|
||
self.add_split(context, input_name, output_name, device_ids, shard_dims,
|
||
device_dims)
|
||
|
||
# TODO: use native all_to_all operation
|
||
def add_all_to_all(self, context: GraphContext, input_name, output_name,
|
||
device_ids, gather_dims, shard_dims, device_dims):
|
||
self.add_all_gather(context, input_name, output_name, device_ids,
|
||
gather_dims, device_dims)
|
||
self.add_split(context, input_name, output_name, device_ids, shard_dims,
|
||
device_dims)
|
||
|
||
def get_item(self, network, tensor, index, layer_info):
|
||
get_item_layer = network.add_slice(tensor, [index], [1], [1])
|
||
self.register_layer(get_item_layer, f"get_item{index}", *layer_info)
|
||
return get_item_layer.get_output(0)
|
||
|
||
def get_shape(self, network, tensor, layer_info):
|
||
shape_layer = network.add_shape(tensor)
|
||
self.register_layer(shape_layer, "shape", *layer_info)
|
||
return shape_layer.get_output(0)
|
||
|
||
def concat(self, network, tensors, layer_info):
|
||
concat_layer = network.add_concatenation(tensors)
|
||
self.register_layer(concat_layer, "concat", *layer_info)
|
||
return concat_layer.get_output(0)
|
||
|
||
def flatten(self, network, tensor, layer_info):
|
||
shuffle_layer = network.add_shuffle(tensor)
|
||
shuffle_layer.reshape_dims = [-1]
|
||
shuffle_layer.zero_is_placeholder = False
|
||
self.register_layer(shuffle_layer, "flatten", *layer_info)
|
||
return shuffle_layer.get_output(0)
|
||
|
||
def reshape(self, network, tensor, reshape_dims, layer_info):
|
||
reshape_layer = network.add_shuffle(tensor)
|
||
reshape_layer.set_input(1, reshape_dims)
|
||
reshape_layer.zero_is_placeholder = False
|
||
self.register_layer(reshape_layer, "reshape", *layer_info)
|
||
return reshape_layer.get_output(0)
|
||
|
||
def cast(self, network, tensor, dtype, layer_info):
|
||
if tensor.dtype == dtype:
|
||
return tensor
|
||
cast_layer = network.add_cast(tensor, dtype)
|
||
self.register_layer(cast_layer, "cast", *layer_info)
|
||
return cast_layer.get_output(0)
|
||
|
||
def const_int(self, network, name, value, layer_info):
|
||
const_layer = network.add_constant(
|
||
[1], np.array([value], dtype=trt_dtype_to_np(default_int_dtype)))
|
||
self.register_layer(const_layer, name, *layer_info)
|
||
return const_layer.get_output(0)
|
||
|
||
def get_dim_size(self, network, tensor, dim, layer_info, shape_tensor=None):
|
||
raw_shape = tensor.shape
|
||
dim_size = raw_shape[dim]
|
||
if dim_size != -1:
|
||
return dim_size
|
||
else:
|
||
if shape_tensor is None:
|
||
shape_tensor = self.get_shape(network, tensor, layer_info)
|
||
return self.get_item(network, shape_tensor, dim, layer_info)
|
||
|
||
def add_split(self, context: GraphContext, input_name, output_name,
|
||
device_ids, shard_dims, device_dims):
|
||
it = np.nditer(device_ids, flags=['multi_index'])
|
||
for device_id in it:
|
||
device_id = device_id.item()
|
||
layer_info = (input_name, output_name, device_id)
|
||
network = self.get_network(device_id)
|
||
input_tensor = self.get_tensor(context, input_name,
|
||
device_id).as_trt()
|
||
raw_input_shape = input_tensor.shape
|
||
start = []
|
||
output_dims = []
|
||
stride = []
|
||
input_shape_tensor = self.get_shape(network, input_tensor,
|
||
layer_info)
|
||
for dim in range(len(raw_input_shape)):
|
||
stride.append(1)
|
||
if dim not in shard_dims:
|
||
start.append(0)
|
||
output_dims.append(
|
||
self.get_item(network, input_shape_tensor, dim,
|
||
layer_info))
|
||
else:
|
||
start.append(None)
|
||
output_dims.append(None)
|
||
|
||
for dim, device_dim in zip(shard_dims, device_dims):
|
||
partition = get_partition(device_dim, device_ids)
|
||
index = get_index(device_dim, it)
|
||
input_dim = raw_input_shape[dim]
|
||
assert input_dim != -1
|
||
assert input_dim % partition == 0
|
||
quotient = input_dim // partition
|
||
start[dim] = index * quotient
|
||
output_dims[dim] = self.const_int(network, f"output_dim{dim}",
|
||
quotient, layer_info)
|
||
context.set_split_info(input_name, device_id, dim,
|
||
SplitInfo(input_dim, partition))
|
||
output_dims_tensor = self.concat(network, output_dims, layer_info)
|
||
split_layer = network.add_slice(input_tensor, start, [], stride)
|
||
split_layer.set_input(2, output_dims_tensor)
|
||
wrapped_layer = self.register_layer(split_layer, "split",
|
||
*layer_info)
|
||
wrapped_layer.attrs["strategy"] = get_sharding_sequence(
|
||
len(raw_input_shape),
|
||
shard_dims,
|
||
device_dims,
|
||
)
|
||
|
||
output_tensor = split_layer.get_output(0)
|
||
context.update_name_mapping(input_name, device_id,
|
||
output_tensor.name)
|
||
|
||
def add_all_gather(self,
|
||
context: GraphContext,
|
||
input_name,
|
||
output_name,
|
||
device_ids,
|
||
gather_dims,
|
||
device_dims,
|
||
is_singleton=False):
|
||
to_gather_tensors = []
|
||
for device_id in np.nditer(device_ids):
|
||
device_id = device_id.item()
|
||
layer_info = (input_name, output_name, device_id)
|
||
network = self.get_network(device_id)
|
||
input_tensor = self.get_tensor(context, input_name,
|
||
device_id).as_trt()
|
||
to_gather_tensor = self.flatten(network, input_tensor, layer_info)
|
||
to_gather_tensors.append(to_gather_tensor)
|
||
|
||
all_gather_layers = self.add_all_gather_layer(
|
||
context,
|
||
input_name,
|
||
output_name,
|
||
device_ids,
|
||
to_gather_tensors,
|
||
)
|
||
|
||
if len(device_dims) == 1:
|
||
gather_indices = [0]
|
||
elif len(device_dims) == 2 and device_dims[0] == [1]:
|
||
gather_indices = [1, 0]
|
||
else:
|
||
gather_indices = [0, 1]
|
||
|
||
for device_id, all_gather_layer in zip(np.nditer(device_ids),
|
||
all_gather_layers):
|
||
device_id = device_id.item()
|
||
layer_info = (input_name, output_name, device_id)
|
||
network = self.get_network(device_id)
|
||
input_tensor = self.get_tensor(context, input_name,
|
||
device_id).as_trt()
|
||
permutation = []
|
||
gathered_dims = []
|
||
output_dims = []
|
||
partitions = []
|
||
raw_input_shape = input_tensor.shape
|
||
|
||
wrapped_layer = self.get_graph(device_id).get_layer(
|
||
all_gather_layer.name)
|
||
wrapped_layer.attrs["strategy"] = get_sharding_sequence(
|
||
len(raw_input_shape),
|
||
gather_dims,
|
||
device_dims,
|
||
)
|
||
|
||
input_shape_layer = network.add_shape(input_tensor)
|
||
self.register_layer(input_shape_layer, "input_shape", *layer_info)
|
||
input_shape_tensor = input_shape_layer.get_output(0)
|
||
split_infos = context.get_split_infos(input_name, device_id)
|
||
for index in gather_indices:
|
||
gather_dim = gather_dims[index]
|
||
device_dim = device_dims[index]
|
||
partition = get_partition(device_dim, device_ids)
|
||
assert partition == split_infos[gather_dim].partition
|
||
partitions.append(
|
||
self.const_int(network, f"partition_num{gather_dim}",
|
||
partition, layer_info))
|
||
for dim in range(len(raw_input_shape)):
|
||
if dim in gather_dims:
|
||
gather_index = gather_dims.index(dim)
|
||
device_dim = device_dims[gather_index]
|
||
permutation.append(gather_indices.index(gather_index))
|
||
permutation.append(dim + len(gather_dims))
|
||
if dim not in split_infos:
|
||
output_dim_layer = network.add_slice(
|
||
input_shape_tensor, [dim], [1], [1])
|
||
self.register_layer(output_dim_layer, f"output_dim{dim}",
|
||
*layer_info)
|
||
dim_tensor = output_dim_layer.get_output(0)
|
||
output_dims.append(dim_tensor)
|
||
gathered_dims.append(dim_tensor)
|
||
else:
|
||
input_dim = split_infos[dim].input_dim
|
||
partition = split_infos[dim].partition
|
||
assert input_dim != -1
|
||
assert input_dim % partition == 0
|
||
quotient = input_dim // partition
|
||
output_dims.append(
|
||
self.const_int(network, f"output_dim{dim}", quotient,
|
||
layer_info))
|
||
if dim in gather_dims:
|
||
gathered_dims.append(
|
||
self.const_int(network, f"gathered_dim{dim}",
|
||
quotient * partition, layer_info))
|
||
del split_infos[dim]
|
||
else:
|
||
gathered_dims.append(output_dim_layer.get_output(0))
|
||
|
||
reshape_dims_for_transpose_layer = network.add_concatenation(
|
||
[*partitions, *output_dims])
|
||
self.register_layer(reshape_dims_for_transpose_layer,
|
||
"reshape_dims_for_transpose", *layer_info)
|
||
reshape_dims_tensor = reshape_dims_for_transpose_layer.get_output(0)
|
||
transpose_layer = network.add_shuffle(
|
||
all_gather_layer.get_output(0))
|
||
transpose_layer.set_input(1, reshape_dims_tensor)
|
||
transpose_layer.second_transpose = permutation
|
||
transpose_layer.zero_is_placeholder = False
|
||
self.register_layer(transpose_layer, "transpose", *layer_info)
|
||
|
||
reshape_dims_for_reshape_layer = network.add_concatenation(
|
||
gathered_dims)
|
||
self.register_layer(reshape_dims_for_reshape_layer,
|
||
"reshape_dims_for_reshape", *layer_info)
|
||
reshape_dims_tensor = reshape_dims_for_reshape_layer.get_output(0)
|
||
output_tensor = self.reshape(
|
||
network,
|
||
transpose_layer.get_output(0),
|
||
reshape_dims_tensor,
|
||
layer_info,
|
||
)
|
||
context.update_name_mapping(input_name, device_id,
|
||
output_tensor.name)
|
||
if is_singleton:
|
||
break
|
||
|
||
def register_unfilled_weights(self, graph, layer):
|
||
if (layer.name in self.full_graph._unfilled_weights
|
||
and layer.name not in graph._unfilled_weights):
|
||
weights, values = self.full_graph._unfilled_weights[layer.name]
|
||
graph._register_unfilled_weights(
|
||
layer.name,
|
||
weights,
|
||
values,
|
||
)
|
||
|
||
def shard_constant(self, context: ShardContext):
|
||
sharding_spec = context.strategy.sharding_specs["output0"]
|
||
shard_dims = sharding_spec.dim_partition_dict
|
||
device_id = context.nditer.value.item()
|
||
device_ids = context.device_ids
|
||
layer = context.layer.as_trt()
|
||
graph = self.get_graph(device_id)
|
||
if len(shard_dims) == 0:
|
||
self.register_unfilled_weights(graph, layer)
|
||
return LayerUpdate(split_info_updated=True)
|
||
flatten_device_dim = list(
|
||
itertools.chain.from_iterable(shard_dims.values()))
|
||
output_name = layer.get_output(0).name
|
||
output_dtype = layer.get_output(0).dtype
|
||
output_shape = layer.shape
|
||
output_dims = []
|
||
weight_index = []
|
||
for dim in range(len(output_shape)):
|
||
output_dim = output_shape[dim]
|
||
if dim in shard_dims:
|
||
device_dim = shard_dims[dim]
|
||
partition = get_partition(device_dim, device_ids)
|
||
index = get_index(device_dim, context.nditer)
|
||
assert output_dim % partition == 0
|
||
quotient = output_dim // partition
|
||
output_dims.append(quotient)
|
||
weight_index.append(
|
||
slice(index * quotient, (index + 1) * quotient))
|
||
context.graph_context.set_split_info(
|
||
output_name, device_id, dim,
|
||
SplitInfo(output_dim, partition))
|
||
else:
|
||
output_dims.append(output_dim)
|
||
weight_index.append(slice(None))
|
||
if layer.name in self.full_graph._unfilled_weights:
|
||
values = self.full_graph._unfilled_weights[layer.name][1]
|
||
else:
|
||
values = layer.weights
|
||
if isinstance(values, trt.Weights):
|
||
values = values.numpy()
|
||
# TODO: remove this WAR after https://nvbugs/4359151 fixed.
|
||
if isinstance(values, trt.Weights):
|
||
network = context.layer.graph.as_trt()
|
||
values = get_np_weight(network, layer.name)
|
||
if values is not None:
|
||
values = values.reshape(layer.shape)
|
||
assert values.size == np.prod(layer.shape)
|
||
sharded_values = values[tuple(weight_index)]
|
||
assert sharded_values.size * get_partition(
|
||
flatten_device_dim, device_ids) == np.prod(layer.shape)
|
||
else:
|
||
sharded_values = None
|
||
dtype = trt_dtype_to_np(output_dtype)
|
||
sharded_weights = np.empty(tuple(output_dims), dtype)
|
||
graph._register_unfilled_weights(
|
||
f"device{device_id}_{layer.name}",
|
||
sharded_weights,
|
||
sharded_values,
|
||
)
|
||
sharded_weights = to_trt_weights(sharded_weights)
|
||
return LayerUpdate(
|
||
updated_attrs=dict(
|
||
shape=trt.Dims(output_dims),
|
||
weights=sharded_weights,
|
||
),
|
||
split_info_updated=True,
|
||
)
|
||
|
||
def shard_fill(self, context: ShardContext):
|
||
sharding_spec = context.strategy.sharding_specs["output0"]
|
||
shard_dims = sharding_spec.dim_partition_dict
|
||
if len(shard_dims) == 0:
|
||
return LayerUpdate(split_info_updated=True)
|
||
device_id = context.nditer.value.item()
|
||
device_ids = context.device_ids
|
||
layer = context.layer.as_trt()
|
||
output_name = layer.get_output(0).name
|
||
output_shape = layer.shape
|
||
output_dims = []
|
||
for dim in range(len(output_shape)):
|
||
output_dim = output_shape[dim]
|
||
if dim in shard_dims:
|
||
device_dim = shard_dims[dim]
|
||
partition = get_partition(device_dim, device_ids)
|
||
assert output_dim % partition == 0
|
||
quotient = output_dim // partition
|
||
output_dims.append(quotient)
|
||
context.graph_context.set_split_info(
|
||
output_name, device_id, dim,
|
||
SplitInfo(output_dim, partition))
|
||
else:
|
||
output_dims.append(output_dim)
|
||
return LayerUpdate(
|
||
updated_attrs=dict(shape=trt.Dims(output_dims), ),
|
||
split_info_updated=True,
|
||
)
|
||
|
||
def update_shape(self, context: ShardContext):
|
||
if not context.layer.is_shape_io:
|
||
return
|
||
layer = context.layer.as_trt()
|
||
input_name = layer.get_input(0).name
|
||
output_name = layer.get_output(0).name
|
||
device_id = context.nditer.value.item()
|
||
layer_info = (output_name, None, device_id)
|
||
split_infos = context.graph_context.get_split_infos(
|
||
input_name, device_id)
|
||
if len(split_infos) == 0:
|
||
return
|
||
network = self.get_network(device_id)
|
||
shape_tensor = self.get_tensor(context.graph_context, output_name,
|
||
device_id).as_trt()
|
||
output_dims = []
|
||
for dim in range(len(context.layer.get_input(0).shape)):
|
||
if dim not in split_infos:
|
||
output_dim_layer = network.add_slice(shape_tensor, [dim], [1],
|
||
[1])
|
||
else:
|
||
input_dim = split_infos[dim].input_dim
|
||
output_dim_layer = network.add_constant(
|
||
[1], np.array([input_dim], dtype=default_int_dtype))
|
||
self.register_layer(output_dim_layer, f"output_dim{dim}",
|
||
*layer_info)
|
||
output_dims.append(output_dim_layer.get_output(0))
|
||
new_shape_layer = network.add_concatenation(output_dims)
|
||
self.register_layer(new_shape_layer, "new_shape", *layer_info)
|
||
new_shape_tensor = new_shape_layer.get_output(0)
|
||
context.graph_context.update_name_mapping(output_name, device_id,
|
||
new_shape_tensor.name)
|
||
|
||
def shard_slice(self, context: ShardContext):
|
||
sharding_spec = context.strategy.sharding_specs["output0"]
|
||
shard_dims = sharding_spec.dim_partition_dict
|
||
if len(shard_dims) == 0:
|
||
return LayerUpdate.none()
|
||
device_id = context.nditer.value.item()
|
||
network = self.get_network(device_id)
|
||
device_ids = context.device_ids
|
||
layer = context.layer.as_trt()
|
||
output_dims = []
|
||
updated_attrs = {}
|
||
updated_inputs = {}
|
||
|
||
if layer.num_inputs >= 3:
|
||
raw_output_shape = layer.get_output(0).shape
|
||
input_name = layer.get_input(2).name
|
||
layer_info = (input_name, layer.name, device_id)
|
||
shape_tensor = self.get_tensor(context.graph_context, input_name,
|
||
device_id).as_trt()
|
||
for dim in range(len(raw_output_shape)):
|
||
output_dim_layer = network.add_slice(shape_tensor, [dim], [1],
|
||
[1])
|
||
self.register_layer(output_dim_layer, f"output_dim{dim}",
|
||
*layer_info)
|
||
if dim in shard_dims:
|
||
device_dim = shard_dims[dim]
|
||
partition = get_partition(device_dim, device_ids)
|
||
partition_num_tensor = self.const_int(
|
||
network, f"partition_num{dim}", partition, layer_info)
|
||
quotient_layer = network.add_elementwise(
|
||
output_dim_layer.get_output(0), partition_num_tensor,
|
||
trt.ElementWiseOperation.FLOOR_DIV)
|
||
self.register_layer(quotient_layer, f"quotient{dim}",
|
||
*layer_info)
|
||
output_dim = self.cast(network,
|
||
quotient_layer.get_output(0),
|
||
default_int_dtype, layer_info)
|
||
output_dims.append(output_dim)
|
||
else:
|
||
output_dims.append(output_dim_layer.get_output(0))
|
||
output_dims_layer = network.add_concatenation(output_dims)
|
||
self.register_layer(output_dims_layer, "output_dims", *layer_info)
|
||
updated_inputs[2] = output_dims_layer.get_output(0)
|
||
else:
|
||
output_shape = layer.shape
|
||
for dim in range(len(output_shape)):
|
||
output_dim = output_shape[dim]
|
||
assert output_dim != -1
|
||
if dim in shard_dims:
|
||
device_dim = shard_dims[dim]
|
||
partition = get_partition(device_dim, device_ids)
|
||
assert output_dim % partition == 0
|
||
quotient = output_dim // partition
|
||
output_dims.append(quotient)
|
||
else:
|
||
output_dims.append(output_dim)
|
||
updated_attrs["shape"] = trt.Dims(output_dims)
|
||
return LayerUpdate(updated_attrs, updated_inputs)
|
||
|
||
def shard_shuffle(self, context: ShardContext):
|
||
sharding_spec = context.strategy.sharding_specs["output0"]
|
||
shard_dims = sharding_spec.dim_partition_dict
|
||
if len(shard_dims) == 0:
|
||
return LayerUpdate.none()
|
||
device_id = context.nditer.value.item()
|
||
network = self.get_network(device_id)
|
||
device_ids = context.device_ids
|
||
layer = context.layer.as_trt()
|
||
updated_attrs = {}
|
||
updated_inputs = {}
|
||
updated_reshape_dims = {}
|
||
second_transpose = layer.second_transpose
|
||
|
||
if layer.num_inputs >= 2:
|
||
raw_output_shape = layer.get_output(0).shape
|
||
input_name = layer.get_input(1).name
|
||
layer_info = (input_name, layer.name, device_id)
|
||
reshape_dims_tensor = self.get_tensor(context.graph_context,
|
||
input_name, device_id)
|
||
reshape_dims = context.layer.get_input(1).value
|
||
reshape_dims_tensor = reshape_dims_tensor.as_trt()
|
||
for dim in range(len(raw_output_shape)):
|
||
if second_transpose is not None:
|
||
reshape_dim = second_transpose[dim]
|
||
else:
|
||
reshape_dim = dim
|
||
output_dim_layer = network.add_slice(reshape_dims_tensor,
|
||
[reshape_dim], [1], [1])
|
||
self.register_layer(output_dim_layer, f"output_dim{dim}",
|
||
*layer_info)
|
||
output_dim = reshape_dims[reshape_dim]
|
||
if dim in shard_dims and output_dim != -1:
|
||
device_dim = shard_dims[dim]
|
||
partition = get_partition(device_dim, device_ids)
|
||
partition_num_tensor = self.const_int(
|
||
network, f"partition_num{dim}", partition, layer_info)
|
||
quotient_layer = network.add_elementwise(
|
||
output_dim_layer.get_output(0), partition_num_tensor,
|
||
trt.ElementWiseOperation.FLOOR_DIV)
|
||
self.register_layer(quotient_layer, f"quotient{dim}",
|
||
*layer_info)
|
||
updated_reshape_dims[reshape_dim] = self.cast(
|
||
network,
|
||
quotient_layer.get_output(0),
|
||
default_int_dtype,
|
||
layer_info,
|
||
)
|
||
else:
|
||
updated_reshape_dims[
|
||
reshape_dim] = output_dim_layer.get_output(0)
|
||
updated_reshape_dims = list(
|
||
map(lambda x: x[1], sorted(updated_reshape_dims.items())))
|
||
reshape_dims_layer = network.add_concatenation(updated_reshape_dims)
|
||
self.register_layer(reshape_dims_layer, "reshape_dims", *layer_info)
|
||
updated_inputs[1] = reshape_dims_layer.get_output(0)
|
||
else:
|
||
reshape_dims = layer.reshape_dims
|
||
if reshape_dims.__len__() < 0:
|
||
return LayerUpdate.none()
|
||
for dim in range(len(reshape_dims)):
|
||
if second_transpose is not None:
|
||
reshape_dim = second_transpose[dim]
|
||
else:
|
||
reshape_dim = dim
|
||
output_dim = reshape_dims[reshape_dim]
|
||
if dim in shard_dims and output_dim != -1:
|
||
device_dim = shard_dims[dim]
|
||
partition = get_partition(device_dim, device_ids)
|
||
quotient = output_dim // partition
|
||
updated_reshape_dims[reshape_dim] = quotient
|
||
else:
|
||
updated_reshape_dims[reshape_dim] = output_dim
|
||
updated_reshape_dims = list(
|
||
map(lambda x: x[1], sorted(updated_reshape_dims.items())))
|
||
updated_attrs["reshape_dims"] = trt.Dims(updated_reshape_dims)
|
||
return LayerUpdate(updated_attrs, updated_inputs)
|
||
|
||
def shard_gpt_attention(self, context: ShardContext):
|
||
layer = context.layer.as_trt()
|
||
plugin_info = get_plugin_info(
|
||
self.full_graph.as_trt(),
|
||
layer.name,
|
||
)
|
||
parser = IdxEntryParser(plugin_info)
|
||
head_dim = 1 if parser.remove_input_padding else 2
|
||
sharding_spec = context.strategy.sharding_specs[
|
||
f"input{parser.get_index(IdxEntry.QKV_TENSOR)}"]
|
||
shard_dims = sharding_spec.dim_partition_dict
|
||
if head_dim not in shard_dims:
|
||
return LayerUpdate.none()
|
||
device_id = context.nditer.value.item()
|
||
network = self.get_network(device_id)
|
||
device_ids = context.device_ids
|
||
updated_attrs = {}
|
||
updated_inputs = {}
|
||
device_dim = shard_dims[head_dim]
|
||
partition = get_partition(device_dim, device_ids)
|
||
index = get_index(device_dim, context.nditer)
|
||
if parser.is_entry_used(IdxEntry.K_TENSOR):
|
||
kv_sharding_spec = context.strategy.sharding_specs[
|
||
f"input{parser.get_index(IdxEntry.K_TENSOR)}"]
|
||
kv_shard_dims = kv_sharding_spec.dim_partition_dict
|
||
if head_dim in kv_shard_dims:
|
||
kv_device_dim = kv_shard_dims[head_dim]
|
||
kv_partition = get_partition(kv_device_dim, device_ids)
|
||
else:
|
||
kv_partition = 1
|
||
else:
|
||
kv_partition = 1
|
||
num_heads = plugin_info.pfc_as_ndarray["num_heads"].copy()
|
||
num_kv_heads = plugin_info.pfc_as_ndarray["num_kv_heads"].copy()
|
||
tp_size = plugin_info.pfc_as_ndarray["tp_size"].copy()
|
||
tp_rank = plugin_info.pfc_as_ndarray["tp_rank"].copy()
|
||
num_kv_heads = np.maximum(num_kv_heads // kv_partition, 1)
|
||
num_heads = np.maximum(num_heads // partition, 1)
|
||
tp_size[0] = partition
|
||
tp_rank[0] = index
|
||
|
||
new_plugin, new_plugin_info = get_updated_plugin(
|
||
plugin_info,
|
||
dict(
|
||
num_heads=num_heads,
|
||
num_kv_heads=num_kv_heads,
|
||
tp_size=tp_size,
|
||
tp_rank=tp_rank,
|
||
))
|
||
prefix = self.get_prefix(device_id)
|
||
new_layer_name = f"{prefix}{layer.name}"
|
||
set_plugin_info(network, new_layer_name, new_plugin_info)
|
||
updated_attrs["plugin"] = new_plugin
|
||
return LayerUpdate(updated_attrs, updated_inputs)
|
||
|
||
def shard_lookup(self, context: ShardContext):
|
||
sharding_spec = context.strategy.sharding_specs["input1"]
|
||
shard_dims = sharding_spec.dim_partition_dict
|
||
if 0 not in shard_dims:
|
||
return LayerUpdate.none()
|
||
layer = context.layer.as_trt()
|
||
plugin_info = get_plugin_info(
|
||
self.full_graph.as_trt(),
|
||
layer.name,
|
||
)
|
||
device_id = context.nditer.value.item()
|
||
network = self.get_network(device_id)
|
||
updated_attrs = {}
|
||
device_dim = shard_dims[0]
|
||
index = get_index(device_dim, context.nditer)
|
||
rank = plugin_info.pfc_as_ndarray["rank"].copy()
|
||
rank[0] = index
|
||
|
||
new_plugin, new_plugin_info = get_updated_plugin(
|
||
plugin_info, dict(rank=rank, ))
|
||
prefix = self.get_prefix(device_id)
|
||
new_layer_name = f"{prefix}{layer.name}"
|
||
set_plugin_info(network, new_layer_name, new_plugin_info)
|
||
updated_attrs["plugin"] = new_plugin
|
||
return LayerUpdate(updated_attrs)
|
||
|
||
|
||
class GraphGroupBase(GraphGroup):
|
||
|
||
def __init__(
|
||
self,
|
||
full_graph: PipelineGraph,
|
||
config: ParallelConfig,
|
||
auto_parallel_config: AutoParallelConfig,
|
||
) -> None:
|
||
self._full_graph = full_graph
|
||
self.config = config
|
||
self._auto_parallel_config = auto_parallel_config
|
||
self.infer_shape = auto_parallel_config.infer_shape
|
||
self.global_context = GraphContext()
|
||
self.shape_cache = {}
|
||
self.suffix = 0
|
||
self.current_block_id = -1
|
||
|
||
@property
|
||
def auto_parallel_config(self) -> AutoParallelConfig:
|
||
return self._auto_parallel_config
|
||
|
||
@property
|
||
def full_graph(self) -> PipelineGraph:
|
||
return self._full_graph
|
||
|
||
def register_layer(self,
|
||
layer,
|
||
base_name,
|
||
input_name,
|
||
output_name=None,
|
||
device_id=None,
|
||
keep_tensor_name=False) -> Layer:
|
||
layer_name = f"{base_name}_{input_name}"
|
||
if device_id is not None:
|
||
layer_name = f"{self.get_prefix(device_id)}{layer_name}"
|
||
if output_name is not None:
|
||
layer_name = f"{layer_name}_to_{output_name}"
|
||
suffix = self.suffix
|
||
self.suffix += 1
|
||
layer_name = f"{layer_name}_{suffix}"
|
||
if layer.type == trt.LayerType.PLUGIN_V2:
|
||
network = self.get_network(device_id)
|
||
plugin_info = get_plugin_info(network, layer.name)
|
||
if plugin_info is not None:
|
||
set_plugin_info(network, layer_name, plugin_info)
|
||
delete_plugin_info(network, layer.name)
|
||
layer.name = layer_name
|
||
layer.metadata = layer.name
|
||
if not keep_tensor_name:
|
||
for i in range(layer.num_outputs):
|
||
output_tensor = layer.get_output(i)
|
||
assert output_tensor.shape.__len__() >= 0
|
||
output_tensor.name = f"{layer.name}_output_{i}"
|
||
wrapped_layer = self.get_graph(device_id).register_layer(layer)
|
||
if self.current_block_id != -1:
|
||
wrapped_layer.attrs["block_id"] = self.current_block_id
|
||
wrapped_layer.attrs["role"] = "helper"
|
||
if self.infer_shape:
|
||
infer_per_layer_shapes(
|
||
layer,
|
||
self.get_shapes(device_id),
|
||
self.get_values(device_id),
|
||
self.shape_cache,
|
||
is_shape_io=True,
|
||
)
|
||
wrapped_layer.assign_shapes(
|
||
self.get_shapes(device_id),
|
||
self.get_values(device_id),
|
||
)
|
||
return wrapped_layer
|
||
|
||
def add_layer(self, wrapped_layer: Layer, device_ids,
|
||
strategy: ShardingStrategy):
|
||
layer = wrapped_layer.as_trt()
|
||
local_context = self.global_context.get_local_context(layer)
|
||
self.current_block_id = wrapped_layer.attrs["block_id"]
|
||
|
||
for i, input in enumerate(wrapped_layer.inputs):
|
||
if input is None:
|
||
continue
|
||
if i not in strategy.best_resharding_cost:
|
||
continue
|
||
comm_action_sequence = strategy.best_resharding_cost[i][0][1]
|
||
for commspec in comm_action_sequence:
|
||
self.add_comm(local_context,
|
||
input.name,
|
||
device_ids,
|
||
commspec,
|
||
output_name=layer.name)
|
||
|
||
it = np.nditer(device_ids, flags=['multi_index'])
|
||
for device_id in it:
|
||
device_id = device_id.item()
|
||
|
||
layer_type = layer.type
|
||
to_subclass_layer(layer)
|
||
shard_context = ShardContext(
|
||
local_context,
|
||
wrapped_layer,
|
||
it,
|
||
device_ids,
|
||
strategy,
|
||
)
|
||
if layer_type == trt.LayerType.CONSTANT:
|
||
layer_update = self.shard_constant(shard_context)
|
||
elif layer_type == trt.LayerType.FILL:
|
||
layer_update = self.shard_fill(shard_context)
|
||
elif layer_type == trt.LayerType.SLICE:
|
||
layer_update = self.shard_slice(shard_context)
|
||
elif layer_type == trt.LayerType.SHUFFLE:
|
||
layer_update = self.shard_shuffle(shard_context)
|
||
elif layer_type == trt.LayerType.PLUGIN_V2:
|
||
if layer.plugin.plugin_type == "GPTAttention":
|
||
layer_update = self.shard_gpt_attention(shard_context)
|
||
elif layer.plugin.plugin_type == "Lookup":
|
||
layer_update = self.shard_lookup(shard_context)
|
||
else:
|
||
layer_update = LayerUpdate.none()
|
||
else:
|
||
layer_update = LayerUpdate.none()
|
||
to_base_class_layer(layer)
|
||
|
||
for i, updated_input in layer_update.updated_inputs.items():
|
||
input_name = layer.get_input(i).name
|
||
local_context.update_name_mapping(input_name, device_id,
|
||
updated_input.name)
|
||
if layer.get_input(i).dtype != updated_input.dtype:
|
||
raise ValueError(
|
||
f"Input dtype mismatch for {layer.name}, "
|
||
f"expect {layer.get_input(i).dtype} for {input_name}, "
|
||
f"get {updated_input.dtype} for {updated_input.name}")
|
||
|
||
prefix = self.get_prefix(device_id)
|
||
new_wrapped_layer = self.get_graph(device_id).add_layer(
|
||
layer,
|
||
prefix=prefix,
|
||
input_mapping=local_context.get_name_mapping(device_id,
|
||
prefix=prefix),
|
||
updated_attrs=layer_update.updated_attrs,
|
||
)
|
||
new_wrapped_layer.attrs["strategy"] = strategy.name
|
||
new_wrapped_layer.attrs["block_id"] = self.current_block_id
|
||
new_layer = new_wrapped_layer.as_trt()
|
||
|
||
if self.infer_shape:
|
||
infer_per_layer_shapes(
|
||
new_layer,
|
||
self.get_shapes(device_id),
|
||
self.get_values(device_id),
|
||
self.shape_cache,
|
||
is_shape_io=wrapped_layer.is_shape_io,
|
||
)
|
||
new_wrapped_layer.assign_shapes(
|
||
self.get_shapes(device_id),
|
||
self.get_values(device_id),
|
||
)
|
||
|
||
for i in range(layer.num_outputs):
|
||
output_tensor = new_layer.get_output(i)
|
||
assert output_tensor.shape.__len__() >= 0
|
||
local_context.update_name_mapping(
|
||
layer.get_output(i).name, device_id, output_tensor.name)
|
||
|
||
if layer.type == trt.LayerType.SHAPE:
|
||
self.update_shape(shard_context)
|
||
|
||
self.global_context.update_layer_context(
|
||
wrapped_layer,
|
||
layer_update,
|
||
local_context,
|
||
device_id,
|
||
device_ids,
|
||
strategy.sharding_specs,
|
||
)
|
||
|
||
for i in range(layer.num_outputs):
|
||
commspecs = strategy.communication_actions.get(f"output{i}")
|
||
if commspecs is None:
|
||
continue
|
||
output = layer.get_output(i)
|
||
for commspec in commspecs:
|
||
self.add_comm(
|
||
self.global_context,
|
||
output.name,
|
||
device_ids,
|
||
commspec,
|
||
)
|
||
|
||
self.current_block_id = -1
|
||
|
||
|
||
class DistributedGraphGroup(GraphGroupBase):
|
||
|
||
def __init__(
|
||
self,
|
||
full_graph: PipelineGraph,
|
||
config: ParallelConfig,
|
||
auto_parallel_config: AutoParallelConfig,
|
||
) -> None:
|
||
super().__init__(full_graph, config, auto_parallel_config)
|
||
self.graphs = {}
|
||
self.io_tensor_shards = {}
|
||
self.shapes_by_device = {}
|
||
self.values_by_device = {}
|
||
phy_mesh = config.graph_config.phy_mesh
|
||
device_ids = phy_mesh.phy_devices_id
|
||
for device_id in np.nditer(device_ids):
|
||
device_id = device_id.item()
|
||
graph = PipelineGraph.create_graph()
|
||
graph._auto_parallel_config = {
|
||
"io_shards": {},
|
||
"mapping":
|
||
Mapping(
|
||
world_size=device_ids.size,
|
||
rank=device_id,
|
||
gpus_per_node=device_ids.shape[1],
|
||
tp_size=device_ids.size // config.graph_config.num_stages,
|
||
pp_size=config.graph_config.num_stages,
|
||
),
|
||
}
|
||
self.graphs[device_id] = graph
|
||
self.shapes_by_device[device_id] = {}
|
||
self.values_by_device[device_id] = {}
|
||
|
||
@contextlib.contextmanager
|
||
def disable_infer_shape(self):
|
||
infer_shape = self.infer_shape
|
||
self.infer_shape = False
|
||
yield
|
||
self.infer_shape = infer_shape
|
||
|
||
def get_network(self, device_id) -> trt.INetworkDefinition:
|
||
return self.graphs[device_id].as_trt()
|
||
|
||
def get_graph(self, device_id) -> PipelineGraph:
|
||
return self.graphs[device_id]
|
||
|
||
def get_prefix(self, device_id) -> str:
|
||
return ""
|
||
|
||
def get_shapes(self, device_id) -> Dict[str, Tuple[int, ...]]:
|
||
return self.shapes_by_device[device_id]
|
||
|
||
def get_values(self, device_id) -> Dict[str, List[int]]:
|
||
return self.values_by_device[device_id]
|
||
|
||
def add_reduce_scatter(self, context: GraphContext, input_name, output_name,
|
||
device_ids, shard_dims, device_dims):
|
||
dtype = str_dtype_to_trt(self.full_graph._plugin_config.dtype)
|
||
it = np.nditer(device_ids, flags=['multi_index'])
|
||
for device_id in it:
|
||
device_id = device_id.item()
|
||
layer_info = (input_name, output_name, device_id)
|
||
network = self.get_network(device_id)
|
||
input_tensor = self.get_tensor(context, input_name,
|
||
device_id).as_trt()
|
||
raw_input_shape = input_tensor.shape
|
||
input_shape_tensor = self.get_shape(network, input_tensor,
|
||
layer_info)
|
||
if shard_dims != [0]:
|
||
permutation = list(range(len(raw_input_shape)))
|
||
for dim in shard_dims:
|
||
permutation.remove(dim)
|
||
permutation = shard_dims + permutation
|
||
transpose_layer = network.add_shuffle(input_tensor)
|
||
transpose_layer.second_transpose = permutation
|
||
self.register_layer(transpose_layer, "input_transpose",
|
||
*layer_info)
|
||
input_tensor = transpose_layer.get_output(0)
|
||
flatten_tensor = self.flatten(network, input_tensor, layer_info)
|
||
input_dtype = flatten_tensor.dtype
|
||
if input_dtype != dtype:
|
||
to_reduce_tensor = self.cast(
|
||
network,
|
||
flatten_tensor,
|
||
dtype,
|
||
layer_info,
|
||
)
|
||
else:
|
||
to_reduce_tensor = flatten_tensor
|
||
|
||
reduce_scatter_plg_creator = trt.get_plugin_registry(
|
||
).get_plugin_creator('ReduceScatter', '1', TRT_LLM_PLUGIN_NAMESPACE)
|
||
assert reduce_scatter_plg_creator is not None
|
||
|
||
group = trt.PluginField(
|
||
"group",
|
||
np.ascontiguousarray(device_ids.reshape(-1).astype(np.int32)),
|
||
trt.PluginFieldType.INT32)
|
||
pf_type = trt.PluginField(
|
||
"type_id", np.array([int(to_reduce_tensor.dtype)], np.int32),
|
||
trt.PluginFieldType.INT32)
|
||
|
||
pfc = trt.PluginFieldCollection([group, pf_type])
|
||
rs_plug = reduce_scatter_plg_creator.create_plugin(
|
||
"reduce_scatter", pfc)
|
||
|
||
reduce_scatter_layer = network.add_plugin_v2([to_reduce_tensor],
|
||
rs_plug)
|
||
plugin_info = PluginInfo(reduce_scatter_plg_creator,
|
||
"reduce_scatter", pfc)
|
||
set_plugin_info(network, reduce_scatter_layer.name, plugin_info)
|
||
with self.disable_infer_shape():
|
||
wrapped_tensor = self.register_layer(
|
||
reduce_scatter_layer,
|
||
"reduce_scatter",
|
||
*layer_info,
|
||
).get_output(0)
|
||
reduce_scatter_tensor = reduce_scatter_layer.get_output(0)
|
||
if self.infer_shape:
|
||
shape = self.shapes_by_device[device_id][to_reduce_tensor.name]
|
||
assert len(shape) == 1
|
||
output_shape = (shape[0] // device_ids.size, )
|
||
self.shapes_by_device[device_id][
|
||
reduce_scatter_tensor.name] = output_shape
|
||
wrapped_tensor.shape = output_shape
|
||
if input_dtype != dtype:
|
||
reduce_scatter_tensor = self.cast(
|
||
network,
|
||
reduce_scatter_tensor,
|
||
input_dtype,
|
||
layer_info,
|
||
)
|
||
|
||
start = []
|
||
output_dims = []
|
||
stride = []
|
||
for dim in range(len(raw_input_shape)):
|
||
stride.append(1)
|
||
if dim not in shard_dims:
|
||
start.append(0)
|
||
output_dims.append(
|
||
self.get_item(network, input_shape_tensor, dim,
|
||
layer_info))
|
||
else:
|
||
start.append(None)
|
||
output_dims.append(None)
|
||
|
||
for dim, device_dim in zip(shard_dims, device_dims):
|
||
partition = get_partition(device_dim, device_ids)
|
||
index = get_index(device_dim, it)
|
||
input_dim = raw_input_shape[dim]
|
||
assert input_dim != -1
|
||
assert input_dim % partition == 0
|
||
quotient = input_dim // partition
|
||
start[dim] = index * quotient
|
||
output_dims[dim] = self.const_int(network, f"output_dim{dim}",
|
||
quotient, layer_info)
|
||
context.set_split_info(input_name, device_id, dim,
|
||
SplitInfo(input_dim, partition))
|
||
if shard_dims != [0]:
|
||
output_dims = [
|
||
output_dims[permutation[i]] for i in range(len(output_dims))
|
||
]
|
||
output_dims_tensor = self.concat(network, output_dims, layer_info)
|
||
output_tensor = self.reshape(
|
||
network,
|
||
reduce_scatter_tensor,
|
||
output_dims_tensor,
|
||
layer_info,
|
||
)
|
||
if shard_dims != [0]:
|
||
transpose_layer = network.add_shuffle(output_tensor)
|
||
transpose_layer.second_transpose = permutation
|
||
self.register_layer(transpose_layer, "output_transpose",
|
||
*layer_info)
|
||
output_tensor = transpose_layer.get_output(0)
|
||
context.update_name_mapping(input_name, device_id,
|
||
output_tensor.name)
|
||
|
||
def add_all_reduce_layer(self, context: GraphContext, input_name,
|
||
output_name, device_ids, to_reduce_tensors):
|
||
for device_id, to_reduce_tensor in zip(np.nditer(device_ids),
|
||
to_reduce_tensors):
|
||
device_id = device_id.item()
|
||
layer_info = (input_name, output_name, device_id)
|
||
network = self.get_network(device_id)
|
||
graph = self.get_graph(device_id)
|
||
workspace = graph.get_input("all_reduce_workspace").as_trt()
|
||
|
||
all_reduce_layer, allreduce_plg_creator, pfc = create_allreduce_plugin(
|
||
network=network,
|
||
tensor=to_reduce_tensor,
|
||
workspace=workspace,
|
||
group=np.ascontiguousarray(
|
||
device_ids.reshape(-1).astype(np.int32)),
|
||
dtype=to_reduce_tensor.dtype,
|
||
all_reduce_params=AllReduceParams(),
|
||
)
|
||
plugin_info = PluginInfo(allreduce_plg_creator, "allreduce", pfc)
|
||
set_plugin_info(network, all_reduce_layer.name, plugin_info)
|
||
with self.disable_infer_shape():
|
||
wrapped_tensor = self.register_layer(
|
||
all_reduce_layer,
|
||
"all_reduce",
|
||
*layer_info,
|
||
).get_output(0)
|
||
output_tensor = all_reduce_layer.get_output(0)
|
||
if self.infer_shape:
|
||
shape = self.shapes_by_device[device_id][to_reduce_tensor.name]
|
||
self.shapes_by_device[device_id][output_tensor.name] = shape
|
||
wrapped_tensor.shape = shape
|
||
context.update_name_mapping(input_name, device_id,
|
||
output_tensor.name)
|
||
|
||
def add_all_gather_layer(self, context: GraphContext, input_name,
|
||
output_name, device_ids, to_gather_tensors):
|
||
all_gather_layers = []
|
||
for device_id, to_gather_tensor in zip(np.nditer(device_ids),
|
||
to_gather_tensors):
|
||
device_id = device_id.item()
|
||
layer_info = (input_name, output_name, device_id)
|
||
network = self.get_network(device_id)
|
||
|
||
allgather_plg_creator = trt.get_plugin_registry(
|
||
).get_plugin_creator('AllGather', '1', TRT_LLM_PLUGIN_NAMESPACE)
|
||
assert allgather_plg_creator is not None
|
||
|
||
group = trt.PluginField(
|
||
"group",
|
||
np.ascontiguousarray(device_ids.reshape(-1).astype(np.int32)),
|
||
trt.PluginFieldType.INT32)
|
||
pf_type = trt.PluginField(
|
||
"type_id", np.array([int(to_gather_tensor.dtype)], np.int32),
|
||
trt.PluginFieldType.INT32)
|
||
pfc = trt.PluginFieldCollection([group, pf_type])
|
||
allgather = allgather_plg_creator.create_plugin("allgather", pfc)
|
||
|
||
all_gather_layer = network.add_plugin_v2([to_gather_tensor],
|
||
allgather)
|
||
plugin_info = PluginInfo(allgather_plg_creator, "allgather", pfc)
|
||
set_plugin_info(network, all_gather_layer.name, plugin_info)
|
||
with self.disable_infer_shape():
|
||
wrapped_tensor = self.register_layer(
|
||
all_gather_layer,
|
||
"all_gather",
|
||
*layer_info,
|
||
).get_output(0)
|
||
if self.infer_shape:
|
||
output_tensor = all_gather_layer.get_output(0)
|
||
shape = self.shapes_by_device[device_id][to_gather_tensor.name]
|
||
assert len(shape) == 1
|
||
output_shape = (shape[0] * device_ids.size, )
|
||
self.shapes_by_device[device_id][
|
||
output_tensor.name] = output_shape
|
||
wrapped_tensor.shape = output_shape
|
||
all_gather_layers.append(all_gather_layer)
|
||
return all_gather_layers
|
||
|
||
def set_shard_num(self, tensor_name, dim, shard_num):
|
||
for graph in self.graphs.values():
|
||
io_shards = graph._auto_parallel_config["io_shards"]
|
||
if tensor_name not in io_shards:
|
||
io_shards[tensor_name] = {}
|
||
io_shards[tensor_name][dim] = shard_num
|
||
|
||
def add_input(self, tensor: Tensor, device_ids, strategy: ShardingStrategy):
|
||
context = self.global_context
|
||
sharding_spec = strategy.sharding_specs["output0"]
|
||
shard_dims = sharding_spec.dim_partition_dict
|
||
for dim, device_dim in shard_dims.items():
|
||
partition = get_partition(device_dim, device_ids)
|
||
self.set_shard_num(tensor.name, dim, partition)
|
||
for device_id in np.nditer(device_ids):
|
||
device_id = device_id.item()
|
||
graph = self.get_graph(device_id)
|
||
new_input = graph.add_input(tensor.as_trt())
|
||
shape = [*tensor.shape]
|
||
if len(shard_dims) != 0:
|
||
output_shape = [*tensor.raw_shape]
|
||
for dim, device_dim in shard_dims.items():
|
||
partition = get_partition(device_dim, device_ids)
|
||
output_dim = output_shape[dim]
|
||
assert output_dim != -1
|
||
assert output_dim % partition == 0
|
||
quotient = output_dim // partition
|
||
output_shape[dim] = quotient
|
||
shape[dim] = quotient
|
||
assert tensor.value is None
|
||
context.set_split_info(tensor.name, device_id, dim,
|
||
SplitInfo(output_dim, partition))
|
||
new_input.raw_shape = output_shape
|
||
context.update_name_mapping(tensor.name, device_id, tensor.name)
|
||
if self.infer_shape:
|
||
self.shapes_by_device[device_id][tensor.name] = tuple(shape)
|
||
new_input.shape = tuple(shape)
|
||
if tensor.value is not None:
|
||
self.values_by_device[device_id][tensor.name] = tensor.value
|
||
new_input.value = tensor.value
|
||
|
||
def add_output(self, tensor: Tensor, device_ids,
|
||
strategy: ShardingStrategy):
|
||
comm_action_sequence = strategy.best_resharding_cost[0][0][1]
|
||
for commspec in comm_action_sequence:
|
||
self.add_comm(self.global_context, tensor.name, device_ids,
|
||
commspec)
|
||
for device_id in np.nditer(device_ids):
|
||
device_id = device_id.item()
|
||
graph = self.get_graph(device_id)
|
||
output_name = tensor.name
|
||
new_output_name = self.global_context.get_name(
|
||
output_name, device_id)
|
||
if new_output_name != output_name:
|
||
suffix = self.suffix
|
||
self.suffix += 1
|
||
original_name = f"original_{output_name}_{suffix}"
|
||
original_tensor = graph.get_tensor(output_name)
|
||
original_tensor.as_trt().name = original_name
|
||
output_tensor = graph.get_tensor(new_output_name)
|
||
output_tensor.as_trt().name = output_name
|
||
graph._tensors[original_name] = original_tensor
|
||
graph._tensors[output_name] = output_tensor
|
||
del graph._tensors[new_output_name]
|
||
else:
|
||
output_tensor = graph.get_tensor(output_name)
|
||
trt_output = output_tensor.as_trt()
|
||
if trt_output.is_shape_tensor:
|
||
graph.add_output_shape(trt_output)
|
||
else:
|
||
graph.add_output(trt_output)
|
||
trt_output.dtype = tensor.dtype
|
||
if tensor.dtype != output_tensor.dtype:
|
||
raise ValueError(
|
||
f"Output dtype mismatch, "
|
||
f"expect {tensor.dtype} for {tensor.name}, "
|
||
f"get {output_tensor.dtype} for {output_tensor.name}")
|
||
|
||
shard_dims = strategy.sharding_specs["input0"].dim_partition_dict
|
||
for dim, device_dim in shard_dims.items():
|
||
partition = get_partition(device_dim, device_ids)
|
||
self.set_shard_num(tensor.name, dim, partition)
|
||
|
||
|
||
class PrefixedGraphGroup(GraphGroupBase):
|
||
|
||
def __init__(
|
||
self,
|
||
full_graph: PipelineGraph = None,
|
||
config: ParallelConfig = None,
|
||
auto_parallel_config: AutoParallelConfig = None,
|
||
) -> None:
|
||
auto_parallel_config = auto_parallel_config or dict(
|
||
infer_shape=False,
|
||
validation_mode=False,
|
||
)
|
||
super().__init__(full_graph, config, auto_parallel_config)
|
||
self.validation_mode = auto_parallel_config.validation_mode
|
||
if not self.infer_shape:
|
||
self.validation_mode = False
|
||
self.prefixed_graph = PipelineGraph.create_graph()
|
||
if self.validation_mode:
|
||
self.layer_mapping = config.graph_config.graph_mapping.layer_mapping
|
||
self.graph_strategy = config.graph_strategy
|
||
self.shapes = {}
|
||
self.values = {}
|
||
self.timing_cache = None
|
||
|
||
def get_network(self, device_id) -> trt.INetworkDefinition:
|
||
return self.prefixed_graph.as_trt()
|
||
|
||
def get_graph(self, device_id) -> PipelineGraph:
|
||
return self.prefixed_graph
|
||
|
||
def get_prefix(self, device_id) -> str:
|
||
return f"device{device_id}_"
|
||
|
||
def get_shapes(self, device_id) -> Dict[str, Tuple[int, ...]]:
|
||
return self.shapes
|
||
|
||
def get_values(self, device_id) -> Dict[str, List[int]]:
|
||
return self.values
|
||
|
||
def add_all_reduce_layer(self, context: GraphContext, input_name,
|
||
output_name, device_ids, to_reduce_tensors):
|
||
reshaped_tensors = []
|
||
for device_id, to_reduce_tensor in zip(np.nditer(device_ids),
|
||
to_reduce_tensors):
|
||
device_id = device_id.item()
|
||
layer_info = (input_name, output_name, device_id)
|
||
network = self.get_network(device_id)
|
||
reshape_dims_tensor = self.concat(
|
||
network,
|
||
[
|
||
self.get_shape(network, to_reduce_tensor, layer_info),
|
||
self.const_int(network, "expanded_dim", 1, layer_info)
|
||
],
|
||
layer_info,
|
||
)
|
||
reshaped_tensor = self.reshape(
|
||
network,
|
||
to_reduce_tensor,
|
||
reshape_dims_tensor,
|
||
layer_info,
|
||
)
|
||
reshaped_tensors.append(reshaped_tensor)
|
||
|
||
for device_id in np.nditer(device_ids):
|
||
device_id = device_id.item()
|
||
layer_info = (input_name, output_name, device_id)
|
||
input_tensor = self.get_tensor(context, input_name, 0).as_trt()
|
||
num_dims = len(input_tensor.shape)
|
||
network = self.get_network(device_id)
|
||
concat_layer = network.add_concatenation(reshaped_tensors)
|
||
concat_layer.axis = num_dims
|
||
self.register_layer(concat_layer, "concat", *layer_info)
|
||
reduce_layer = network.add_reduce(concat_layer.get_output(0),
|
||
trt.ReduceOperation.SUM,
|
||
axes=1 << num_dims,
|
||
keep_dims=False)
|
||
dtype = to_reduce_tensors[0].dtype
|
||
reduce_layer.precision = dtype
|
||
reduce_layer.set_output_type(0, dtype)
|
||
self.register_layer(reduce_layer, "reduce", *layer_info)
|
||
output_tensor = reduce_layer.get_output(0)
|
||
|
||
context.update_name_mapping(input_name, device_id,
|
||
output_tensor.name)
|
||
|
||
def add_all_gather_layer(self, context: GraphContext, input_name,
|
||
output_name, device_ids, to_gather_tensors):
|
||
all_gather_layers = []
|
||
for device_id in np.nditer(device_ids):
|
||
device_id = device_id.item()
|
||
layer_info = (input_name, output_name, device_id)
|
||
network = self.get_network(device_id)
|
||
all_gather_layer = network.add_concatenation(to_gather_tensors)
|
||
all_gather_layer.axis = 0
|
||
self.register_layer(all_gather_layer, "all_gather", *layer_info)
|
||
all_gather_layers.append(all_gather_layer)
|
||
return all_gather_layers
|
||
|
||
def add_input(self, tensor: Tensor, device_ids, strategy: ShardingStrategy):
|
||
|
||
def add_identity():
|
||
identity_layer = network.add_identity(input.as_trt())
|
||
return identity_layer
|
||
|
||
input = self.prefixed_graph.add_input(tensor.as_trt())
|
||
if self.infer_shape:
|
||
self.shapes[tensor.name] = tensor.shape
|
||
input.shape = tensor.shape
|
||
if tensor.value is not None:
|
||
self.values[tensor.name] = tensor.value
|
||
input.value = tensor.value
|
||
network = self.get_network(None)
|
||
if self.validation_mode:
|
||
identity_layer = add_identity()
|
||
identity_layer.get_output(0).name = f"ref_{tensor.name}"
|
||
layer_info = (tensor.name, None, None)
|
||
self.register_layer(identity_layer,
|
||
"identity",
|
||
*layer_info,
|
||
keep_tensor_name=True)
|
||
input.attrs["strategy"] = strategy.name
|
||
sharding_spec = strategy.sharding_specs["output0"]
|
||
pre_sharding_sepc = get_full_sharding_spec(sharding_spec)
|
||
comm_action_sequence = get_comm_action_sequence(pre_sharding_sepc,
|
||
sharding_spec)
|
||
context = self.global_context
|
||
for device_id in np.nditer(device_ids):
|
||
device_id = device_id.item()
|
||
layer_info = (tensor.name, None, device_id)
|
||
context.update_name_mapping(tensor.name, device_id, tensor.name)
|
||
if len(comm_action_sequence
|
||
) == 0 and not tensor.as_trt().is_shape_tensor:
|
||
identity_layer = add_identity()
|
||
self.register_layer(identity_layer, "identity", *layer_info)
|
||
context.update_name_mapping(
|
||
tensor.name,
|
||
device_id,
|
||
identity_layer.get_output(0).name,
|
||
)
|
||
for commspec in comm_action_sequence:
|
||
self.add_comm(context, tensor.name, device_ids, commspec)
|
||
|
||
def get_graph_in_range(self, graph_group, src_layer, layer_range,
|
||
device_ids, shapes, values):
|
||
src_network = self.prefixed_graph.as_trt()
|
||
graph = graph_group.prefixed_graph
|
||
network = graph.as_trt()
|
||
input_mapping = {}
|
||
for device_id in np.nditer(device_ids):
|
||
device_id = device_id.item()
|
||
for i in range(src_layer.num_inputs):
|
||
src_input = src_layer.get_input(i)
|
||
if src_input is not None:
|
||
input = self.get_tensor(
|
||
self.global_context,
|
||
src_input.name,
|
||
device_id,
|
||
).as_trt()
|
||
if graph.get_input(src_input.name) is not None:
|
||
new_input = graph_group.get_tensor(
|
||
graph_group.global_context,
|
||
src_input.name,
|
||
device_id,
|
||
).as_trt()
|
||
input_mapping[input.name] = new_input.name
|
||
continue
|
||
if graph.get_tensor(input.name) is not None:
|
||
continue
|
||
shape = shapes[input.name]
|
||
assert input.name in values
|
||
value = values[input.name]
|
||
weights = np.asarray(value,
|
||
dtype=trt_dtype_to_np(input.dtype))
|
||
weights = to_trt_weights(weights)
|
||
input_layer = network.add_constant(shape, weights)
|
||
new_input = input_layer.get_output(0)
|
||
new_input.name = input.name
|
||
graph.register_layer(input_layer)
|
||
for i in layer_range:
|
||
layer = src_network.get_layer(i)
|
||
graph.add_layer(layer, input_mapping=input_mapping)
|
||
|
||
def add_layer_singleton(self, output, device_ids, sharding_spec):
|
||
assert self.prefixed_graph.get_tensor(output.name) is None
|
||
network = self.prefixed_graph.as_trt()
|
||
full_sharding_sepc = get_full_sharding_spec(sharding_spec)
|
||
comm_action_sequence = get_comm_action_sequence(sharding_spec,
|
||
full_sharding_sepc)
|
||
output_context = self.global_context.get_local_context_for_output(
|
||
output)
|
||
if len(comm_action_sequence) != 0:
|
||
for commspec in comm_action_sequence[:-1]:
|
||
self.add_comm(output_context, output.name, device_ids, commspec)
|
||
self.add_comm(
|
||
output_context,
|
||
output.name,
|
||
device_ids,
|
||
comm_action_sequence[-1],
|
||
is_singleton=True,
|
||
)
|
||
device_id = next(np.nditer(device_ids)).item()
|
||
layer_info = (output.name, None, device_id)
|
||
output_tensor = self.get_tensor(output_context, output.name,
|
||
device_id).as_trt()
|
||
singleton_layer = network.add_identity(output_tensor)
|
||
singleton_layer.get_output(0).name = output.name
|
||
self.register_layer(singleton_layer,
|
||
"singleton",
|
||
*layer_info,
|
||
keep_tensor_name=True)
|
||
|
||
def add_layer(self, wrapped_layer: Layer, device_ids,
|
||
strategy: ShardingStrategy):
|
||
graph = self.prefixed_graph
|
||
network = graph.as_trt()
|
||
start_layer_id = network.num_layers
|
||
|
||
super().add_layer(wrapped_layer, device_ids, strategy)
|
||
|
||
layer = wrapped_layer.as_trt()
|
||
|
||
if self.validation_mode:
|
||
is_shape = (wrapped_layer.is_shape_io
|
||
or layer.type == trt.LayerType.SHAPE)
|
||
|
||
if not is_shape:
|
||
self.current_block_id = wrapped_layer.attrs["block_id"]
|
||
for i, wrapped_output in enumerate(wrapped_layer.outputs):
|
||
if wrapped_output.is_graph_output:
|
||
continue
|
||
output = wrapped_output.as_trt()
|
||
output_name = f"output{i}"
|
||
if strategy.communication_actions.get(
|
||
output_name) is not None:
|
||
output_name += "_after_comm"
|
||
sharding_spec = strategy.sharding_specs[output_name]
|
||
self.add_layer_singleton(output, device_ids, sharding_spec)
|
||
self.current_block_id = -1
|
||
end_layer_id = network.num_layers
|
||
|
||
is_skip = (is_shape or layer.type == trt.LayerType.CONSTANT
|
||
or layer.name in self.layer_mapping)
|
||
sharded = False
|
||
for sharding_spec in strategy.sharding_specs.values():
|
||
if len(sharding_spec.dim_partition_dict) > 0:
|
||
sharded = True
|
||
break
|
||
if not sharded:
|
||
is_skip = True
|
||
|
||
ref_layer = graph.add_layer(layer, prefix="ref_")
|
||
ref_layer.attrs["strategy"] = strategy.name
|
||
ref_layer.attrs["block_id"] = wrapped_layer.attrs["block_id"]
|
||
if layer.type == trt.LayerType.CONSTANT:
|
||
self.register_unfilled_weights(graph, layer)
|
||
|
||
if is_skip:
|
||
return
|
||
|
||
logger.debug(f"validating layer {layer.name}")
|
||
|
||
layer_type = layer.type
|
||
generated_input_values = {}
|
||
to_subclass_layer(layer)
|
||
if layer_type == trt.LayerType.PLUGIN_V2:
|
||
if layer.plugin.plugin_type == "GPTAttention":
|
||
sharding_specs = {}
|
||
for name, sharding_spec in strategy.sharding_specs.items():
|
||
sharding_specs[name] = get_full_sharding_spec(
|
||
sharding_spec)
|
||
plugin_info = get_plugin_info(
|
||
self.full_graph.as_trt(),
|
||
layer.name,
|
||
)
|
||
generated_input_values = GPTAttentionPlugin.parameter_generator(
|
||
sharding_specs, plugin_info)
|
||
to_base_class_layer(layer)
|
||
|
||
validation_graph_group = PrefixedGraphGroup()
|
||
validation_graph = validation_graph_group.prefixed_graph
|
||
validation_graph._io_buffer_mapping = self.full_graph._io_buffer_mapping
|
||
extra_input_values = {}
|
||
validation_shapes = {}
|
||
for i, wrapped_input in enumerate(wrapped_layer.inputs):
|
||
if wrapped_input is None:
|
||
continue
|
||
input = wrapped_input.as_trt()
|
||
validation_shapes[input.name] = wrapped_input.shape
|
||
if wrapped_input.value is None:
|
||
if i in generated_input_values:
|
||
extra_input_value = generated_input_values[i]
|
||
else:
|
||
extra_input_value = torch.empty(
|
||
tuple(wrapped_input.shape),
|
||
dtype=trt_dtype_to_torch(input.dtype),
|
||
device=torch.cuda.current_device(),
|
||
)
|
||
if torch.is_floating_point(extra_input_value):
|
||
extra_input_value.normal_()
|
||
# extra_input_value[:] = random.choice([2, 3, 5, 7])
|
||
extra_input_values[input.name] = extra_input_value
|
||
self.values[input.name] = extra_input_value
|
||
if wrapped_input.producer is not None:
|
||
node_name = wrapped_input.producer.name
|
||
output_index = wrapped_input.output_index
|
||
else:
|
||
node_name = wrapped_input.name
|
||
output_index = 0
|
||
sharding_spec = self.graph_strategy[
|
||
node_name].sharding_specs[f"output{output_index}"]
|
||
validation_graph_group.add_input(
|
||
wrapped_input,
|
||
device_ids,
|
||
ShardingStrategy(
|
||
sharding_specs={"output0": sharding_spec}),
|
||
)
|
||
validation_graph.get_input(
|
||
input.name).raw_shape = wrapped_input.shape
|
||
|
||
self.get_graph_in_range(
|
||
validation_graph_group,
|
||
layer,
|
||
range(start_layer_id, end_layer_id),
|
||
device_ids,
|
||
self.shapes,
|
||
self.values,
|
||
)
|
||
|
||
for i, wrapped_output in enumerate(wrapped_layer.outputs):
|
||
output = wrapped_output.as_trt()
|
||
if wrapped_output.is_graph_output:
|
||
output_name = f"output{i}"
|
||
if strategy.communication_actions.get(
|
||
output_name) is not None:
|
||
output_name += "_after_comm"
|
||
sharding_spec = strategy.sharding_specs[output_name]
|
||
validation_graph_group.global_context.merge_context(
|
||
self.global_context.get_local_context_for_output(
|
||
output))
|
||
validation_graph_group.add_layer_singleton(
|
||
output, device_ids, sharding_spec)
|
||
validation_graph.add_output(output)
|
||
validation_shapes[output.name] = wrapped_output.shape
|
||
if not self.timing_cache:
|
||
self.timing_cache = network.builder.create_builder_config(
|
||
).create_timing_cache(b"")
|
||
logger.debug(f"run validation graph for layer {layer.name}")
|
||
validation_runner = validation_graph.get_runner(
|
||
validation_shapes,
|
||
self.values,
|
||
timing_cache=self.timing_cache,
|
||
opt_level=0,
|
||
)
|
||
values = validation_runner.run()
|
||
refer_input_values = {}
|
||
for wrapped_input in wrapped_layer.inputs:
|
||
if wrapped_input is None:
|
||
continue
|
||
if wrapped_input.value is not None:
|
||
refer_input_values[wrapped_input.name] = wrapped_input.value
|
||
refer_graph, output_mapping = get_per_layer_graph(
|
||
layer,
|
||
validation_shapes,
|
||
refer_input_values,
|
||
is_shape_io=False,
|
||
)
|
||
refer_graph._io_buffer_mapping = self.full_graph._io_buffer_mapping
|
||
for proxy_output, output in output_mapping.items():
|
||
validation_shapes[proxy_output] = validation_shapes[output]
|
||
logger.debug(f"run refer graph for layer {layer.name}")
|
||
refer_runner = refer_graph.get_runner(
|
||
validation_shapes,
|
||
self.values,
|
||
timing_cache=self.timing_cache,
|
||
opt_level=0,
|
||
)
|
||
refer_outputs = refer_runner.run()
|
||
for name, refer_output in refer_outputs.items():
|
||
if name in output_mapping:
|
||
refer_output = refer_output.bool()
|
||
output = values[name]
|
||
# ∣output−refer_output∣ <= atol+rtol*∣refer_output∣
|
||
atol = 1e-02
|
||
rtol = 1e-02
|
||
if not torch.allclose(
|
||
output,
|
||
refer_output,
|
||
rtol=rtol,
|
||
atol=atol,
|
||
equal_nan=True,
|
||
):
|
||
size = output.nelement()
|
||
diff = (output - refer_output).abs()
|
||
diff_index = (~torch.isnan(diff)) & (diff > (
|
||
atol + rtol * refer_output.abs()))
|
||
diff_output = diff[diff_index]
|
||
diff_size = diff_output.nelement()
|
||
logger.warning(
|
||
f"output {name} of {layer.name} is not accurate after parallelization. "
|
||
f"{diff_size} out of {size} elements ({diff_size / size * 100:.2f}%) are not close. "
|
||
f"max: {diff_output.max():.5f}, mean: {diff_output.float().mean():.5f}, std: {diff_output.float().std():.5f}. "
|
||
f"mean of reference: {refer_output.float().mean():.5f}, mean of output: {output.float().mean():.5f}."
|
||
)
|
||
for name in extra_input_values.keys():
|
||
del self.values[name]
|
||
|
||
def add_output(self, tensor: Tensor, device_ids,
|
||
strategy: ShardingStrategy):
|
||
trt_output = tensor.as_trt()
|
||
comm_action_sequence = strategy.best_resharding_cost[0][0][1]
|
||
for commspec in comm_action_sequence:
|
||
self.add_comm(self.global_context, tensor.name, device_ids,
|
||
commspec)
|
||
self.add_layer_singleton(trt_output, device_ids,
|
||
strategy.sharding_specs["input0"])
|
||
if trt_output.is_shape_tensor:
|
||
output = self.prefixed_graph.add_output_shape(trt_output)
|
||
else:
|
||
output = self.prefixed_graph.add_output(trt_output)
|
||
trt_output.dtype = tensor.dtype
|
||
output.attrs["strategy"] = strategy.name
|
||
|
||
def assign_shapes(self, shape_info: ShapeInfo):
|
||
if self.validation_mode:
|
||
shapes = {
|
||
f"ref_{name}": shape
|
||
for name, shape in shape_info.shapes.items()
|
||
}
|
||
values = {
|
||
f"ref_{name}": value
|
||
for name, value in shape_info.values.items()
|
||
}
|
||
self.shapes.update(shapes)
|
||
self.values.update(values)
|
||
shape_layers = get_shape_layers(self.prefixed_graph.as_trt())
|
||
shape_info = ShapeInfo(self.shapes, self.values, shape_layers)
|
||
self.prefixed_graph.assign_shapes(shape_info)
|
||
|
||
|
||
def parallelize(
|
||
simplifier: Simplifier,
|
||
config: ParallelConfig,
|
||
):
|
||
auto_parallel_config = simplifier.config
|
||
debug_mode = auto_parallel_config.debug_mode
|
||
dump_path = auto_parallel_config.dump_path
|
||
debug_outputs = auto_parallel_config.debug_outputs
|
||
|
||
simplifier.infer_shapes(config.graph_config.num_micro_batches)
|
||
network = simplifier.network
|
||
graph = simplifier.graph
|
||
phy_mesh = config.graph_config.phy_mesh
|
||
# TODO: test device_ids = [[0]]
|
||
device_ids = phy_mesh.phy_devices_id
|
||
stage_phy_meshes = config.graph_config.stage_phy_meshes
|
||
block_to_stage = config.graph_config.graph_mapping.block_to_stage
|
||
graph_strategy = config.graph_strategy
|
||
desimplify_strategy(
|
||
graph,
|
||
graph_strategy,
|
||
config.graph_config.graph_mapping,
|
||
)
|
||
graph._plugin_config = simplifier.llm_network.plugin_config
|
||
graph_group = GraphGroup.from_graph(graph, config, auto_parallel_config)
|
||
|
||
if not debug_mode:
|
||
init_all_reduce_helper()
|
||
tp_size = phy_mesh.size // config.graph_config.num_stages
|
||
shape = (CustomAllReduceHelper.POINTERS_PER_RANK * tp_size +
|
||
CustomAllReduceHelper.POINTERS_OF_COUNTER, )
|
||
workspace = graph.as_trt().add_input(
|
||
name="all_reduce_workspace",
|
||
dtype=trt.int64,
|
||
shape=shape,
|
||
)
|
||
tensor = graph.register_input(workspace)
|
||
tensor.shape = shape
|
||
graph_strategy["all_reduce_workspace"] = ShardingStrategy(
|
||
sharding_specs={
|
||
"output0":
|
||
ShardingSpec(
|
||
device_mesh=phy_mesh.as_logical_mesh(),
|
||
data_type_size=tensor.dtype_str_size,
|
||
data_shape=shape,
|
||
max_data_shape=shape,
|
||
raw_data_shape=shape,
|
||
dim_partition_dict={},
|
||
)
|
||
})
|
||
|
||
if dump_path is not None:
|
||
lock = FileLock(f"{dump_path}/path.lock", thread_local=False)
|
||
with lock:
|
||
with open(f'{dump_path}/sharded_graph.log', 'w+') as file:
|
||
config.print_graph_strategy(file)
|
||
|
||
for input in graph.inputs:
|
||
graph_group.add_input(input, device_ids, graph_strategy[input.name])
|
||
for block in simplifier.blocks:
|
||
stage_id = block_to_stage[block.block_id]
|
||
stage_phy_mesh = stage_phy_meshes[stage_id]
|
||
stage_device_ids = stage_phy_mesh.phy_devices_id.reshape(
|
||
config.lmesh.mesh_shape)
|
||
for i in block.sorted_layer_ids:
|
||
layer = graph.get_layer(network.get_layer(i).name)
|
||
layer.attrs["block_id"] = block.block_id
|
||
graph_group.add_layer(
|
||
layer,
|
||
stage_device_ids,
|
||
graph_strategy[layer.name],
|
||
)
|
||
for output in graph.outputs:
|
||
graph_group.add_output(output, device_ids, graph_strategy[output.name])
|
||
|
||
if debug_mode:
|
||
new_graph = graph_group.prefixed_graph
|
||
debug_outputs = debug_outputs or []
|
||
if isinstance(debug_outputs, str):
|
||
if debug_outputs == 'validation':
|
||
debug_outputs = []
|
||
for tensor in new_graph.tensors:
|
||
if tensor.name.startswith('ref_'):
|
||
original_name = tensor.name[4:]
|
||
original_tensor = new_graph.get_tensor(original_name)
|
||
if original_tensor is not None:
|
||
if not original_tensor.is_graph_io:
|
||
debug_outputs.append(tensor.name)
|
||
debug_outputs.append(original_name)
|
||
if original_tensor.is_graph_output:
|
||
debug_outputs.append(tensor.name)
|
||
else:
|
||
pattern = debug_outputs
|
||
debug_outputs = []
|
||
for tensor in new_graph.tensors:
|
||
if tensor.as_trt().is_shape_tensor:
|
||
continue
|
||
if tensor.producer is not None:
|
||
layer = tensor.producer
|
||
if layer.type == trt.LayerType.SHAPE:
|
||
continue
|
||
if re.match(pattern, tensor.name):
|
||
debug_outputs.append(tensor.name)
|
||
for output_name in debug_outputs:
|
||
trt_output = new_graph.get_tensor(output_name).as_trt()
|
||
if trt_output.is_shape_tensor:
|
||
output = new_graph.add_output_shape(trt_output)
|
||
else:
|
||
output = new_graph.add_output(trt_output)
|
||
graph_group.assign_shapes(simplifier.shape_info)
|
||
if dump_path is not None:
|
||
with lock:
|
||
new_graph.to_dot(
|
||
f'{dump_path}/sharded_graph.dot',
|
||
per_device=True,
|
||
per_block=True,
|
||
# ignore_shape_io=True,
|
||
extra_attrs=['strategy'],
|
||
)
|
||
return [new_graph]
|
||
else:
|
||
graphs = []
|
||
for device_id in np.nditer(device_ids):
|
||
device_id = device_id.item()
|
||
graph = graph_group.graphs[device_id]
|
||
graphs.append(graph)
|
||
return graphs
|