TensorRT-LLMs/tensorrt_llm/auto_parallel/parallelization.py
Yibin Li 0f3bd7800e
[TRTLLM-4971]: Use safe deserialization in ParallelConfig (#4630)
Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
2025-06-27 09:58:41 +08:00

2300 lines
99 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import contextlib
import copy
import itertools
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, ClassVar, Dict, List, Sequence, Set, Tuple, Union
import numpy as np
import tensorrt as trt
import torch
from filelock import FileLock
from tensorrt_llm import serialization
from tensorrt_llm._utils import (str_dtype_to_trt, trt_dtype_to_np,
trt_dtype_to_torch)
from tensorrt_llm.functional import AllReduceParams, create_allreduce_plugin
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.network import (PluginInfo, delete_plugin_info, get_np_weight,
get_plugin_info, set_plugin_info)
from tensorrt_llm.plugin import TRT_LLM_PLUGIN_NAMESPACE, init_all_reduce_helper
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
from tensorrt_llm.version import __version__
from .config import AutoParallelConfig
from .device_mesh import LogicalDeviceMesh
from .pipeline_graph import Layer, PipelineGraph, Tensor
from .shape_info import (ShapeInfo, get_per_layer_graph, get_shape_layers,
infer_per_layer_shapes)
from .simplifier import GraphConfig, GraphMapping, Simplifier, StageType
from .tensor_parallel.comm_spec import CommSpec
from .tensor_parallel.plugin_nodes.gpt_attention_node import (
GPTAttentionPlugin, IdxEntry, IdxEntryParser)
from .tensor_parallel.sharding_spec import ShardingSpec, get_sharding_sequence
from .tensor_parallel.sharding_strategy import ShardingStrategy
from .utils import (get_updated_plugin, to_base_class_layer, to_subclass_layer,
to_trt_weights)
default_int_dtype = trt.int64
# These dataclasses are used in ParallelConfig serialization. If there are other classes need to be serialized, please add to this list.
BASE_AUTOPP_CLASSES = {
"tensorrt_llm.auto_parallel.parallelization": ["ParallelConfig"],
"tensorrt_llm.auto_parallel.config": ["AutoParallelConfig", "CostModel"],
"tensorrt_llm.auto_parallel.simplifier": ["GraphConfig", "StageType"]
}
@dataclass
class ParallelConfig:
VERSION: ClassVar[str] = __version__
version: str = VERSION
network_hash: str = None
auto_parallel_config: AutoParallelConfig = None
graph_config: GraphConfig = None
lmesh: LogicalDeviceMesh = None
cost: float = None
graph_strategy: Dict[str, ShardingStrategy] = None
stage_type: StageType = None
def save(self, filename):
with open(filename, 'wb') as file:
serialization.dump(self, file)
@staticmethod
def from_file(filename) -> "ParallelConfig":
with open(filename, "rb") as file:
return serialization.load(file,
approved_imports=BASE_AUTOPP_CLASSES)
def print_graph_strategy(self, file=None):
for index, (node_name,
strategy) in enumerate(self.graph_strategy.items()):
print(f'\n[{index}]: node_name = {node_name}', file=file)
strategy.print_strategy(best_resharding_cost_only=True, file=file)
def desimplify_strategy(
graph: PipelineGraph,
graph_strategy: Dict[str, ShardingStrategy],
graph_mapping: GraphMapping,
):
for strategy in graph_strategy.values():
for name, commspec in list(strategy.communication_actions.items()):
strategy.communication_actions[name] = [commspec]
strategy.sharding_specs[
f"{name}_after_comm"] = strategy.sharding_specs[name]
# insert same spec layers' communication actions after
# its producer's communication actions
same_spec_layer_mapping = graph_mapping.same_spec_layer_mapping
for same_spec_layer_name in same_spec_layer_mapping.keys():
same_spec_strategy = graph_strategy[same_spec_layer_name]
same_spec_commspecs = same_spec_strategy.best_resharding_cost[0][0][1]
if len(same_spec_commspecs) == 0:
continue
output_name = same_spec_layer_name[:-len("_same_spec")]
output = graph.get_tensor(output_name)
layer_name = output.producer.name
output_index = output.output_index
strategy = graph_strategy[layer_name]
commspecs = strategy.communication_actions.get(f"output{output_index}",
[])
commspecs.extend(same_spec_commspecs)
strategy.communication_actions[f"output{output_index}"] = commspecs
strategy.sharding_specs[
f"output{output_index}_after_comm"] = same_spec_strategy.sharding_specs[
"output0"]
layer_mapping = graph_mapping.layer_mapping
for removed_layer_name, layer_name in layer_mapping.items():
if layer_name in graph_strategy:
strategy = copy.copy(graph_strategy[layer_name])
layer = graph.get_layer(removed_layer_name)
if layer is not None:
strategy.node_names = strategy.node_names.copy()
for index, name in list(strategy.node_names.items()):
input = layer.get_input(index)
node_name = input.name if input.producer is None else input.producer.name
strategy.node_names[index] = node_name
graph_strategy[removed_layer_name] = strategy
@dataclass
class SplitInfo:
input_dim: Union[int, trt.ITensor]
partition: int
def __deepcopy__(self, memo) -> "SplitInfo":
return SplitInfo(self.input_dim, self.partition)
@dataclass
class TensorInfo:
name: str = None
split_infos: Dict[int, SplitInfo] = field(default_factory=dict)
def set_split_info(self, dim, split_info):
self.split_infos[dim] = split_info
def __deepcopy__(self, memo) -> "TensorInfo":
return TensorInfo(self.name, copy.deepcopy(self.split_infos))
@dataclass
class TensorContext:
info_by_device: Dict[int, TensorInfo] = field(default_factory=dict)
device_dims_for_shape: Set[int] = field(default_factory=set)
def update_name_mapping(self, device_id, new_name):
if device_id not in self.info_by_device:
self.info_by_device[device_id] = TensorInfo()
self.info_by_device[device_id].name = new_name
def set_split_info(self, device_id, dim, split_info):
if device_id not in self.info_by_device:
self.info_by_device[device_id] = TensorInfo()
self.info_by_device[device_id].set_split_info(dim, split_info)
def set_split_infos(self, device_id, split_infos: Dict[int, SplitInfo]):
if device_id not in self.info_by_device:
self.info_by_device[device_id] = TensorInfo()
self.info_by_device[device_id].split_infos = split_infos
def __deepcopy__(self, memo) -> "TensorContext":
return TensorContext(copy.deepcopy(self.info_by_device),
set(self.device_dims_for_shape))
@dataclass
class LayerUpdate:
updated_attrs: Dict[str, Any] = field(default_factory=dict)
updated_inputs: Dict[int, trt.ITensor] = field(default_factory=dict)
split_info_updated: bool = False
@staticmethod
def none() -> "LayerUpdate":
return LayerUpdate()
@dataclass
class GraphContext:
tensor_contexts: Dict[str, TensorContext] = field(default_factory=dict)
def get_name(self, tensor_name, device_id):
if tensor_name not in self.tensor_contexts:
return None
if device_id not in self.tensor_contexts[tensor_name].info_by_device:
return None
return self.tensor_contexts[tensor_name].info_by_device[device_id].name
def update_name_mapping(self, tensor_name, device_id, new_name):
if tensor_name not in self.tensor_contexts:
self.tensor_contexts[tensor_name] = TensorContext()
self.tensor_contexts[tensor_name].update_name_mapping(
device_id, new_name)
def get_name_mapping(self, device_id, prefix: str) -> Dict[str, str]:
name_mapping = {}
for tensor_name in self.tensor_contexts.keys():
new_name = self.get_name(tensor_name, device_id)
if new_name is not None:
name_mapping[f"{prefix}{tensor_name}"] = new_name
return name_mapping
def add_device_dims_for_shape(self, tensor_name: str,
device_dims: Sequence[int]):
if tensor_name not in self.tensor_contexts:
self.tensor_contexts[tensor_name] = TensorContext()
self.tensor_contexts[tensor_name].device_dims_for_shape.update(
device_dims)
def get_device_dims_for_shape(self, tensor_name: str):
if tensor_name not in self.tensor_contexts:
return set()
return self.tensor_contexts[tensor_name].device_dims_for_shape
def get_split_infos(self, tensor_name, device_id):
if tensor_name not in self.tensor_contexts:
return None
if device_id not in self.tensor_contexts[tensor_name].info_by_device:
return None
return self.tensor_contexts[tensor_name].info_by_device[
device_id].split_infos
def set_split_info(self, tensor_name, device_id, dim, split_info):
if tensor_name not in self.tensor_contexts:
self.tensor_contexts[tensor_name] = TensorContext()
self.tensor_contexts[tensor_name].set_split_info(
device_id, dim, split_info)
def set_split_infos(self, tensor_name, device_id,
split_infos: Dict[int, SplitInfo]):
if tensor_name not in self.tensor_contexts:
self.tensor_contexts[tensor_name] = TensorContext()
self.tensor_contexts[tensor_name].set_split_infos(
device_id, split_infos)
def update_layer_context(self, wrapped_layer: Layer,
layer_update: LayerUpdate,
local_context: "GraphContext", device_id: int,
device_ids: np.ndarray,
sharding_specs: Dict[str, ShardingSpec]):
layer = wrapped_layer.as_trt()
for i in range(layer.num_outputs):
output = layer.get_output(i)
new_name = local_context.get_name(output.name, device_id)
if new_name is not None:
self.update_name_mapping(output.name, device_id, new_name)
if layer_update.split_info_updated:
for i in range(layer.num_outputs):
output = layer.get_output(i)
split_infos = local_context.get_split_infos(
output.name, device_id)
if split_infos is not None:
self.set_split_infos(output.name, device_id, split_infos)
return
split_info_by_device_dim = {}
for i in range(layer.num_inputs):
input = layer.get_input(i)
if input is None:
continue
sharding_spec = sharding_specs[f"input{i}"]
split_infos = local_context.get_split_infos(input.name, device_id)
if split_infos is None:
continue
for dim, split_info in split_infos.items():
device_dim = tuple(sharding_spec.dim_partition_dict[dim])
split_info_by_device_dim[device_dim] = split_info
for i in range(layer.num_outputs):
output = layer.get_output(i)
sharding_spec = sharding_specs[f"output{i}"]
for dim, device_dim in sharding_spec.dim_partition_dict.items():
split_info = split_info_by_device_dim.get(tuple(device_dim))
if split_info is None:
if device_dim == [0, 1] or device_dim == [1, 0]:
if (0, ) in split_info_by_device_dim and (
1, ) in split_info_by_device_dim:
split_info = SplitInfo(
split_info_by_device_dim[(0, )].input_dim *
split_info_by_device_dim[(1, )].input_dim,
split_info_by_device_dim[(0, )].partition *
split_info_by_device_dim[(1, )].partition,
)
assert split_info is not None
partition = get_partition(device_dim, device_ids)
if split_info.input_dim != output.shape[dim]:
assert output.shape[
dim] > 0 and output.shape[dim] % partition == 0
output_split_info = SplitInfo(output.shape[dim], partition)
self.set_split_info(output.name, device_id, dim,
output_split_info)
def get_local_context(self, layer: trt.ILayer) -> "GraphContext":
local_context = GraphContext()
for i in range(layer.num_inputs):
input = layer.get_input(i)
if input is None:
continue
local_context.tensor_contexts[input.name] = copy.deepcopy(
self.tensor_contexts[input.name])
return local_context
def get_local_context_for_output(self,
output: trt.ITensor) -> "GraphContext":
local_context = GraphContext()
local_context.tensor_contexts[output.name] = copy.deepcopy(
self.tensor_contexts[output.name])
return local_context
def merge_context(self, context: "GraphContext"):
self.tensor_contexts.update(context.tensor_contexts)
@dataclass
class ShardContext:
graph_context: GraphContext
layer: Layer
nditer: np.nditer
device_ids: np.ndarray
strategy: ShardingStrategy
def get_partition(device_dim, device_ids):
if device_dim == [0]:
partition = device_ids.shape[0]
elif device_dim == [1]:
partition = device_ids.shape[1]
else:
assert device_dim == [0, 1] or device_dim == [1, 0]
partition = device_ids.size
return partition
def get_index(device_dim, iter):
if device_dim == [0]:
index = iter.multi_index[0]
elif device_dim == [1]:
index = iter.multi_index[1]
else:
assert device_dim == [0, 1] or device_dim == [1, 0]
index = iter.iterindex
return index
def get_full_sharding_spec(sharding_spec):
return ShardingSpec(sharding_spec.device_mesh,
sharding_spec.data_type_size,
sharding_spec.entire_shape,
sharding_spec.max_entire_shape,
sharding_spec.raw_shape,
dim_partition_dict={})
def get_comm_action_sequence(from_sharding_sepc, to_sharding_sepc):
comm_action_sequence = from_sharding_sepc.device_mesh.shape_consistency_manager.shape_consistency(
from_sharding_sepc, to_sharding_sepc)[1]
# TODO: should merged by shape_consistency
if len(comm_action_sequence) == 2:
if comm_action_sequence[0].comm_pattern == comm_action_sequence[
1].comm_pattern == "all_gather":
if comm_action_sequence[0].gather_dim == comm_action_sequence[
1].gather_dim:
comm_action_sequence = [
CommSpec(
comm_action_sequence[0].comm_pattern,
comm_action_sequence[0].sharding_spec,
comm_action_sequence[0].gather_dim,
comm_action_sequence[0].shard_dim, [[
*comm_action_sequence[0].logical_process_axis[0],
*comm_action_sequence[1].logical_process_axis[0]
]], comm_action_sequence[0].mix_gather,
comm_action_sequence[0].forward_only)
]
assert len(comm_action_sequence[0].logical_process_axis[0]) <= 2
assert len(comm_action_sequence) <= 1
return comm_action_sequence
class GraphGroup(ABC):
@staticmethod
def from_graph(
graph: PipelineGraph,
config: ParallelConfig,
auto_parallel_config: AutoParallelConfig,
) -> "GraphGroup":
if auto_parallel_config.debug_mode:
return PrefixedGraphGroup(graph, config, auto_parallel_config)
else:
return DistributedGraphGroup(graph, config, auto_parallel_config)
@property
@abstractmethod
def auto_parallel_config(self) -> AutoParallelConfig:
...
@abstractmethod
def add_input(self, tensor, device_ids, strategy: ShardingStrategy):
...
@abstractmethod
def add_layer(self, layer, device_ids, strategy: ShardingStrategy):
...
@abstractmethod
def add_output(self, tensor, device_ids, sharding_spec: ShardingSpec):
...
@abstractmethod
def get_network(self, device_id) -> trt.INetworkDefinition:
...
@abstractmethod
def get_graph(self, device_id) -> PipelineGraph:
...
@property
@abstractmethod
def full_graph(self) -> PipelineGraph:
...
@abstractmethod
def get_prefix(self, device_id) -> str:
...
@abstractmethod
def get_shapes(self, device_id) -> Dict[str, Tuple[int, ...]]:
...
@abstractmethod
def get_values(self, device_id) -> Dict[str, List[int]]:
...
@abstractmethod
def add_all_reduce_layer(self, context: GraphContext, input_name,
output_name, device_ids, to_reduce_tensors):
...
@abstractmethod
def add_all_gather_layer(self, context: GraphContext, input_name,
output_name, device_ids, to_gather_tensors):
...
@abstractmethod
def register_layer(self,
layer,
base_name,
input_name,
output_name=None,
device_id=None,
keep_tensor_name=False) -> Layer:
...
def get_tensor(self, context: GraphContext, tensor_name: str,
device_id: int) -> Tensor:
name = context.get_name(tensor_name, device_id)
return self.get_graph(device_id).get_tensor(name)
def add_comm(self,
context: GraphContext,
input_name,
device_ids,
commspec,
output_name=None,
is_singleton=False):
remove_index = []
for i, device_dim in enumerate(commspec.logical_process_axis):
partition = get_partition(device_dim, device_ids)
if partition == 1:
remove_index.append(i)
if len(remove_index) > 0:
if commspec.comm_pattern in ["all_gather", "all_to_all"]:
commspec.gather_dim = [
dim for i, dim in enumerate(commspec.gather_dim)
if i not in remove_index
]
if commspec.comm_pattern in [
"split", "reduce_scatter", "all_to_all"
]:
commspec.shard_dim = [
dim for i, dim in enumerate(commspec.shard_dim)
if i not in remove_index
]
commspec.logical_process_axis = [
dim for i, dim in enumerate(commspec.logical_process_axis)
if i not in remove_index
]
flatten_device_dim = list(
itertools.chain.from_iterable(commspec.logical_process_axis))
if flatten_device_dim == []:
return
if flatten_device_dim == [0, 1] or flatten_device_dim == [1, 0]:
self._add_comm(context, input_name, device_ids, commspec,
output_name, is_singleton)
elif flatten_device_dim == [0]:
for i in range(device_ids.shape[1]):
self._add_comm(context, input_name, device_ids[:, i:i + 1],
commspec, output_name, is_singleton)
elif flatten_device_dim == [1]:
for i in range(device_ids.shape[0]):
self._add_comm(context, input_name, device_ids[i:i + 1, :],
commspec, output_name, is_singleton)
else:
raise RuntimeError(
f"Invalid flatten device_dim: {flatten_device_dim}")
def _add_comm(self,
context: GraphContext,
input_name,
device_ids,
commspec,
output_name=None,
is_singleton=False):
input_tensors = [
self.get_tensor(context, input_name, device_id.item())
for device_id in np.nditer(device_ids)
]
comm_pattern = commspec.comm_pattern
if comm_pattern == "split":
self.add_split(context, input_name, output_name, device_ids,
commspec.shard_dim, commspec.logical_process_axis)
elif comm_pattern == "all_gather":
self.add_all_gather(context, input_name, output_name, device_ids,
commspec.gather_dim,
commspec.logical_process_axis, is_singleton)
elif comm_pattern == "all_reduce":
self.add_all_reduce(context, input_name, output_name, device_ids)
elif comm_pattern == "reduce_scatter":
self.add_reduce_scatter(context, input_name, output_name,
device_ids, commspec.shard_dim,
commspec.logical_process_axis)
elif comm_pattern == "all_to_all":
self.add_all_to_all(context, input_name, output_name, device_ids,
commspec.gather_dim, commspec.shard_dim,
commspec.logical_process_axis)
else:
raise NotImplementedError
output_tensors = [
self.get_tensor(context, input_name, device_id.item())
for device_id in np.nditer(device_ids)
]
for input_tensor, output_tensor in zip(input_tensors, output_tensors):
if input_tensor.dtype != output_tensor.dtype:
raise ValueError(
f"Input tensor and output tensor should have the same dtype for communication layers, "
f"input dtype is {input_tensor.dtype} for {input_tensor.name}, "
f"but output dtype is {output_tensor.dtype} for {output_tensor.name}"
)
def add_all_reduce(self, context: GraphContext, input_name, output_name,
device_ids):
dtype = str_dtype_to_trt(self.full_graph._plugin_config.dtype)
to_reduce_tensors = []
for device_id in np.nditer(device_ids):
device_id = device_id.item()
layer_info = (input_name, output_name, device_id)
network = self.get_network(device_id)
input_tensor = self.get_tensor(context, input_name,
device_id).as_trt()
input_dtype = input_tensor.dtype
if input_dtype != dtype:
to_reduce_tensor = self.cast(
network,
input_tensor,
dtype,
layer_info,
)
else:
to_reduce_tensor = input_tensor
to_reduce_tensors.append(to_reduce_tensor)
self.add_all_reduce_layer(context, input_name, output_name, device_ids,
to_reduce_tensors)
if input_dtype != dtype:
for device_id in np.nditer(device_ids):
device_id = device_id.item()
layer_info = (input_name, output_name, device_id)
network = self.get_network(device_id)
input_tensor = self.get_tensor(
context,
input_name,
device_id,
).as_trt()
output_tensor = self.cast(
network,
input_tensor,
input_dtype,
layer_info,
)
context.update_name_mapping(
input_name,
device_id,
output_tensor.name,
)
def add_reduce_scatter(self, context: GraphContext, input_name, output_name,
device_ids, shard_dims, device_dims):
self.add_all_reduce(context, input_name, output_name, device_ids)
self.add_split(context, input_name, output_name, device_ids, shard_dims,
device_dims)
# TODO: use native all_to_all operation
def add_all_to_all(self, context: GraphContext, input_name, output_name,
device_ids, gather_dims, shard_dims, device_dims):
self.add_all_gather(context, input_name, output_name, device_ids,
gather_dims, device_dims)
self.add_split(context, input_name, output_name, device_ids, shard_dims,
device_dims)
def get_item(self, network, tensor, index, layer_info):
get_item_layer = network.add_slice(tensor, [index], [1], [1])
self.register_layer(get_item_layer, f"get_item{index}", *layer_info)
return get_item_layer.get_output(0)
def get_shape(self, network, tensor, layer_info):
shape_layer = network.add_shape(tensor)
self.register_layer(shape_layer, "shape", *layer_info)
return shape_layer.get_output(0)
def concat(self, network, tensors, layer_info):
concat_layer = network.add_concatenation(tensors)
self.register_layer(concat_layer, "concat", *layer_info)
return concat_layer.get_output(0)
def flatten(self, network, tensor, layer_info):
shuffle_layer = network.add_shuffle(tensor)
shuffle_layer.reshape_dims = [-1]
shuffle_layer.zero_is_placeholder = False
self.register_layer(shuffle_layer, "flatten", *layer_info)
return shuffle_layer.get_output(0)
def reshape(self, network, tensor, reshape_dims, layer_info):
reshape_layer = network.add_shuffle(tensor)
reshape_layer.set_input(1, reshape_dims)
reshape_layer.zero_is_placeholder = False
self.register_layer(reshape_layer, "reshape", *layer_info)
return reshape_layer.get_output(0)
def cast(self, network, tensor, dtype, layer_info):
if tensor.dtype == dtype:
return tensor
cast_layer = network.add_cast(tensor, dtype)
self.register_layer(cast_layer, "cast", *layer_info)
return cast_layer.get_output(0)
def const_int(self, network, name, value, layer_info):
const_layer = network.add_constant(
[1], np.array([value], dtype=trt_dtype_to_np(default_int_dtype)))
self.register_layer(const_layer, name, *layer_info)
return const_layer.get_output(0)
def get_dim_size(self, network, tensor, dim, layer_info, shape_tensor=None):
raw_shape = tensor.shape
dim_size = raw_shape[dim]
if dim_size != -1:
return dim_size
else:
if shape_tensor is None:
shape_tensor = self.get_shape(network, tensor, layer_info)
return self.get_item(network, shape_tensor, dim, layer_info)
def add_split(self, context: GraphContext, input_name, output_name,
device_ids, shard_dims, device_dims):
it = np.nditer(device_ids, flags=['multi_index'])
for device_id in it:
device_id = device_id.item()
layer_info = (input_name, output_name, device_id)
network = self.get_network(device_id)
input_tensor = self.get_tensor(context, input_name,
device_id).as_trt()
raw_input_shape = input_tensor.shape
start = []
output_dims = []
stride = []
input_shape_tensor = self.get_shape(network, input_tensor,
layer_info)
for dim in range(len(raw_input_shape)):
stride.append(1)
if dim not in shard_dims:
start.append(0)
output_dims.append(
self.get_item(network, input_shape_tensor, dim,
layer_info))
else:
start.append(None)
output_dims.append(None)
for dim, device_dim in zip(shard_dims, device_dims):
partition = get_partition(device_dim, device_ids)
index = get_index(device_dim, it)
input_dim = raw_input_shape[dim]
assert input_dim != -1
assert input_dim % partition == 0
quotient = input_dim // partition
start[dim] = index * quotient
output_dims[dim] = self.const_int(network, f"output_dim{dim}",
quotient, layer_info)
context.set_split_info(input_name, device_id, dim,
SplitInfo(input_dim, partition))
output_dims_tensor = self.concat(network, output_dims, layer_info)
split_layer = network.add_slice(input_tensor, start, [], stride)
split_layer.set_input(2, output_dims_tensor)
wrapped_layer = self.register_layer(split_layer, "split",
*layer_info)
wrapped_layer.attrs["strategy"] = get_sharding_sequence(
len(raw_input_shape),
shard_dims,
device_dims,
)
output_tensor = split_layer.get_output(0)
context.update_name_mapping(input_name, device_id,
output_tensor.name)
def add_all_gather(self,
context: GraphContext,
input_name,
output_name,
device_ids,
gather_dims,
device_dims,
is_singleton=False):
to_gather_tensors = []
for device_id in np.nditer(device_ids):
device_id = device_id.item()
layer_info = (input_name, output_name, device_id)
network = self.get_network(device_id)
input_tensor = self.get_tensor(context, input_name,
device_id).as_trt()
to_gather_tensor = self.flatten(network, input_tensor, layer_info)
to_gather_tensors.append(to_gather_tensor)
all_gather_layers = self.add_all_gather_layer(
context,
input_name,
output_name,
device_ids,
to_gather_tensors,
)
if len(device_dims) == 1:
gather_indices = [0]
elif len(device_dims) == 2 and device_dims[0] == [1]:
gather_indices = [1, 0]
else:
gather_indices = [0, 1]
for device_id, all_gather_layer in zip(np.nditer(device_ids),
all_gather_layers):
device_id = device_id.item()
layer_info = (input_name, output_name, device_id)
network = self.get_network(device_id)
input_tensor = self.get_tensor(context, input_name,
device_id).as_trt()
permutation = []
gathered_dims = []
output_dims = []
partitions = []
raw_input_shape = input_tensor.shape
wrapped_layer = self.get_graph(device_id).get_layer(
all_gather_layer.name)
wrapped_layer.attrs["strategy"] = get_sharding_sequence(
len(raw_input_shape),
gather_dims,
device_dims,
)
input_shape_layer = network.add_shape(input_tensor)
self.register_layer(input_shape_layer, "input_shape", *layer_info)
input_shape_tensor = input_shape_layer.get_output(0)
split_infos = context.get_split_infos(input_name, device_id)
for index in gather_indices:
gather_dim = gather_dims[index]
device_dim = device_dims[index]
partition = get_partition(device_dim, device_ids)
assert partition == split_infos[gather_dim].partition
partitions.append(
self.const_int(network, f"partition_num{gather_dim}",
partition, layer_info))
for dim in range(len(raw_input_shape)):
if dim in gather_dims:
gather_index = gather_dims.index(dim)
device_dim = device_dims[gather_index]
permutation.append(gather_indices.index(gather_index))
permutation.append(dim + len(gather_dims))
if dim not in split_infos:
output_dim_layer = network.add_slice(
input_shape_tensor, [dim], [1], [1])
self.register_layer(output_dim_layer, f"output_dim{dim}",
*layer_info)
dim_tensor = output_dim_layer.get_output(0)
output_dims.append(dim_tensor)
gathered_dims.append(dim_tensor)
else:
input_dim = split_infos[dim].input_dim
partition = split_infos[dim].partition
assert input_dim != -1
assert input_dim % partition == 0
quotient = input_dim // partition
output_dims.append(
self.const_int(network, f"output_dim{dim}", quotient,
layer_info))
if dim in gather_dims:
gathered_dims.append(
self.const_int(network, f"gathered_dim{dim}",
quotient * partition, layer_info))
del split_infos[dim]
else:
gathered_dims.append(output_dim_layer.get_output(0))
reshape_dims_for_transpose_layer = network.add_concatenation(
[*partitions, *output_dims])
self.register_layer(reshape_dims_for_transpose_layer,
"reshape_dims_for_transpose", *layer_info)
reshape_dims_tensor = reshape_dims_for_transpose_layer.get_output(0)
transpose_layer = network.add_shuffle(
all_gather_layer.get_output(0))
transpose_layer.set_input(1, reshape_dims_tensor)
transpose_layer.second_transpose = permutation
transpose_layer.zero_is_placeholder = False
self.register_layer(transpose_layer, "transpose", *layer_info)
reshape_dims_for_reshape_layer = network.add_concatenation(
gathered_dims)
self.register_layer(reshape_dims_for_reshape_layer,
"reshape_dims_for_reshape", *layer_info)
reshape_dims_tensor = reshape_dims_for_reshape_layer.get_output(0)
output_tensor = self.reshape(
network,
transpose_layer.get_output(0),
reshape_dims_tensor,
layer_info,
)
context.update_name_mapping(input_name, device_id,
output_tensor.name)
if is_singleton:
break
def register_unfilled_weights(self, graph, layer):
if (layer.name in self.full_graph._unfilled_weights
and layer.name not in graph._unfilled_weights):
weights, values = self.full_graph._unfilled_weights[layer.name]
graph._register_unfilled_weights(
layer.name,
weights,
values,
)
def shard_constant(self, context: ShardContext):
sharding_spec = context.strategy.sharding_specs["output0"]
shard_dims = sharding_spec.dim_partition_dict
device_id = context.nditer.value.item()
device_ids = context.device_ids
layer = context.layer.as_trt()
graph = self.get_graph(device_id)
if len(shard_dims) == 0:
self.register_unfilled_weights(graph, layer)
return LayerUpdate(split_info_updated=True)
flatten_device_dim = list(
itertools.chain.from_iterable(shard_dims.values()))
output_name = layer.get_output(0).name
output_dtype = layer.get_output(0).dtype
output_shape = layer.shape
output_dims = []
weight_index = []
for dim in range(len(output_shape)):
output_dim = output_shape[dim]
if dim in shard_dims:
device_dim = shard_dims[dim]
partition = get_partition(device_dim, device_ids)
index = get_index(device_dim, context.nditer)
assert output_dim % partition == 0
quotient = output_dim // partition
output_dims.append(quotient)
weight_index.append(
slice(index * quotient, (index + 1) * quotient))
context.graph_context.set_split_info(
output_name, device_id, dim,
SplitInfo(output_dim, partition))
else:
output_dims.append(output_dim)
weight_index.append(slice(None))
if layer.name in self.full_graph._unfilled_weights:
values = self.full_graph._unfilled_weights[layer.name][1]
else:
values = layer.weights
if isinstance(values, trt.Weights):
values = values.numpy()
# TODO: remove this WAR after https://nvbugs/4359151 fixed.
if isinstance(values, trt.Weights):
network = context.layer.graph.as_trt()
values = get_np_weight(network, layer.name)
if values is not None:
values = values.reshape(layer.shape)
assert values.size == np.prod(layer.shape)
sharded_values = values[tuple(weight_index)]
assert sharded_values.size * get_partition(
flatten_device_dim, device_ids) == np.prod(layer.shape)
else:
sharded_values = None
dtype = trt_dtype_to_np(output_dtype)
sharded_weights = np.empty(tuple(output_dims), dtype)
graph._register_unfilled_weights(
f"device{device_id}_{layer.name}",
sharded_weights,
sharded_values,
)
sharded_weights = to_trt_weights(sharded_weights)
return LayerUpdate(
updated_attrs=dict(
shape=trt.Dims(output_dims),
weights=sharded_weights,
),
split_info_updated=True,
)
def shard_fill(self, context: ShardContext):
sharding_spec = context.strategy.sharding_specs["output0"]
shard_dims = sharding_spec.dim_partition_dict
if len(shard_dims) == 0:
return LayerUpdate(split_info_updated=True)
device_id = context.nditer.value.item()
device_ids = context.device_ids
layer = context.layer.as_trt()
output_name = layer.get_output(0).name
output_shape = layer.shape
output_dims = []
for dim in range(len(output_shape)):
output_dim = output_shape[dim]
if dim in shard_dims:
device_dim = shard_dims[dim]
partition = get_partition(device_dim, device_ids)
assert output_dim % partition == 0
quotient = output_dim // partition
output_dims.append(quotient)
context.graph_context.set_split_info(
output_name, device_id, dim,
SplitInfo(output_dim, partition))
else:
output_dims.append(output_dim)
return LayerUpdate(
updated_attrs=dict(shape=trt.Dims(output_dims), ),
split_info_updated=True,
)
def update_shape(self, context: ShardContext):
if not context.layer.is_shape_io:
return
layer = context.layer.as_trt()
input_name = layer.get_input(0).name
output_name = layer.get_output(0).name
device_id = context.nditer.value.item()
layer_info = (output_name, None, device_id)
split_infos = context.graph_context.get_split_infos(
input_name, device_id)
if len(split_infos) == 0:
return
network = self.get_network(device_id)
shape_tensor = self.get_tensor(context.graph_context, output_name,
device_id).as_trt()
output_dims = []
for dim in range(len(context.layer.get_input(0).shape)):
if dim not in split_infos:
output_dim_layer = network.add_slice(shape_tensor, [dim], [1],
[1])
else:
input_dim = split_infos[dim].input_dim
output_dim_layer = network.add_constant(
[1], np.array([input_dim], dtype=default_int_dtype))
self.register_layer(output_dim_layer, f"output_dim{dim}",
*layer_info)
output_dims.append(output_dim_layer.get_output(0))
new_shape_layer = network.add_concatenation(output_dims)
self.register_layer(new_shape_layer, "new_shape", *layer_info)
new_shape_tensor = new_shape_layer.get_output(0)
context.graph_context.update_name_mapping(output_name, device_id,
new_shape_tensor.name)
def shard_slice(self, context: ShardContext):
sharding_spec = context.strategy.sharding_specs["output0"]
shard_dims = sharding_spec.dim_partition_dict
if len(shard_dims) == 0:
return LayerUpdate.none()
device_id = context.nditer.value.item()
network = self.get_network(device_id)
device_ids = context.device_ids
layer = context.layer.as_trt()
output_dims = []
updated_attrs = {}
updated_inputs = {}
if layer.num_inputs >= 3:
raw_output_shape = layer.get_output(0).shape
input_name = layer.get_input(2).name
layer_info = (input_name, layer.name, device_id)
shape_tensor = self.get_tensor(context.graph_context, input_name,
device_id).as_trt()
for dim in range(len(raw_output_shape)):
output_dim_layer = network.add_slice(shape_tensor, [dim], [1],
[1])
self.register_layer(output_dim_layer, f"output_dim{dim}",
*layer_info)
if dim in shard_dims:
device_dim = shard_dims[dim]
partition = get_partition(device_dim, device_ids)
partition_num_tensor = self.const_int(
network, f"partition_num{dim}", partition, layer_info)
quotient_layer = network.add_elementwise(
output_dim_layer.get_output(0), partition_num_tensor,
trt.ElementWiseOperation.FLOOR_DIV)
self.register_layer(quotient_layer, f"quotient{dim}",
*layer_info)
output_dim = self.cast(network,
quotient_layer.get_output(0),
default_int_dtype, layer_info)
output_dims.append(output_dim)
else:
output_dims.append(output_dim_layer.get_output(0))
output_dims_layer = network.add_concatenation(output_dims)
self.register_layer(output_dims_layer, "output_dims", *layer_info)
updated_inputs[2] = output_dims_layer.get_output(0)
else:
output_shape = layer.shape
for dim in range(len(output_shape)):
output_dim = output_shape[dim]
assert output_dim != -1
if dim in shard_dims:
device_dim = shard_dims[dim]
partition = get_partition(device_dim, device_ids)
assert output_dim % partition == 0
quotient = output_dim // partition
output_dims.append(quotient)
else:
output_dims.append(output_dim)
updated_attrs["shape"] = trt.Dims(output_dims)
return LayerUpdate(updated_attrs, updated_inputs)
def shard_shuffle(self, context: ShardContext):
sharding_spec = context.strategy.sharding_specs["output0"]
shard_dims = sharding_spec.dim_partition_dict
if len(shard_dims) == 0:
return LayerUpdate.none()
device_id = context.nditer.value.item()
network = self.get_network(device_id)
device_ids = context.device_ids
layer = context.layer.as_trt()
updated_attrs = {}
updated_inputs = {}
updated_reshape_dims = {}
second_transpose = layer.second_transpose
if layer.num_inputs >= 2:
raw_output_shape = layer.get_output(0).shape
input_name = layer.get_input(1).name
layer_info = (input_name, layer.name, device_id)
reshape_dims_tensor = self.get_tensor(context.graph_context,
input_name, device_id)
reshape_dims = context.layer.get_input(1).value
reshape_dims_tensor = reshape_dims_tensor.as_trt()
for dim in range(len(raw_output_shape)):
if second_transpose is not None:
reshape_dim = second_transpose[dim]
else:
reshape_dim = dim
output_dim_layer = network.add_slice(reshape_dims_tensor,
[reshape_dim], [1], [1])
self.register_layer(output_dim_layer, f"output_dim{dim}",
*layer_info)
output_dim = reshape_dims[reshape_dim]
if dim in shard_dims and output_dim != -1:
device_dim = shard_dims[dim]
partition = get_partition(device_dim, device_ids)
partition_num_tensor = self.const_int(
network, f"partition_num{dim}", partition, layer_info)
quotient_layer = network.add_elementwise(
output_dim_layer.get_output(0), partition_num_tensor,
trt.ElementWiseOperation.FLOOR_DIV)
self.register_layer(quotient_layer, f"quotient{dim}",
*layer_info)
updated_reshape_dims[reshape_dim] = self.cast(
network,
quotient_layer.get_output(0),
default_int_dtype,
layer_info,
)
else:
updated_reshape_dims[
reshape_dim] = output_dim_layer.get_output(0)
updated_reshape_dims = list(
map(lambda x: x[1], sorted(updated_reshape_dims.items())))
reshape_dims_layer = network.add_concatenation(updated_reshape_dims)
self.register_layer(reshape_dims_layer, "reshape_dims", *layer_info)
updated_inputs[1] = reshape_dims_layer.get_output(0)
else:
reshape_dims = layer.reshape_dims
if reshape_dims.__len__() < 0:
return LayerUpdate.none()
for dim in range(len(reshape_dims)):
if second_transpose is not None:
reshape_dim = second_transpose[dim]
else:
reshape_dim = dim
output_dim = reshape_dims[reshape_dim]
if dim in shard_dims and output_dim != -1:
device_dim = shard_dims[dim]
partition = get_partition(device_dim, device_ids)
quotient = output_dim // partition
updated_reshape_dims[reshape_dim] = quotient
else:
updated_reshape_dims[reshape_dim] = output_dim
updated_reshape_dims = list(
map(lambda x: x[1], sorted(updated_reshape_dims.items())))
updated_attrs["reshape_dims"] = trt.Dims(updated_reshape_dims)
return LayerUpdate(updated_attrs, updated_inputs)
def shard_gpt_attention(self, context: ShardContext):
layer = context.layer.as_trt()
plugin_info = get_plugin_info(
self.full_graph.as_trt(),
layer.name,
)
parser = IdxEntryParser(plugin_info)
head_dim = 1 if parser.remove_input_padding else 2
sharding_spec = context.strategy.sharding_specs[
f"input{parser.get_index(IdxEntry.QKV_TENSOR)}"]
shard_dims = sharding_spec.dim_partition_dict
if head_dim not in shard_dims:
return LayerUpdate.none()
device_id = context.nditer.value.item()
network = self.get_network(device_id)
device_ids = context.device_ids
updated_attrs = {}
updated_inputs = {}
device_dim = shard_dims[head_dim]
partition = get_partition(device_dim, device_ids)
index = get_index(device_dim, context.nditer)
if parser.is_entry_used(IdxEntry.K_TENSOR):
kv_sharding_spec = context.strategy.sharding_specs[
f"input{parser.get_index(IdxEntry.K_TENSOR)}"]
kv_shard_dims = kv_sharding_spec.dim_partition_dict
if head_dim in kv_shard_dims:
kv_device_dim = kv_shard_dims[head_dim]
kv_partition = get_partition(kv_device_dim, device_ids)
else:
kv_partition = 1
else:
kv_partition = 1
num_heads = plugin_info.pfc_as_ndarray["num_heads"].copy()
num_kv_heads = plugin_info.pfc_as_ndarray["num_kv_heads"].copy()
tp_size = plugin_info.pfc_as_ndarray["tp_size"].copy()
tp_rank = plugin_info.pfc_as_ndarray["tp_rank"].copy()
num_kv_heads = np.maximum(num_kv_heads // kv_partition, 1)
num_heads = np.maximum(num_heads // partition, 1)
tp_size[0] = partition
tp_rank[0] = index
new_plugin, new_plugin_info = get_updated_plugin(
plugin_info,
dict(
num_heads=num_heads,
num_kv_heads=num_kv_heads,
tp_size=tp_size,
tp_rank=tp_rank,
))
prefix = self.get_prefix(device_id)
new_layer_name = f"{prefix}{layer.name}"
set_plugin_info(network, new_layer_name, new_plugin_info)
updated_attrs["plugin"] = new_plugin
return LayerUpdate(updated_attrs, updated_inputs)
def shard_lookup(self, context: ShardContext):
sharding_spec = context.strategy.sharding_specs["input1"]
shard_dims = sharding_spec.dim_partition_dict
if 0 not in shard_dims:
return LayerUpdate.none()
layer = context.layer.as_trt()
plugin_info = get_plugin_info(
self.full_graph.as_trt(),
layer.name,
)
device_id = context.nditer.value.item()
network = self.get_network(device_id)
updated_attrs = {}
device_dim = shard_dims[0]
index = get_index(device_dim, context.nditer)
rank = plugin_info.pfc_as_ndarray["rank"].copy()
rank[0] = index
new_plugin, new_plugin_info = get_updated_plugin(
plugin_info, dict(rank=rank, ))
prefix = self.get_prefix(device_id)
new_layer_name = f"{prefix}{layer.name}"
set_plugin_info(network, new_layer_name, new_plugin_info)
updated_attrs["plugin"] = new_plugin
return LayerUpdate(updated_attrs)
class GraphGroupBase(GraphGroup):
def __init__(
self,
full_graph: PipelineGraph,
config: ParallelConfig,
auto_parallel_config: AutoParallelConfig,
) -> None:
self._full_graph = full_graph
self.config = config
self._auto_parallel_config = auto_parallel_config
self.infer_shape = auto_parallel_config.infer_shape
self.global_context = GraphContext()
self.shape_cache = {}
self.suffix = 0
self.current_block_id = -1
@property
def auto_parallel_config(self) -> AutoParallelConfig:
return self._auto_parallel_config
@property
def full_graph(self) -> PipelineGraph:
return self._full_graph
def register_layer(self,
layer,
base_name,
input_name,
output_name=None,
device_id=None,
keep_tensor_name=False) -> Layer:
layer_name = f"{base_name}_{input_name}"
if device_id is not None:
layer_name = f"{self.get_prefix(device_id)}{layer_name}"
if output_name is not None:
layer_name = f"{layer_name}_to_{output_name}"
suffix = self.suffix
self.suffix += 1
layer_name = f"{layer_name}_{suffix}"
if layer.type == trt.LayerType.PLUGIN_V2:
network = self.get_network(device_id)
plugin_info = get_plugin_info(network, layer.name)
if plugin_info is not None:
set_plugin_info(network, layer_name, plugin_info)
delete_plugin_info(network, layer.name)
layer.name = layer_name
layer.metadata = layer.name
if not keep_tensor_name:
for i in range(layer.num_outputs):
output_tensor = layer.get_output(i)
assert output_tensor.shape.__len__() >= 0
output_tensor.name = f"{layer.name}_output_{i}"
wrapped_layer = self.get_graph(device_id).register_layer(layer)
if self.current_block_id != -1:
wrapped_layer.attrs["block_id"] = self.current_block_id
wrapped_layer.attrs["role"] = "helper"
if self.infer_shape:
infer_per_layer_shapes(
layer,
self.get_shapes(device_id),
self.get_values(device_id),
self.shape_cache,
is_shape_io=True,
)
wrapped_layer.assign_shapes(
self.get_shapes(device_id),
self.get_values(device_id),
)
return wrapped_layer
def add_layer(self, wrapped_layer: Layer, device_ids,
strategy: ShardingStrategy):
layer = wrapped_layer.as_trt()
local_context = self.global_context.get_local_context(layer)
self.current_block_id = wrapped_layer.attrs["block_id"]
for i, input in enumerate(wrapped_layer.inputs):
if input is None:
continue
if i not in strategy.best_resharding_cost:
continue
comm_action_sequence = strategy.best_resharding_cost[i][0][1]
for commspec in comm_action_sequence:
self.add_comm(local_context,
input.name,
device_ids,
commspec,
output_name=layer.name)
it = np.nditer(device_ids, flags=['multi_index'])
for device_id in it:
device_id = device_id.item()
layer_type = layer.type
to_subclass_layer(layer)
shard_context = ShardContext(
local_context,
wrapped_layer,
it,
device_ids,
strategy,
)
if layer_type == trt.LayerType.CONSTANT:
layer_update = self.shard_constant(shard_context)
elif layer_type == trt.LayerType.FILL:
layer_update = self.shard_fill(shard_context)
elif layer_type == trt.LayerType.SLICE:
layer_update = self.shard_slice(shard_context)
elif layer_type == trt.LayerType.SHUFFLE:
layer_update = self.shard_shuffle(shard_context)
elif layer_type == trt.LayerType.PLUGIN_V2:
if layer.plugin.plugin_type == "GPTAttention":
layer_update = self.shard_gpt_attention(shard_context)
elif layer.plugin.plugin_type == "Lookup":
layer_update = self.shard_lookup(shard_context)
else:
layer_update = LayerUpdate.none()
else:
layer_update = LayerUpdate.none()
to_base_class_layer(layer)
for i, updated_input in layer_update.updated_inputs.items():
input_name = layer.get_input(i).name
local_context.update_name_mapping(input_name, device_id,
updated_input.name)
if layer.get_input(i).dtype != updated_input.dtype:
raise ValueError(
f"Input dtype mismatch for {layer.name}, "
f"expect {layer.get_input(i).dtype} for {input_name}, "
f"get {updated_input.dtype} for {updated_input.name}")
prefix = self.get_prefix(device_id)
new_wrapped_layer = self.get_graph(device_id).add_layer(
layer,
prefix=prefix,
input_mapping=local_context.get_name_mapping(device_id,
prefix=prefix),
updated_attrs=layer_update.updated_attrs,
)
new_wrapped_layer.attrs["strategy"] = strategy.name
new_wrapped_layer.attrs["block_id"] = self.current_block_id
new_layer = new_wrapped_layer.as_trt()
if self.infer_shape:
infer_per_layer_shapes(
new_layer,
self.get_shapes(device_id),
self.get_values(device_id),
self.shape_cache,
is_shape_io=wrapped_layer.is_shape_io,
)
new_wrapped_layer.assign_shapes(
self.get_shapes(device_id),
self.get_values(device_id),
)
for i in range(layer.num_outputs):
output_tensor = new_layer.get_output(i)
assert output_tensor.shape.__len__() >= 0
local_context.update_name_mapping(
layer.get_output(i).name, device_id, output_tensor.name)
if layer.type == trt.LayerType.SHAPE:
self.update_shape(shard_context)
self.global_context.update_layer_context(
wrapped_layer,
layer_update,
local_context,
device_id,
device_ids,
strategy.sharding_specs,
)
for i in range(layer.num_outputs):
commspecs = strategy.communication_actions.get(f"output{i}")
if commspecs is None:
continue
output = layer.get_output(i)
for commspec in commspecs:
self.add_comm(
self.global_context,
output.name,
device_ids,
commspec,
)
self.current_block_id = -1
class DistributedGraphGroup(GraphGroupBase):
def __init__(
self,
full_graph: PipelineGraph,
config: ParallelConfig,
auto_parallel_config: AutoParallelConfig,
) -> None:
super().__init__(full_graph, config, auto_parallel_config)
self.graphs = {}
self.io_tensor_shards = {}
self.shapes_by_device = {}
self.values_by_device = {}
phy_mesh = config.graph_config.phy_mesh
device_ids = phy_mesh.phy_devices_id
for device_id in np.nditer(device_ids):
device_id = device_id.item()
graph = PipelineGraph.create_graph()
graph._auto_parallel_config = {
"io_shards": {},
"mapping":
Mapping(
world_size=device_ids.size,
rank=device_id,
gpus_per_node=device_ids.shape[1],
tp_size=device_ids.size // config.graph_config.num_stages,
pp_size=config.graph_config.num_stages,
),
}
self.graphs[device_id] = graph
self.shapes_by_device[device_id] = {}
self.values_by_device[device_id] = {}
@contextlib.contextmanager
def disable_infer_shape(self):
infer_shape = self.infer_shape
self.infer_shape = False
yield
self.infer_shape = infer_shape
def get_network(self, device_id) -> trt.INetworkDefinition:
return self.graphs[device_id].as_trt()
def get_graph(self, device_id) -> PipelineGraph:
return self.graphs[device_id]
def get_prefix(self, device_id) -> str:
return ""
def get_shapes(self, device_id) -> Dict[str, Tuple[int, ...]]:
return self.shapes_by_device[device_id]
def get_values(self, device_id) -> Dict[str, List[int]]:
return self.values_by_device[device_id]
def add_reduce_scatter(self, context: GraphContext, input_name, output_name,
device_ids, shard_dims, device_dims):
dtype = str_dtype_to_trt(self.full_graph._plugin_config.dtype)
it = np.nditer(device_ids, flags=['multi_index'])
for device_id in it:
device_id = device_id.item()
layer_info = (input_name, output_name, device_id)
network = self.get_network(device_id)
input_tensor = self.get_tensor(context, input_name,
device_id).as_trt()
raw_input_shape = input_tensor.shape
input_shape_tensor = self.get_shape(network, input_tensor,
layer_info)
if shard_dims != [0]:
permutation = list(range(len(raw_input_shape)))
for dim in shard_dims:
permutation.remove(dim)
permutation = shard_dims + permutation
transpose_layer = network.add_shuffle(input_tensor)
transpose_layer.second_transpose = permutation
self.register_layer(transpose_layer, "input_transpose",
*layer_info)
input_tensor = transpose_layer.get_output(0)
flatten_tensor = self.flatten(network, input_tensor, layer_info)
input_dtype = flatten_tensor.dtype
if input_dtype != dtype:
to_reduce_tensor = self.cast(
network,
flatten_tensor,
dtype,
layer_info,
)
else:
to_reduce_tensor = flatten_tensor
reduce_scatter_plg_creator = trt.get_plugin_registry(
).get_plugin_creator('ReduceScatter', '1', TRT_LLM_PLUGIN_NAMESPACE)
assert reduce_scatter_plg_creator is not None
group = trt.PluginField(
"group",
np.ascontiguousarray(device_ids.reshape(-1).astype(np.int32)),
trt.PluginFieldType.INT32)
pf_type = trt.PluginField(
"type_id", np.array([int(to_reduce_tensor.dtype)], np.int32),
trt.PluginFieldType.INT32)
pfc = trt.PluginFieldCollection([group, pf_type])
rs_plug = reduce_scatter_plg_creator.create_plugin(
"reduce_scatter", pfc)
reduce_scatter_layer = network.add_plugin_v2([to_reduce_tensor],
rs_plug)
plugin_info = PluginInfo(reduce_scatter_plg_creator,
"reduce_scatter", pfc)
set_plugin_info(network, reduce_scatter_layer.name, plugin_info)
with self.disable_infer_shape():
wrapped_tensor = self.register_layer(
reduce_scatter_layer,
"reduce_scatter",
*layer_info,
).get_output(0)
reduce_scatter_tensor = reduce_scatter_layer.get_output(0)
if self.infer_shape:
shape = self.shapes_by_device[device_id][to_reduce_tensor.name]
assert len(shape) == 1
output_shape = (shape[0] // device_ids.size, )
self.shapes_by_device[device_id][
reduce_scatter_tensor.name] = output_shape
wrapped_tensor.shape = output_shape
if input_dtype != dtype:
reduce_scatter_tensor = self.cast(
network,
reduce_scatter_tensor,
input_dtype,
layer_info,
)
start = []
output_dims = []
stride = []
for dim in range(len(raw_input_shape)):
stride.append(1)
if dim not in shard_dims:
start.append(0)
output_dims.append(
self.get_item(network, input_shape_tensor, dim,
layer_info))
else:
start.append(None)
output_dims.append(None)
for dim, device_dim in zip(shard_dims, device_dims):
partition = get_partition(device_dim, device_ids)
index = get_index(device_dim, it)
input_dim = raw_input_shape[dim]
assert input_dim != -1
assert input_dim % partition == 0
quotient = input_dim // partition
start[dim] = index * quotient
output_dims[dim] = self.const_int(network, f"output_dim{dim}",
quotient, layer_info)
context.set_split_info(input_name, device_id, dim,
SplitInfo(input_dim, partition))
if shard_dims != [0]:
output_dims = [
output_dims[permutation[i]] for i in range(len(output_dims))
]
output_dims_tensor = self.concat(network, output_dims, layer_info)
output_tensor = self.reshape(
network,
reduce_scatter_tensor,
output_dims_tensor,
layer_info,
)
if shard_dims != [0]:
transpose_layer = network.add_shuffle(output_tensor)
transpose_layer.second_transpose = permutation
self.register_layer(transpose_layer, "output_transpose",
*layer_info)
output_tensor = transpose_layer.get_output(0)
context.update_name_mapping(input_name, device_id,
output_tensor.name)
def add_all_reduce_layer(self, context: GraphContext, input_name,
output_name, device_ids, to_reduce_tensors):
for device_id, to_reduce_tensor in zip(np.nditer(device_ids),
to_reduce_tensors):
device_id = device_id.item()
layer_info = (input_name, output_name, device_id)
network = self.get_network(device_id)
graph = self.get_graph(device_id)
workspace = graph.get_input("all_reduce_workspace").as_trt()
all_reduce_layer, allreduce_plg_creator, pfc = create_allreduce_plugin(
network=network,
tensor=to_reduce_tensor,
workspace=workspace,
group=np.ascontiguousarray(
device_ids.reshape(-1).astype(np.int32)),
dtype=to_reduce_tensor.dtype,
all_reduce_params=AllReduceParams(),
)
plugin_info = PluginInfo(allreduce_plg_creator, "allreduce", pfc)
set_plugin_info(network, all_reduce_layer.name, plugin_info)
with self.disable_infer_shape():
wrapped_tensor = self.register_layer(
all_reduce_layer,
"all_reduce",
*layer_info,
).get_output(0)
output_tensor = all_reduce_layer.get_output(0)
if self.infer_shape:
shape = self.shapes_by_device[device_id][to_reduce_tensor.name]
self.shapes_by_device[device_id][output_tensor.name] = shape
wrapped_tensor.shape = shape
context.update_name_mapping(input_name, device_id,
output_tensor.name)
def add_all_gather_layer(self, context: GraphContext, input_name,
output_name, device_ids, to_gather_tensors):
all_gather_layers = []
for device_id, to_gather_tensor in zip(np.nditer(device_ids),
to_gather_tensors):
device_id = device_id.item()
layer_info = (input_name, output_name, device_id)
network = self.get_network(device_id)
allgather_plg_creator = trt.get_plugin_registry(
).get_plugin_creator('AllGather', '1', TRT_LLM_PLUGIN_NAMESPACE)
assert allgather_plg_creator is not None
group = trt.PluginField(
"group",
np.ascontiguousarray(device_ids.reshape(-1).astype(np.int32)),
trt.PluginFieldType.INT32)
pf_type = trt.PluginField(
"type_id", np.array([int(to_gather_tensor.dtype)], np.int32),
trt.PluginFieldType.INT32)
pfc = trt.PluginFieldCollection([group, pf_type])
allgather = allgather_plg_creator.create_plugin("allgather", pfc)
all_gather_layer = network.add_plugin_v2([to_gather_tensor],
allgather)
plugin_info = PluginInfo(allgather_plg_creator, "allgather", pfc)
set_plugin_info(network, all_gather_layer.name, plugin_info)
with self.disable_infer_shape():
wrapped_tensor = self.register_layer(
all_gather_layer,
"all_gather",
*layer_info,
).get_output(0)
if self.infer_shape:
output_tensor = all_gather_layer.get_output(0)
shape = self.shapes_by_device[device_id][to_gather_tensor.name]
assert len(shape) == 1
output_shape = (shape[0] * device_ids.size, )
self.shapes_by_device[device_id][
output_tensor.name] = output_shape
wrapped_tensor.shape = output_shape
all_gather_layers.append(all_gather_layer)
return all_gather_layers
def set_shard_num(self, tensor_name, dim, shard_num):
for graph in self.graphs.values():
io_shards = graph._auto_parallel_config["io_shards"]
if tensor_name not in io_shards:
io_shards[tensor_name] = {}
io_shards[tensor_name][dim] = shard_num
def add_input(self, tensor: Tensor, device_ids, strategy: ShardingStrategy):
context = self.global_context
sharding_spec = strategy.sharding_specs["output0"]
shard_dims = sharding_spec.dim_partition_dict
for dim, device_dim in shard_dims.items():
partition = get_partition(device_dim, device_ids)
self.set_shard_num(tensor.name, dim, partition)
for device_id in np.nditer(device_ids):
device_id = device_id.item()
graph = self.get_graph(device_id)
new_input = graph.add_input(tensor.as_trt())
shape = [*tensor.shape]
if len(shard_dims) != 0:
output_shape = [*tensor.raw_shape]
for dim, device_dim in shard_dims.items():
partition = get_partition(device_dim, device_ids)
output_dim = output_shape[dim]
assert output_dim != -1
assert output_dim % partition == 0
quotient = output_dim // partition
output_shape[dim] = quotient
shape[dim] = quotient
assert tensor.value is None
context.set_split_info(tensor.name, device_id, dim,
SplitInfo(output_dim, partition))
new_input.raw_shape = output_shape
context.update_name_mapping(tensor.name, device_id, tensor.name)
if self.infer_shape:
self.shapes_by_device[device_id][tensor.name] = tuple(shape)
new_input.shape = tuple(shape)
if tensor.value is not None:
self.values_by_device[device_id][tensor.name] = tensor.value
new_input.value = tensor.value
def add_output(self, tensor: Tensor, device_ids,
strategy: ShardingStrategy):
comm_action_sequence = strategy.best_resharding_cost[0][0][1]
for commspec in comm_action_sequence:
self.add_comm(self.global_context, tensor.name, device_ids,
commspec)
for device_id in np.nditer(device_ids):
device_id = device_id.item()
graph = self.get_graph(device_id)
output_name = tensor.name
new_output_name = self.global_context.get_name(
output_name, device_id)
if new_output_name != output_name:
suffix = self.suffix
self.suffix += 1
original_name = f"original_{output_name}_{suffix}"
original_tensor = graph.get_tensor(output_name)
original_tensor.as_trt().name = original_name
output_tensor = graph.get_tensor(new_output_name)
output_tensor.as_trt().name = output_name
graph._tensors[original_name] = original_tensor
graph._tensors[output_name] = output_tensor
del graph._tensors[new_output_name]
else:
output_tensor = graph.get_tensor(output_name)
trt_output = output_tensor.as_trt()
if trt_output.is_shape_tensor:
graph.add_output_shape(trt_output)
else:
graph.add_output(trt_output)
trt_output.dtype = tensor.dtype
if tensor.dtype != output_tensor.dtype:
raise ValueError(
f"Output dtype mismatch, "
f"expect {tensor.dtype} for {tensor.name}, "
f"get {output_tensor.dtype} for {output_tensor.name}")
shard_dims = strategy.sharding_specs["input0"].dim_partition_dict
for dim, device_dim in shard_dims.items():
partition = get_partition(device_dim, device_ids)
self.set_shard_num(tensor.name, dim, partition)
class PrefixedGraphGroup(GraphGroupBase):
def __init__(
self,
full_graph: PipelineGraph = None,
config: ParallelConfig = None,
auto_parallel_config: AutoParallelConfig = None,
) -> None:
auto_parallel_config = auto_parallel_config or dict(
infer_shape=False,
validation_mode=False,
)
super().__init__(full_graph, config, auto_parallel_config)
self.validation_mode = auto_parallel_config.validation_mode
if not self.infer_shape:
self.validation_mode = False
self.prefixed_graph = PipelineGraph.create_graph()
if self.validation_mode:
self.layer_mapping = config.graph_config.graph_mapping.layer_mapping
self.graph_strategy = config.graph_strategy
self.shapes = {}
self.values = {}
self.timing_cache = None
def get_network(self, device_id) -> trt.INetworkDefinition:
return self.prefixed_graph.as_trt()
def get_graph(self, device_id) -> PipelineGraph:
return self.prefixed_graph
def get_prefix(self, device_id) -> str:
return f"device{device_id}_"
def get_shapes(self, device_id) -> Dict[str, Tuple[int, ...]]:
return self.shapes
def get_values(self, device_id) -> Dict[str, List[int]]:
return self.values
def add_all_reduce_layer(self, context: GraphContext, input_name,
output_name, device_ids, to_reduce_tensors):
reshaped_tensors = []
for device_id, to_reduce_tensor in zip(np.nditer(device_ids),
to_reduce_tensors):
device_id = device_id.item()
layer_info = (input_name, output_name, device_id)
network = self.get_network(device_id)
reshape_dims_tensor = self.concat(
network,
[
self.get_shape(network, to_reduce_tensor, layer_info),
self.const_int(network, "expanded_dim", 1, layer_info)
],
layer_info,
)
reshaped_tensor = self.reshape(
network,
to_reduce_tensor,
reshape_dims_tensor,
layer_info,
)
reshaped_tensors.append(reshaped_tensor)
for device_id in np.nditer(device_ids):
device_id = device_id.item()
layer_info = (input_name, output_name, device_id)
input_tensor = self.get_tensor(context, input_name, 0).as_trt()
num_dims = len(input_tensor.shape)
network = self.get_network(device_id)
concat_layer = network.add_concatenation(reshaped_tensors)
concat_layer.axis = num_dims
self.register_layer(concat_layer, "concat", *layer_info)
reduce_layer = network.add_reduce(concat_layer.get_output(0),
trt.ReduceOperation.SUM,
axes=1 << num_dims,
keep_dims=False)
dtype = to_reduce_tensors[0].dtype
reduce_layer.precision = dtype
reduce_layer.set_output_type(0, dtype)
self.register_layer(reduce_layer, "reduce", *layer_info)
output_tensor = reduce_layer.get_output(0)
context.update_name_mapping(input_name, device_id,
output_tensor.name)
def add_all_gather_layer(self, context: GraphContext, input_name,
output_name, device_ids, to_gather_tensors):
all_gather_layers = []
for device_id in np.nditer(device_ids):
device_id = device_id.item()
layer_info = (input_name, output_name, device_id)
network = self.get_network(device_id)
all_gather_layer = network.add_concatenation(to_gather_tensors)
all_gather_layer.axis = 0
self.register_layer(all_gather_layer, "all_gather", *layer_info)
all_gather_layers.append(all_gather_layer)
return all_gather_layers
def add_input(self, tensor: Tensor, device_ids, strategy: ShardingStrategy):
def add_identity():
identity_layer = network.add_identity(input.as_trt())
return identity_layer
input = self.prefixed_graph.add_input(tensor.as_trt())
if self.infer_shape:
self.shapes[tensor.name] = tensor.shape
input.shape = tensor.shape
if tensor.value is not None:
self.values[tensor.name] = tensor.value
input.value = tensor.value
network = self.get_network(None)
if self.validation_mode:
identity_layer = add_identity()
identity_layer.get_output(0).name = f"ref_{tensor.name}"
layer_info = (tensor.name, None, None)
self.register_layer(identity_layer,
"identity",
*layer_info,
keep_tensor_name=True)
input.attrs["strategy"] = strategy.name
sharding_spec = strategy.sharding_specs["output0"]
pre_sharding_sepc = get_full_sharding_spec(sharding_spec)
comm_action_sequence = get_comm_action_sequence(pre_sharding_sepc,
sharding_spec)
context = self.global_context
for device_id in np.nditer(device_ids):
device_id = device_id.item()
layer_info = (tensor.name, None, device_id)
context.update_name_mapping(tensor.name, device_id, tensor.name)
if len(comm_action_sequence
) == 0 and not tensor.as_trt().is_shape_tensor:
identity_layer = add_identity()
self.register_layer(identity_layer, "identity", *layer_info)
context.update_name_mapping(
tensor.name,
device_id,
identity_layer.get_output(0).name,
)
for commspec in comm_action_sequence:
self.add_comm(context, tensor.name, device_ids, commspec)
def get_graph_in_range(self, graph_group, src_layer, layer_range,
device_ids, shapes, values):
src_network = self.prefixed_graph.as_trt()
graph = graph_group.prefixed_graph
network = graph.as_trt()
input_mapping = {}
for device_id in np.nditer(device_ids):
device_id = device_id.item()
for i in range(src_layer.num_inputs):
src_input = src_layer.get_input(i)
if src_input is not None:
input = self.get_tensor(
self.global_context,
src_input.name,
device_id,
).as_trt()
if graph.get_input(src_input.name) is not None:
new_input = graph_group.get_tensor(
graph_group.global_context,
src_input.name,
device_id,
).as_trt()
input_mapping[input.name] = new_input.name
continue
if graph.get_tensor(input.name) is not None:
continue
shape = shapes[input.name]
assert input.name in values
value = values[input.name]
weights = np.asarray(value,
dtype=trt_dtype_to_np(input.dtype))
weights = to_trt_weights(weights)
input_layer = network.add_constant(shape, weights)
new_input = input_layer.get_output(0)
new_input.name = input.name
graph.register_layer(input_layer)
for i in layer_range:
layer = src_network.get_layer(i)
graph.add_layer(layer, input_mapping=input_mapping)
def add_layer_singleton(self, output, device_ids, sharding_spec):
assert self.prefixed_graph.get_tensor(output.name) is None
network = self.prefixed_graph.as_trt()
full_sharding_sepc = get_full_sharding_spec(sharding_spec)
comm_action_sequence = get_comm_action_sequence(sharding_spec,
full_sharding_sepc)
output_context = self.global_context.get_local_context_for_output(
output)
if len(comm_action_sequence) != 0:
for commspec in comm_action_sequence[:-1]:
self.add_comm(output_context, output.name, device_ids, commspec)
self.add_comm(
output_context,
output.name,
device_ids,
comm_action_sequence[-1],
is_singleton=True,
)
device_id = next(np.nditer(device_ids)).item()
layer_info = (output.name, None, device_id)
output_tensor = self.get_tensor(output_context, output.name,
device_id).as_trt()
singleton_layer = network.add_identity(output_tensor)
singleton_layer.get_output(0).name = output.name
self.register_layer(singleton_layer,
"singleton",
*layer_info,
keep_tensor_name=True)
def add_layer(self, wrapped_layer: Layer, device_ids,
strategy: ShardingStrategy):
graph = self.prefixed_graph
network = graph.as_trt()
start_layer_id = network.num_layers
super().add_layer(wrapped_layer, device_ids, strategy)
layer = wrapped_layer.as_trt()
if self.validation_mode:
is_shape = (wrapped_layer.is_shape_io
or layer.type == trt.LayerType.SHAPE)
if not is_shape:
self.current_block_id = wrapped_layer.attrs["block_id"]
for i, wrapped_output in enumerate(wrapped_layer.outputs):
if wrapped_output.is_graph_output:
continue
output = wrapped_output.as_trt()
output_name = f"output{i}"
if strategy.communication_actions.get(
output_name) is not None:
output_name += "_after_comm"
sharding_spec = strategy.sharding_specs[output_name]
self.add_layer_singleton(output, device_ids, sharding_spec)
self.current_block_id = -1
end_layer_id = network.num_layers
is_skip = (is_shape or layer.type == trt.LayerType.CONSTANT
or layer.name in self.layer_mapping)
sharded = False
for sharding_spec in strategy.sharding_specs.values():
if len(sharding_spec.dim_partition_dict) > 0:
sharded = True
break
if not sharded:
is_skip = True
ref_layer = graph.add_layer(layer, prefix="ref_")
ref_layer.attrs["strategy"] = strategy.name
ref_layer.attrs["block_id"] = wrapped_layer.attrs["block_id"]
if layer.type == trt.LayerType.CONSTANT:
self.register_unfilled_weights(graph, layer)
if is_skip:
return
logger.debug(f"validating layer {layer.name}")
layer_type = layer.type
generated_input_values = {}
to_subclass_layer(layer)
if layer_type == trt.LayerType.PLUGIN_V2:
if layer.plugin.plugin_type == "GPTAttention":
sharding_specs = {}
for name, sharding_spec in strategy.sharding_specs.items():
sharding_specs[name] = get_full_sharding_spec(
sharding_spec)
plugin_info = get_plugin_info(
self.full_graph.as_trt(),
layer.name,
)
generated_input_values = GPTAttentionPlugin.parameter_generator(
sharding_specs, plugin_info)
to_base_class_layer(layer)
validation_graph_group = PrefixedGraphGroup()
validation_graph = validation_graph_group.prefixed_graph
validation_graph._io_buffer_mapping = self.full_graph._io_buffer_mapping
extra_input_values = {}
validation_shapes = {}
for i, wrapped_input in enumerate(wrapped_layer.inputs):
if wrapped_input is None:
continue
input = wrapped_input.as_trt()
validation_shapes[input.name] = wrapped_input.shape
if wrapped_input.value is None:
if i in generated_input_values:
extra_input_value = generated_input_values[i]
else:
extra_input_value = torch.empty(
tuple(wrapped_input.shape),
dtype=trt_dtype_to_torch(input.dtype),
device=torch.cuda.current_device(),
)
if torch.is_floating_point(extra_input_value):
extra_input_value.normal_()
# extra_input_value[:] = random.choice([2, 3, 5, 7])
extra_input_values[input.name] = extra_input_value
self.values[input.name] = extra_input_value
if wrapped_input.producer is not None:
node_name = wrapped_input.producer.name
output_index = wrapped_input.output_index
else:
node_name = wrapped_input.name
output_index = 0
sharding_spec = self.graph_strategy[
node_name].sharding_specs[f"output{output_index}"]
validation_graph_group.add_input(
wrapped_input,
device_ids,
ShardingStrategy(
sharding_specs={"output0": sharding_spec}),
)
validation_graph.get_input(
input.name).raw_shape = wrapped_input.shape
self.get_graph_in_range(
validation_graph_group,
layer,
range(start_layer_id, end_layer_id),
device_ids,
self.shapes,
self.values,
)
for i, wrapped_output in enumerate(wrapped_layer.outputs):
output = wrapped_output.as_trt()
if wrapped_output.is_graph_output:
output_name = f"output{i}"
if strategy.communication_actions.get(
output_name) is not None:
output_name += "_after_comm"
sharding_spec = strategy.sharding_specs[output_name]
validation_graph_group.global_context.merge_context(
self.global_context.get_local_context_for_output(
output))
validation_graph_group.add_layer_singleton(
output, device_ids, sharding_spec)
validation_graph.add_output(output)
validation_shapes[output.name] = wrapped_output.shape
if not self.timing_cache:
self.timing_cache = network.builder.create_builder_config(
).create_timing_cache(b"")
logger.debug(f"run validation graph for layer {layer.name}")
validation_runner = validation_graph.get_runner(
validation_shapes,
self.values,
timing_cache=self.timing_cache,
opt_level=0,
)
values = validation_runner.run()
refer_input_values = {}
for wrapped_input in wrapped_layer.inputs:
if wrapped_input is None:
continue
if wrapped_input.value is not None:
refer_input_values[wrapped_input.name] = wrapped_input.value
refer_graph, output_mapping = get_per_layer_graph(
layer,
validation_shapes,
refer_input_values,
is_shape_io=False,
)
refer_graph._io_buffer_mapping = self.full_graph._io_buffer_mapping
for proxy_output, output in output_mapping.items():
validation_shapes[proxy_output] = validation_shapes[output]
logger.debug(f"run refer graph for layer {layer.name}")
refer_runner = refer_graph.get_runner(
validation_shapes,
self.values,
timing_cache=self.timing_cache,
opt_level=0,
)
refer_outputs = refer_runner.run()
for name, refer_output in refer_outputs.items():
if name in output_mapping:
refer_output = refer_output.bool()
output = values[name]
# outputrefer_output <= atol+rtol*refer_output
atol = 1e-02
rtol = 1e-02
if not torch.allclose(
output,
refer_output,
rtol=rtol,
atol=atol,
equal_nan=True,
):
size = output.nelement()
diff = (output - refer_output).abs()
diff_index = (~torch.isnan(diff)) & (diff > (
atol + rtol * refer_output.abs()))
diff_output = diff[diff_index]
diff_size = diff_output.nelement()
logger.warning(
f"output {name} of {layer.name} is not accurate after parallelization. "
f"{diff_size} out of {size} elements ({diff_size / size * 100:.2f}%) are not close. "
f"max: {diff_output.max():.5f}, mean: {diff_output.float().mean():.5f}, std: {diff_output.float().std():.5f}. "
f"mean of reference: {refer_output.float().mean():.5f}, mean of output: {output.float().mean():.5f}."
)
for name in extra_input_values.keys():
del self.values[name]
def add_output(self, tensor: Tensor, device_ids,
strategy: ShardingStrategy):
trt_output = tensor.as_trt()
comm_action_sequence = strategy.best_resharding_cost[0][0][1]
for commspec in comm_action_sequence:
self.add_comm(self.global_context, tensor.name, device_ids,
commspec)
self.add_layer_singleton(trt_output, device_ids,
strategy.sharding_specs["input0"])
if trt_output.is_shape_tensor:
output = self.prefixed_graph.add_output_shape(trt_output)
else:
output = self.prefixed_graph.add_output(trt_output)
trt_output.dtype = tensor.dtype
output.attrs["strategy"] = strategy.name
def assign_shapes(self, shape_info: ShapeInfo):
if self.validation_mode:
shapes = {
f"ref_{name}": shape
for name, shape in shape_info.shapes.items()
}
values = {
f"ref_{name}": value
for name, value in shape_info.values.items()
}
self.shapes.update(shapes)
self.values.update(values)
shape_layers = get_shape_layers(self.prefixed_graph.as_trt())
shape_info = ShapeInfo(self.shapes, self.values, shape_layers)
self.prefixed_graph.assign_shapes(shape_info)
def parallelize(
simplifier: Simplifier,
config: ParallelConfig,
):
auto_parallel_config = simplifier.config
debug_mode = auto_parallel_config.debug_mode
dump_path = auto_parallel_config.dump_path
debug_outputs = auto_parallel_config.debug_outputs
simplifier.infer_shapes(config.graph_config.num_micro_batches)
network = simplifier.network
graph = simplifier.graph
phy_mesh = config.graph_config.phy_mesh
# TODO: test device_ids = [[0]]
device_ids = phy_mesh.phy_devices_id
stage_phy_meshes = config.graph_config.stage_phy_meshes
block_to_stage = config.graph_config.graph_mapping.block_to_stage
graph_strategy = config.graph_strategy
desimplify_strategy(
graph,
graph_strategy,
config.graph_config.graph_mapping,
)
graph._plugin_config = simplifier.llm_network.plugin_config
graph_group = GraphGroup.from_graph(graph, config, auto_parallel_config)
if not debug_mode:
init_all_reduce_helper()
tp_size = phy_mesh.size // config.graph_config.num_stages
shape = (CustomAllReduceHelper.POINTERS_PER_RANK * tp_size +
CustomAllReduceHelper.POINTERS_OF_COUNTER, )
workspace = graph.as_trt().add_input(
name="all_reduce_workspace",
dtype=trt.int64,
shape=shape,
)
tensor = graph.register_input(workspace)
tensor.shape = shape
graph_strategy["all_reduce_workspace"] = ShardingStrategy(
sharding_specs={
"output0":
ShardingSpec(
device_mesh=phy_mesh.as_logical_mesh(),
data_type_size=tensor.dtype_str_size,
data_shape=shape,
max_data_shape=shape,
raw_data_shape=shape,
dim_partition_dict={},
)
})
if dump_path is not None:
lock = FileLock(f"{dump_path}/path.lock", thread_local=False)
with lock:
with open(f'{dump_path}/sharded_graph.log', 'w+') as file:
config.print_graph_strategy(file)
for input in graph.inputs:
graph_group.add_input(input, device_ids, graph_strategy[input.name])
for block in simplifier.blocks:
stage_id = block_to_stage[block.block_id]
stage_phy_mesh = stage_phy_meshes[stage_id]
stage_device_ids = stage_phy_mesh.phy_devices_id.reshape(
config.lmesh.mesh_shape)
for i in block.sorted_layer_ids:
layer = graph.get_layer(network.get_layer(i).name)
layer.attrs["block_id"] = block.block_id
graph_group.add_layer(
layer,
stage_device_ids,
graph_strategy[layer.name],
)
for output in graph.outputs:
graph_group.add_output(output, device_ids, graph_strategy[output.name])
if debug_mode:
new_graph = graph_group.prefixed_graph
debug_outputs = debug_outputs or []
if isinstance(debug_outputs, str):
if debug_outputs == 'validation':
debug_outputs = []
for tensor in new_graph.tensors:
if tensor.name.startswith('ref_'):
original_name = tensor.name[4:]
original_tensor = new_graph.get_tensor(original_name)
if original_tensor is not None:
if not original_tensor.is_graph_io:
debug_outputs.append(tensor.name)
debug_outputs.append(original_name)
if original_tensor.is_graph_output:
debug_outputs.append(tensor.name)
else:
pattern = debug_outputs
debug_outputs = []
for tensor in new_graph.tensors:
if tensor.as_trt().is_shape_tensor:
continue
if tensor.producer is not None:
layer = tensor.producer
if layer.type == trt.LayerType.SHAPE:
continue
if re.match(pattern, tensor.name):
debug_outputs.append(tensor.name)
for output_name in debug_outputs:
trt_output = new_graph.get_tensor(output_name).as_trt()
if trt_output.is_shape_tensor:
output = new_graph.add_output_shape(trt_output)
else:
output = new_graph.add_output(trt_output)
graph_group.assign_shapes(simplifier.shape_info)
if dump_path is not None:
with lock:
new_graph.to_dot(
f'{dump_path}/sharded_graph.dot',
per_device=True,
per_block=True,
# ignore_shape_io=True,
extra_attrs=['strategy'],
)
return [new_graph]
else:
graphs = []
for device_id in np.nditer(device_ids):
device_id = device_id.item()
graph = graph_group.graphs[device_id]
graphs.append(graph)
return graphs