mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-8682][chore] Remove auto_parallel module (#8329)
Signed-off-by: Anish Shanbhag <ashanbhag@nvidia.com>
This commit is contained in:
parent
e5865de518
commit
15de45d782
@ -93,36 +93,6 @@
|
||||
],
|
||||
"trtllm_modules_to_hf_modules": {}
|
||||
},
|
||||
"auto_parallel_config": {
|
||||
"world_size": 1,
|
||||
"gpus_per_node": 8,
|
||||
"cluster_key": "A100-PCIe-80GB",
|
||||
"cluster_info": null,
|
||||
"sharding_cost_model": "alpha_beta",
|
||||
"comm_cost_model": "alpha_beta",
|
||||
"enable_pipeline_parallelism": false,
|
||||
"enable_shard_unbalanced_shape": false,
|
||||
"enable_shard_dynamic_shape": false,
|
||||
"enable_reduce_scatter": true,
|
||||
"builder_flags": null,
|
||||
"debug_mode": false,
|
||||
"infer_shape": true,
|
||||
"validation_mode": false,
|
||||
"same_buffer_io": {
|
||||
"past_key_value_(\\d+)": "present_key_value_\\1"
|
||||
},
|
||||
"same_spec_io": {},
|
||||
"sharded_io_allowlist": [
|
||||
"past_key_value_\\d+",
|
||||
"present_key_value_\\d*"
|
||||
],
|
||||
"fast_reduce": true,
|
||||
"fill_weights": false,
|
||||
"parallel_config_cache": null,
|
||||
"profile_cache": null,
|
||||
"dump_path": null,
|
||||
"debug_outputs": []
|
||||
},
|
||||
"weight_sparsity": false,
|
||||
"weight_streaming": false,
|
||||
"plugin_config": {
|
||||
|
||||
@ -132,16 +132,6 @@ trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp16_wq \
|
||||
--output_dir ./tmp/llama/7B/trt_engines/weight_only/1-gpu/ \
|
||||
--gemm_plugin auto
|
||||
|
||||
# Build LLaMA 7B using 2-way auto parallelism (deprecated).
|
||||
python convert_checkpoint.py --model_dir ./tmp/llama/7B/ \
|
||||
--output_dir ./tllm_checkpoint_1gpu_fp16 \
|
||||
--dtype float16
|
||||
|
||||
trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp16 \
|
||||
--output_dir ./tmp/llama/7B/trt_engines/fp16/2-gpu/ \
|
||||
--gemm_plugin auto \
|
||||
--auto_parallel 2
|
||||
|
||||
# Build LLaMA 7B using 2-way tensor parallelism.
|
||||
python convert_checkpoint.py --model_dir ./tmp/llama/7B/ \
|
||||
--output_dir ./tllm_checkpoint_2gpu_tp2 \
|
||||
|
||||
@ -77,7 +77,6 @@ from ._utils import (default_gpus_per_node, local_mpi_rank, local_mpi_size,
|
||||
mpi_barrier, mpi_comm, mpi_rank, mpi_world_size,
|
||||
set_mpi_comm, str_dtype_to_torch, str_dtype_to_trt,
|
||||
torch_dtype_to_trt)
|
||||
from .auto_parallel import AutoParallelConfig, auto_parallel
|
||||
from .builder import BuildConfig, Builder, BuilderConfig, build
|
||||
from .disaggregated_params import DisaggregatedParams
|
||||
from .functional import Tensor, constant
|
||||
@ -130,8 +129,6 @@ __all__ = [
|
||||
'Module',
|
||||
'functional',
|
||||
'models',
|
||||
'auto_parallel',
|
||||
'AutoParallelConfig',
|
||||
'quantization',
|
||||
'tools',
|
||||
'LLM',
|
||||
|
||||
@ -391,10 +391,13 @@ class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings):
|
||||
rank to automatically shard the model. This is just to ensure that other objects in the
|
||||
runtime that may read parallel_config can do so.
|
||||
"""
|
||||
|
||||
# Set tp_size = self.world_size so that _ParallelConfig.world_size will return the
|
||||
# correct value (computed as tp_size * pp_size * cp_size). This does not necessarily
|
||||
# mean that TP will actually be used.
|
||||
self._parallel_config = _ParallelConfig(
|
||||
auto_parallel=True, gpus_per_node=self.gpus_per_node
|
||||
tp_size=self.world_size, gpus_per_node=self.gpus_per_node
|
||||
)
|
||||
self._parallel_config.world_size = self.world_size
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
@ -74,18 +74,15 @@ class DeviceMeshTopologyImpl(_MappingBaseForTypeCheck):
|
||||
# Access rank
|
||||
@property
|
||||
def tp_rank(self) -> int:
|
||||
assert not self.auto_parallel, "Auto parallel is not currently supported in Ray mode."
|
||||
return self.tp_group_pg.rank()
|
||||
|
||||
@property
|
||||
def pp_rank(self) -> int:
|
||||
assert not self.auto_parallel, "Auto parallel is not currently supported in Ray mode."
|
||||
return self.pp_group_pg.rank()
|
||||
|
||||
@property
|
||||
def cp_rank(self) -> int:
|
||||
# TODO: WIP
|
||||
assert not self.auto_parallel, "Auto parallel is not currently supported in Ray mode."
|
||||
return self.cp_group_pg.rank()
|
||||
|
||||
# Access group ranks
|
||||
|
||||
@ -25,7 +25,6 @@ import tempfile
|
||||
import trace
|
||||
import weakref
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import asdict
|
||||
from enum import EnumMeta
|
||||
from functools import lru_cache, partial, wraps
|
||||
from pathlib import Path
|
||||
@ -799,38 +798,6 @@ def trace_func(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
class DictConversion:
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config: Dict[str, Any]):
|
||||
obj = cls()
|
||||
fields = obj.__dataclass_fields__
|
||||
for key, value in config.items():
|
||||
assert hasattr(obj, key), f"cannot find {key} in {obj}"
|
||||
field_cls = fields[key].type
|
||||
if (isinstance(field_cls, type)
|
||||
and issubclass(field_cls, DictConversion)
|
||||
and isinstance(value, dict)):
|
||||
value = field_cls.from_dict(value)
|
||||
setattr(obj, key, value)
|
||||
return obj
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, file):
|
||||
with open(file) as f:
|
||||
return cls.from_dict(json.load(f))
|
||||
|
||||
def set_defaults(self, **kwargs):
|
||||
for key, default in kwargs.items():
|
||||
value = getattr(self, key)
|
||||
if (value is None
|
||||
or (isinstance(value, (list, dict)) and len(value) == 0)):
|
||||
setattr(self, key, default)
|
||||
|
||||
|
||||
class BaseEnumMeta(EnumMeta):
|
||||
|
||||
def __contains__(cls, item):
|
||||
|
||||
@ -1,9 +0,0 @@
|
||||
from .auto_parallel import auto_parallel
|
||||
from .cluster_info import infer_cluster_config
|
||||
from .config import AutoParallelConfig
|
||||
|
||||
__all__ = [
|
||||
'auto_parallel',
|
||||
'AutoParallelConfig',
|
||||
'infer_cluster_config',
|
||||
]
|
||||
@ -1,266 +0,0 @@
|
||||
import gc
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
from filelock import FileLock
|
||||
|
||||
from tensorrt_llm.functional import DimRange, Tensor
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.network import Network, net_guard
|
||||
|
||||
from .config import AutoParallelConfig
|
||||
from .device_mesh import LogicalDeviceMesh, PhysicalDeviceMesh
|
||||
from .node_graph import NodeGraph
|
||||
from .parallelization import ParallelConfig, parallelize
|
||||
from .pipeline_graph import PipelineGraph
|
||||
from .simplifier import GraphConfig, Simplifier, StageType
|
||||
from .utils import current_flags
|
||||
|
||||
|
||||
def to_network(graph: PipelineGraph, network: Network):
|
||||
logger.debug("Converting graph to network")
|
||||
trt_network = graph.as_trt()
|
||||
trt_network.name = network.trt_network.name
|
||||
new_network = Network()
|
||||
new_network._init(trt_network)
|
||||
new_network._dtype = network._dtype
|
||||
new_network._plugin_config = network._plugin_config
|
||||
new_network._unfilled_weights = graph._unfilled_weights
|
||||
new_network._auto_parallel_config = graph._auto_parallel_config
|
||||
with net_guard(network):
|
||||
for i in range(trt_network.num_inputs):
|
||||
input = trt_network.get_input(i)
|
||||
tensor = Tensor(is_network_input=False)
|
||||
if input.name in network._inputs:
|
||||
profiles = network._inputs[input.name].profiles
|
||||
elif len(network._inputs) == 0:
|
||||
profiles = []
|
||||
else:
|
||||
shape = input.shape
|
||||
num_profiles = len(list(network._inputs.values())[0].profiles)
|
||||
profile = DimRange(shape, [None] * len(shape))
|
||||
profiles = [profile] * num_profiles
|
||||
tensor.profiles = profiles
|
||||
tensor.trt_tensor = input
|
||||
new_network._inputs[input.name] = tensor
|
||||
return new_network
|
||||
|
||||
|
||||
def find_solution(
|
||||
node_graph: NodeGraph,
|
||||
graph_config: GraphConfig,
|
||||
lmesh: LogicalDeviceMesh,
|
||||
memory_budget: int,
|
||||
flags: list,
|
||||
device: int,
|
||||
dump_path: str,
|
||||
) -> ParallelConfig:
|
||||
torch.cuda.set_device(device)
|
||||
with current_flags(*flags):
|
||||
cost_graph = node_graph.get_cost_graph(lmesh)
|
||||
num_stages = graph_config.num_stages
|
||||
if num_stages == 1:
|
||||
stage_types = [None]
|
||||
elif num_stages == 2:
|
||||
stage_types = [StageType.START, StageType.END]
|
||||
else:
|
||||
stage_types = [StageType.START, StageType.BLOCK, StageType.END]
|
||||
|
||||
best_config, best_solution = None, None
|
||||
for stage_type in stage_types:
|
||||
if stage_type is not None:
|
||||
node_graph.set_slowest_stage(stage_type, graph_config)
|
||||
solution = node_graph.find_solution(
|
||||
cost_graph,
|
||||
memory_budget,
|
||||
)
|
||||
cost = solution.total_cost
|
||||
if best_config is None or cost < best_config.cost:
|
||||
best_config = ParallelConfig()
|
||||
best_config.graph_config = graph_config
|
||||
best_config.lmesh = lmesh
|
||||
best_config.cost = cost
|
||||
best_config.graph_strategy = solution.node_best_strategy
|
||||
best_config.stage_type = stage_type
|
||||
best_solution = solution
|
||||
if dump_path is not None:
|
||||
lock = FileLock(f"{dump_path}/path.lock", thread_local=False)
|
||||
vlz_name = f"{dump_path}/solution."
|
||||
if graph_config.num_micro_batches != 1:
|
||||
vlz_name += f"mbs{graph_config.num_micro_batches}."
|
||||
if graph_config.num_stages != 1:
|
||||
vlz_name += f"stages{graph_config.num_stages}."
|
||||
vlz_name += lmesh.cluster_key
|
||||
with lock:
|
||||
node_graph.visualize_solution(
|
||||
best_solution,
|
||||
vlz_name,
|
||||
ignore_shape_io=True,
|
||||
)
|
||||
return best_config
|
||||
|
||||
|
||||
def infer_builder_flags(network):
|
||||
fp16_enabled = False
|
||||
bf16_enabled = False
|
||||
int8_enabled = False
|
||||
fp8_enabled = False
|
||||
|
||||
def check_dtype(tensor):
|
||||
nonlocal fp16_enabled
|
||||
nonlocal bf16_enabled
|
||||
nonlocal int8_enabled
|
||||
nonlocal fp8_enabled
|
||||
if tensor.dtype == trt.DataType.HALF:
|
||||
fp16_enabled = True
|
||||
elif tensor.dtype == trt.DataType.BF16:
|
||||
bf16_enabled = True
|
||||
elif tensor.dtype == trt.DataType.INT8:
|
||||
int8_enabled = True
|
||||
elif tensor.dtype == trt.DataType.FP8:
|
||||
fp8_enabled = True
|
||||
|
||||
trt_network = network.trt_network
|
||||
for i in range(trt_network.num_inputs):
|
||||
input = trt_network.get_input(i)
|
||||
check_dtype(input)
|
||||
for i in range(trt_network.num_layers):
|
||||
layer = trt_network.get_layer(i)
|
||||
for j in range(layer.num_outputs):
|
||||
output = layer.get_output(j)
|
||||
check_dtype(output)
|
||||
|
||||
builder_flags = 0
|
||||
if fp16_enabled:
|
||||
builder_flags |= 1 << int(trt.BuilderFlag.FP16)
|
||||
builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
|
||||
if bf16_enabled:
|
||||
builder_flags |= 1 << int(trt.BuilderFlag.BF16)
|
||||
builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
|
||||
if int8_enabled:
|
||||
builder_flags |= 1 << int(trt.BuilderFlag.INT8)
|
||||
if fp8_enabled:
|
||||
builder_flags |= 1 << int(trt.BuilderFlag.FP8)
|
||||
builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
|
||||
return builder_flags
|
||||
|
||||
|
||||
def auto_parallel(network: Network, config: AutoParallelConfig):
|
||||
logger.warning(
|
||||
"auto_parallel is deprecated, "
|
||||
"please use explicit parallelism like tp_size/pp_size instead.")
|
||||
debug_mode = config.debug_mode
|
||||
memory_budget = config.get_cluster_info(
|
||||
).memory_budget_per_device * 1024 * 1024 * 1024
|
||||
enable_pipeline_parallelism = config.enable_pipeline_parallelism
|
||||
if config.world_size < config.gpus_per_node:
|
||||
num_hosts = 1
|
||||
num_devices_per_host = config.world_size
|
||||
else:
|
||||
assert config.world_size % config.gpus_per_node == 0
|
||||
num_hosts = config.world_size // config.gpus_per_node
|
||||
num_devices_per_host = config.gpus_per_node
|
||||
parallel_config_cache = config.parallel_config_cache
|
||||
dump_path = config.dump_path if debug_mode else None
|
||||
fill_weights = config.fill_weights
|
||||
|
||||
if num_hosts == 1 and num_devices_per_host == 1:
|
||||
return [network]
|
||||
|
||||
if dump_path is not None:
|
||||
if not os.path.exists(dump_path):
|
||||
os.makedirs(dump_path)
|
||||
|
||||
builder_flags = config.builder_flags or infer_builder_flags(network)
|
||||
flags = [builder_flags, network.strongly_typed]
|
||||
with current_flags(*flags):
|
||||
simplifier = Simplifier(network, config)
|
||||
network_hash = simplifier.get_network_hash()
|
||||
|
||||
best_config = None
|
||||
if parallel_config_cache is not None and Path(
|
||||
parallel_config_cache).exists():
|
||||
parallel_config = ParallelConfig.from_file(parallel_config_cache)
|
||||
if (ParallelConfig.VERSION == parallel_config.version
|
||||
and network_hash == parallel_config.network_hash
|
||||
and config == parallel_config.auto_parallel_config):
|
||||
logger.info(
|
||||
f"use cache of parallel config from {parallel_config_cache}"
|
||||
)
|
||||
best_config = parallel_config
|
||||
|
||||
if best_config is None:
|
||||
num_devices = num_hosts * num_devices_per_host
|
||||
phy_ids = [[
|
||||
i + j * num_devices_per_host
|
||||
for i in range(num_devices_per_host)
|
||||
] for j in range(num_hosts)]
|
||||
phy_mesh = PhysicalDeviceMesh(phy_ids, config)
|
||||
if enable_pipeline_parallelism:
|
||||
num_micro_batches_list = simplifier.list_all_num_micro_batches()
|
||||
else:
|
||||
num_micro_batches_list = [1]
|
||||
|
||||
jobs = []
|
||||
for num_micro_batches in num_micro_batches_list:
|
||||
simplifier.infer_shapes(num_micro_batches)
|
||||
if enable_pipeline_parallelism:
|
||||
pipeline_configs = phy_mesh.list_all_pipeline_configs()
|
||||
else:
|
||||
pipeline_configs = [(1, num_devices)]
|
||||
for num_stages, num_devices_per_stage in pipeline_configs:
|
||||
# TODO: add fallback path that allows num_micro_batches >= num_stages
|
||||
# if no solution satisfies memory budget
|
||||
if num_micro_batches < num_stages:
|
||||
continue
|
||||
simplified_graph, graph_config = simplifier.simplify_graph(
|
||||
phy_mesh,
|
||||
num_stages,
|
||||
num_devices_per_stage,
|
||||
)
|
||||
if simplified_graph is None:
|
||||
continue
|
||||
node_graph = NodeGraph(simplified_graph)
|
||||
node_graph.assign_cost_weights(graph_config)
|
||||
lmeshes = graph_config.stage_phy_meshes[
|
||||
0].get_logical_meshes()
|
||||
for lmesh in lmeshes:
|
||||
jobs.append(
|
||||
(node_graph, graph_config, lmesh, memory_budget *
|
||||
(num_devices / num_devices_per_stage)))
|
||||
|
||||
try:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
best_config = sorted(
|
||||
executor.map(
|
||||
lambda x: find_solution(
|
||||
*x,
|
||||
flags,
|
||||
torch.cuda.current_device(),
|
||||
dump_path,
|
||||
),
|
||||
jobs,
|
||||
),
|
||||
key=lambda x: x.cost,
|
||||
)[0]
|
||||
finally:
|
||||
phy_mesh.close()
|
||||
|
||||
if parallel_config_cache is not None:
|
||||
best_config.network_hash = network_hash
|
||||
best_config.auto_parallel_config = config
|
||||
best_config.save(parallel_config_cache)
|
||||
|
||||
new_graphs = parallelize(simplifier, best_config)
|
||||
|
||||
networks = [to_network(new_graph, network) for new_graph in new_graphs]
|
||||
if debug_mode and fill_weights:
|
||||
networks[0]._fill_weights()
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return networks
|
||||
@ -1,539 +0,0 @@
|
||||
import copy
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
import pynvml
|
||||
import torch
|
||||
|
||||
try:
|
||||
from cuda.bindings import runtime as cudart
|
||||
except ImportError:
|
||||
from cuda import cudart
|
||||
|
||||
from tensorrt_llm._utils import DictConversion
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.profiler import PyNVMLContext, _device_get_memory_info_fn
|
||||
|
||||
|
||||
@dataclass
|
||||
class MathThroughput(DictConversion):
|
||||
int4: int = 0 # Tflops
|
||||
int8: int = 0 # Tflops
|
||||
fp8: int = 0 # Tflops
|
||||
float16: int = 0 # Tflops
|
||||
bfloat16: int = 0 # Tflops
|
||||
float32: int = 0 # Tflops
|
||||
|
||||
@staticmethod
|
||||
def to_tflops(
|
||||
ipc_per_sm: "MathThroughput",
|
||||
sm_count: int,
|
||||
clock_mhz: int,
|
||||
) -> "MathThroughput":
|
||||
tflops = MathThroughput()
|
||||
for name in ipc_per_sm.__dataclass_fields__:
|
||||
setattr(
|
||||
tflops, name,
|
||||
getattr(ipc_per_sm, name) * sm_count * clock_mhz // int(1e6))
|
||||
return tflops
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClusterInfo(DictConversion):
|
||||
inter_node_bw_per_device: int = 25 # GBps
|
||||
intra_node_bw_per_device: int = 0 # GBps
|
||||
inter_node_latency: int = 10 # us
|
||||
intra_node_latency: int = 10 # us
|
||||
intra_node_sharp: bool = False
|
||||
inter_node_sharp: bool = True
|
||||
|
||||
memory_bw: int = 0 # GBps
|
||||
memory_budget_per_device: int = 0 # GB
|
||||
|
||||
math_throughput: MathThroughput = field(default_factory=MathThroughput)
|
||||
|
||||
memory_efficiency: float = 1.0
|
||||
math_efficiency: float = 1.0
|
||||
communication_efficiency: float = 1.0
|
||||
|
||||
|
||||
_math_throughputs = {
|
||||
"A100": MathThroughput(
|
||||
int8=624,
|
||||
float16=312,
|
||||
bfloat16=312,
|
||||
float32=156,
|
||||
),
|
||||
}
|
||||
|
||||
_bandwidths = {
|
||||
"PCIe-3": 16,
|
||||
"PCIe-4": 32,
|
||||
"PCIe-5": 64,
|
||||
}
|
||||
|
||||
cluster_infos = {
|
||||
# from https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
|
||||
"A100-SXM-80GB":
|
||||
ClusterInfo(
|
||||
intra_node_bw_per_device=300,
|
||||
memory_bw=2039,
|
||||
memory_budget_per_device=80,
|
||||
math_throughput=_math_throughputs["A100"],
|
||||
),
|
||||
"A100-SXM-40GB":
|
||||
ClusterInfo(
|
||||
intra_node_bw_per_device=300,
|
||||
memory_bw=1555,
|
||||
memory_budget_per_device=40,
|
||||
math_throughput=_math_throughputs["A100"],
|
||||
),
|
||||
"A100-PCIe-80GB":
|
||||
ClusterInfo(
|
||||
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
||||
memory_bw=1935,
|
||||
memory_budget_per_device=80,
|
||||
math_throughput=_math_throughputs["A100"],
|
||||
),
|
||||
"A100-PCIe-40GB":
|
||||
ClusterInfo(
|
||||
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
||||
memory_bw=1555,
|
||||
memory_budget_per_device=40,
|
||||
math_throughput=_math_throughputs["A100"],
|
||||
),
|
||||
# from https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
|
||||
"H100-SXM":
|
||||
ClusterInfo(
|
||||
inter_node_bw_per_device=50,
|
||||
intra_node_bw_per_device=450,
|
||||
intra_node_sharp=True,
|
||||
memory_bw=3350,
|
||||
memory_budget_per_device=80,
|
||||
math_throughput=MathThroughput(
|
||||
int8=1979,
|
||||
fp8=1979,
|
||||
float16=989,
|
||||
bfloat16=989,
|
||||
float32=495,
|
||||
),
|
||||
),
|
||||
"H100-PCIe":
|
||||
ClusterInfo(
|
||||
inter_node_bw_per_device=50,
|
||||
intra_node_bw_per_device=_bandwidths["PCIe-5"],
|
||||
memory_bw=2000,
|
||||
memory_budget_per_device=80,
|
||||
math_throughput=MathThroughput(
|
||||
int8=1513,
|
||||
fp8=1513,
|
||||
float16=756,
|
||||
bfloat16=756,
|
||||
float32=378,
|
||||
),
|
||||
),
|
||||
"H20":
|
||||
ClusterInfo(
|
||||
inter_node_bw_per_device=50,
|
||||
intra_node_bw_per_device=450,
|
||||
memory_bw=4000,
|
||||
memory_budget_per_device=96,
|
||||
math_throughput=MathThroughput(
|
||||
int8=293,
|
||||
fp8=293,
|
||||
float16=147,
|
||||
bfloat16=147,
|
||||
float32=74,
|
||||
),
|
||||
),
|
||||
# from https://nvdam.widen.net/s/nb5zzzsjdf/hpc-datasheet-sc23-h200-datasheet-3002446
|
||||
"H200-SXM":
|
||||
ClusterInfo(
|
||||
inter_node_bw_per_device=50,
|
||||
intra_node_bw_per_device=450,
|
||||
memory_bw=4800,
|
||||
memory_budget_per_device=141,
|
||||
math_throughput=MathThroughput(
|
||||
int8=3958,
|
||||
fp8=3958,
|
||||
float16=1979,
|
||||
bfloat16=1979,
|
||||
float32=67,
|
||||
),
|
||||
),
|
||||
"H200-NVL":
|
||||
ClusterInfo(
|
||||
inter_node_bw_per_device=50,
|
||||
intra_node_bw_per_device=450,
|
||||
memory_bw=4800,
|
||||
memory_budget_per_device=141,
|
||||
math_throughput=MathThroughput(
|
||||
int8=3341,
|
||||
fp8=3341,
|
||||
float16=1671,
|
||||
bfloat16=1671,
|
||||
float32=60,
|
||||
),
|
||||
),
|
||||
# from https://images.nvidia.com/content/Solutions/data-center/a40/nvidia-a40-datasheet.pdf
|
||||
"A40":
|
||||
ClusterInfo(
|
||||
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
||||
memory_bw=696,
|
||||
memory_budget_per_device=48,
|
||||
math_throughput=MathThroughput(
|
||||
int4=600,
|
||||
int8=300,
|
||||
float16=150,
|
||||
bfloat16=150,
|
||||
float32=75,
|
||||
),
|
||||
),
|
||||
# from https://www.nvidia.com/content/dam/en-zz/Solutions/data-center/products/a30-gpu/pdf/a30-datasheet.pdf
|
||||
"A30":
|
||||
ClusterInfo(
|
||||
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
||||
memory_bw=933,
|
||||
memory_budget_per_device=24,
|
||||
math_throughput=MathThroughput(
|
||||
int4=661,
|
||||
int8=330,
|
||||
float16=165,
|
||||
bfloat16=165,
|
||||
float32=82,
|
||||
),
|
||||
),
|
||||
# from https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/datasheet-new/nvidia-a10-datasheet.pdf
|
||||
"A10":
|
||||
ClusterInfo(
|
||||
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
||||
memory_bw=600,
|
||||
memory_budget_per_device=24,
|
||||
math_throughput=MathThroughput(
|
||||
int4=500,
|
||||
int8=250,
|
||||
float16=125,
|
||||
bfloat16=125,
|
||||
float32=62.5,
|
||||
),
|
||||
),
|
||||
"A10G":
|
||||
ClusterInfo(
|
||||
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
||||
memory_bw=600,
|
||||
memory_budget_per_device=24,
|
||||
math_throughput=MathThroughput(
|
||||
int4=280,
|
||||
int8=140,
|
||||
float16=70,
|
||||
bfloat16=70,
|
||||
float32=35,
|
||||
),
|
||||
),
|
||||
# from https://resources.nvidia.com/en-us-l40s/l40s-datasheet-28413
|
||||
"L40S":
|
||||
ClusterInfo(
|
||||
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
||||
memory_bw=864,
|
||||
memory_budget_per_device=48,
|
||||
math_throughput=MathThroughput(
|
||||
int4=733,
|
||||
int8=733,
|
||||
fp8=733,
|
||||
float16=362,
|
||||
bfloat16=362,
|
||||
float32=183,
|
||||
),
|
||||
),
|
||||
# from https://images.nvidia.cn/content/Solutions/data-center/vgpu-L40-datasheet.pdf
|
||||
"L40":
|
||||
ClusterInfo(
|
||||
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
||||
memory_bw=864,
|
||||
memory_budget_per_device=48,
|
||||
math_throughput=MathThroughput(
|
||||
int4=724,
|
||||
int8=362,
|
||||
fp8=362,
|
||||
float16=181,
|
||||
bfloat16=181,
|
||||
float32=90,
|
||||
),
|
||||
),
|
||||
"L20":
|
||||
ClusterInfo(
|
||||
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
||||
memory_bw=864,
|
||||
memory_budget_per_device=48,
|
||||
math_throughput=MathThroughput(
|
||||
int8=238,
|
||||
fp8=238,
|
||||
float16=119,
|
||||
bfloat16=119,
|
||||
float32=60,
|
||||
),
|
||||
),
|
||||
# from https://nvdam.widen.net/s/rvq98gbwsw/l4-datasheet-2595652
|
||||
"L4":
|
||||
ClusterInfo(
|
||||
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
||||
memory_bw=300,
|
||||
memory_budget_per_device=24,
|
||||
math_throughput=MathThroughput(
|
||||
int8=242,
|
||||
fp8=242,
|
||||
float16=120,
|
||||
bfloat16=120,
|
||||
float32=60,
|
||||
),
|
||||
),
|
||||
"L2":
|
||||
ClusterInfo(
|
||||
intra_node_bw_per_device=_bandwidths["PCIe-4"],
|
||||
memory_bw=300,
|
||||
memory_budget_per_device=24,
|
||||
math_throughput=MathThroughput(
|
||||
int8=193,
|
||||
fp8=193,
|
||||
float16=97,
|
||||
bfloat16=97,
|
||||
float32=48,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def infer_cluster_key() -> str:
|
||||
|
||||
def match(product, name):
|
||||
# Use A100 as example, the regex pattern matches for:
|
||||
# - NVIDIA A100 80GB
|
||||
# - NVIDIA A100-PCIE
|
||||
# - NVIDIA A100
|
||||
# And does not match A1000 etc.
|
||||
return re.match(f".*{product}([ -]|$).*", name) is not None
|
||||
|
||||
def is_sxm():
|
||||
return "SXM" in device_name
|
||||
|
||||
def is_80gb():
|
||||
return "80GB" in device_name
|
||||
|
||||
def is_32gb():
|
||||
return "32GB" in device_name
|
||||
|
||||
device_name = torch.cuda.get_device_name(torch.cuda.current_device())
|
||||
|
||||
if match("A100", device_name):
|
||||
if is_sxm():
|
||||
if is_80gb():
|
||||
return "A100-SXM-80GB"
|
||||
else:
|
||||
return "A100-SXM-40GB"
|
||||
else:
|
||||
if is_80gb():
|
||||
return "A100-PCIe-80GB"
|
||||
else:
|
||||
return "A100-PCIe-40GB"
|
||||
elif match("A10G", device_name):
|
||||
return "A10G"
|
||||
elif match("A10", device_name):
|
||||
return "A10"
|
||||
elif match("A30", device_name):
|
||||
return "A30"
|
||||
elif match("A40", device_name):
|
||||
return "A40"
|
||||
elif match("H100", device_name):
|
||||
if is_sxm():
|
||||
return "H100-SXM"
|
||||
else:
|
||||
return "H100-PCIe"
|
||||
elif match("H200", device_name):
|
||||
if is_sxm():
|
||||
return "H200-SXM"
|
||||
else:
|
||||
return "H200-NVL"
|
||||
elif match("L40S", device_name):
|
||||
return "L40S"
|
||||
elif match("L40", device_name):
|
||||
return "L40"
|
||||
elif match("L4", device_name):
|
||||
return "L4"
|
||||
return None
|
||||
|
||||
|
||||
def ipc_per_sm(compute_cap: Tuple[int, int]) -> MathThroughput:
|
||||
ipc_table = {
|
||||
(9, 0):
|
||||
MathThroughput(
|
||||
int8=16384,
|
||||
fp8=16384,
|
||||
float16=8192,
|
||||
bfloat16=8192,
|
||||
float32=4096,
|
||||
),
|
||||
(8, 0):
|
||||
MathThroughput(
|
||||
int4=8192,
|
||||
int8=4096,
|
||||
float16=2048,
|
||||
bfloat16=2048,
|
||||
float32=1024,
|
||||
),
|
||||
(8, 6):
|
||||
MathThroughput(
|
||||
int4=4096,
|
||||
int8=2048,
|
||||
float16=1024,
|
||||
bfloat16=1024,
|
||||
float32=512,
|
||||
),
|
||||
(8, 9):
|
||||
MathThroughput(
|
||||
int4=2048,
|
||||
int8=1024,
|
||||
fp8=1024,
|
||||
float16=512,
|
||||
bfloat16=512,
|
||||
float32=256,
|
||||
),
|
||||
(7, 0):
|
||||
MathThroughput(
|
||||
float16=1024,
|
||||
float32=128,
|
||||
),
|
||||
(7, 5):
|
||||
MathThroughput(
|
||||
int4=4096,
|
||||
int8=2048,
|
||||
float16=1024,
|
||||
float32=128,
|
||||
),
|
||||
}
|
||||
return ipc_table.get(compute_cap, MathThroughput())
|
||||
|
||||
|
||||
def nvlink_version(version_enum: int) -> int:
|
||||
nvl_version_table = {
|
||||
1: 1,
|
||||
2: 2,
|
||||
3: 2, # 2.2
|
||||
4: 3,
|
||||
5: 3, # 3.1
|
||||
6: 4,
|
||||
7: 5,
|
||||
}
|
||||
return nvl_version_table[version_enum]
|
||||
|
||||
|
||||
def nvlink_bandwidth(nvlink_version: int) -> int:
|
||||
nvl_bw_table = {
|
||||
1: 80,
|
||||
2: 150,
|
||||
3: 300,
|
||||
4: 450,
|
||||
5: 900,
|
||||
}
|
||||
return nvl_bw_table[nvlink_version]
|
||||
|
||||
|
||||
def infer_cluster_info() -> ClusterInfo:
|
||||
device = torch.cuda.current_device()
|
||||
index = device.index if isinstance(device, torch.device) else device
|
||||
with PyNVMLContext():
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(index)
|
||||
compute_cap = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||
logger.info(f"Compute capability: {compute_cap}")
|
||||
err, properties = cudart.cudaGetDeviceProperties(index)
|
||||
sm_count = properties.multiProcessorCount
|
||||
logger.info(f"SM count: {sm_count}")
|
||||
sm_clock = pynvml.nvmlDeviceGetMaxClockInfo(
|
||||
handle,
|
||||
pynvml.NVML_CLOCK_SM,
|
||||
)
|
||||
logger.info(f"SM clock: {sm_clock} MHz")
|
||||
math_throughput = MathThroughput.to_tflops(
|
||||
ipc_per_sm(compute_cap),
|
||||
sm_count,
|
||||
sm_clock,
|
||||
)
|
||||
for name in math_throughput.__dataclass_fields__:
|
||||
tflops = getattr(math_throughput, name)
|
||||
logger.info(f"{name} TFLOPS: {tflops}")
|
||||
|
||||
mem_info = _device_get_memory_info_fn(handle)
|
||||
memory_budget = mem_info.total // (1024**3)
|
||||
logger.info(f"Total Memory: {memory_budget} GiB")
|
||||
|
||||
mem_clock = pynvml.nvmlDeviceGetMaxClockInfo(
|
||||
handle,
|
||||
pynvml.NVML_CLOCK_MEM,
|
||||
)
|
||||
logger.info(f"Memory clock: {mem_clock} MHz")
|
||||
mem_bus_width = pynvml.nvmlDeviceGetMemoryBusWidth(handle)
|
||||
logger.info(f"Memory bus width: {mem_bus_width}")
|
||||
memory_bw = mem_bus_width * mem_clock * 2 // int(8e3)
|
||||
logger.info(f"Memory bandwidth: {memory_bw} GB/s")
|
||||
|
||||
try:
|
||||
is_nvl_active = bool(pynvml.nvmlDeviceGetNvLinkState(handle, 0))
|
||||
logger.info(f"NVLink is active: {is_nvl_active}")
|
||||
except pynvml.NVMLError:
|
||||
is_nvl_active = False
|
||||
|
||||
intra_node_sharp = False
|
||||
if is_nvl_active:
|
||||
nvl_version_enum = pynvml.nvmlDeviceGetNvLinkVersion(handle, 0)
|
||||
nvl_version = nvlink_version(nvl_version_enum)
|
||||
logger.info(f"NVLink version: {nvl_version}")
|
||||
nvl_bw = nvlink_bandwidth(nvl_version)
|
||||
logger.info(f"NVLink bandwidth (unidirectional): {nvl_bw} GB/s")
|
||||
intra_node_bw = nvl_bw
|
||||
if nvl_version >= 4:
|
||||
intra_node_sharp = True
|
||||
else:
|
||||
pcie_speed = pynvml.nvmlDeviceGetPcieSpeed(handle)
|
||||
logger.info(f"PCIe speed: {pcie_speed} Mbps")
|
||||
pcie_link_width = pynvml.nvmlDeviceGetCurrPcieLinkWidth(handle)
|
||||
logger.info(f"PCIe link width: {pcie_link_width}")
|
||||
pcie_bw = pcie_speed * pcie_link_width // int(8e3)
|
||||
logger.info(f"PCIe bandwidth: {pcie_bw} GB/s")
|
||||
intra_node_bw = pcie_bw
|
||||
|
||||
cluster_info = ClusterInfo(
|
||||
math_throughput=math_throughput,
|
||||
memory_bw=memory_bw,
|
||||
memory_budget_per_device=memory_budget,
|
||||
intra_node_bw_per_device=intra_node_bw,
|
||||
intra_node_sharp=intra_node_sharp,
|
||||
)
|
||||
return cluster_info
|
||||
|
||||
|
||||
def infer_cluster_config() -> Dict[str, Union[str, ClusterInfo]]:
|
||||
device_name = torch.cuda.get_device_name(torch.cuda.current_device())
|
||||
cluster_key = infer_cluster_key()
|
||||
if cluster_key is not None:
|
||||
return dict(cluster_key=cluster_key)
|
||||
else:
|
||||
try:
|
||||
cluster_info = infer_cluster_info()
|
||||
except pynvml.NVMLError:
|
||||
fallback_cluster_key = "L40"
|
||||
cluster_info = copy.copy(cluster_infos[fallback_cluster_key])
|
||||
memory_budget = torch.cuda.mem_get_info()[1] // (1024**3)
|
||||
cluster_info.memory_budget_per_device = memory_budget
|
||||
logger.warning(
|
||||
f"Failed to infer cluster info for {device_name}, "
|
||||
f"treat it as a {fallback_cluster_key} node with {memory_budget} GB memory. "
|
||||
"This setting makes no effect if you do not use auto parallel.")
|
||||
return dict(
|
||||
cluster_key=device_name.replace(" ", "-"),
|
||||
cluster_info=cluster_info,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.set_level("info")
|
||||
infer_cluster_info()
|
||||
@ -1,61 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from enum import auto
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from strenum import LowercaseStrEnum
|
||||
|
||||
from tensorrt_llm._utils import BaseEnumMeta, DictConversion
|
||||
|
||||
from .cluster_info import ClusterInfo, cluster_infos
|
||||
|
||||
|
||||
class CostModel(LowercaseStrEnum, metaclass=BaseEnumMeta):
|
||||
ALPHA_BETA = auto()
|
||||
PROFILE = auto()
|
||||
S_CURVE = auto()
|
||||
# Zero cost model is for test purpose.
|
||||
# Use zero cost model for communication will make solver prefer sharding
|
||||
# Use zero cost model for computation will make solver prefer replication
|
||||
ZERO = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoParallelConfig(DictConversion):
|
||||
# cluster configuration
|
||||
world_size: int = 1
|
||||
gpus_per_node: int = 8
|
||||
cluster_key: str = None
|
||||
cluster_info: Optional[ClusterInfo] = None
|
||||
|
||||
# cost model configuration
|
||||
sharding_cost_model: str = CostModel.ALPHA_BETA
|
||||
comm_cost_model: str = CostModel.ALPHA_BETA
|
||||
|
||||
# strategy configuration
|
||||
enable_pipeline_parallelism: bool = False
|
||||
enable_shard_unbalanced_shape: bool = False
|
||||
enable_shard_dynamic_shape: bool = False
|
||||
enable_reduce_scatter: bool = True
|
||||
|
||||
# parallelization configuration
|
||||
builder_flags: Optional[int] = None
|
||||
debug_mode: bool = False
|
||||
infer_shape: bool = True
|
||||
validation_mode: bool = False
|
||||
same_buffer_io: Dict[str, str] = field(default_factory=dict)
|
||||
same_spec_io: Dict[str, str] = field(default_factory=dict)
|
||||
sharded_io_allowlist: List[str] = field(default_factory=list)
|
||||
fill_weights: bool = False
|
||||
|
||||
# debug configuration
|
||||
parallel_config_cache: Optional[str] = None
|
||||
profile_cache: Optional[str] = None
|
||||
dump_path: Optional[str] = None
|
||||
debug_outputs: Union[List[str], str] = field(default_factory=list)
|
||||
|
||||
def get_cluster_info(self) -> ClusterInfo:
|
||||
return self.cluster_info or cluster_infos[self.cluster_key]
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.world_size > 1
|
||||
@ -1,612 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
from filelock import FileLock
|
||||
|
||||
from .config import AutoParallelConfig, CostModel
|
||||
from .tensor_parallel.shape_consistency import ShapeConsistencyManager
|
||||
|
||||
|
||||
class ProfileDB(ABC):
|
||||
"""A database that stores profiling results for multiple device mesh
|
||||
shapes."""
|
||||
|
||||
@abstractmethod
|
||||
def query(self, cluster_key, data_key):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def update(self, cluster_key, data_key, mesh_result):
|
||||
...
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class MemDB(ProfileDB):
|
||||
|
||||
def __init__(self):
|
||||
self.data = {}
|
||||
|
||||
def query(self, cluster_key, data_key):
|
||||
key = (cluster_key, data_key)
|
||||
mesh_result = self.data.get(key, None)
|
||||
if mesh_result is None:
|
||||
return None
|
||||
else:
|
||||
return mesh_result[0]
|
||||
|
||||
def update(self, cluster_key, data_key, mesh_result):
|
||||
key = (cluster_key, data_key)
|
||||
self.data[key] = mesh_result
|
||||
|
||||
|
||||
class Hdf5DB(ProfileDB):
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
lock_name = self.name + ".lock"
|
||||
self.lock = FileLock(lock_name, thread_local=False)
|
||||
|
||||
def query(self, cluster_key, data_key):
|
||||
file_name = f"{self.name}.hdf5"
|
||||
key = str((cluster_key, data_key))
|
||||
self.lock.acquire()
|
||||
mesh_result = None
|
||||
with h5py.File(file_name, 'a') as f:
|
||||
if key in f:
|
||||
self.lock.release()
|
||||
mesh_result = f[key]
|
||||
return mesh_result[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def update(self, cluster_key, data_key, mesh_result):
|
||||
key = str((cluster_key, data_key))
|
||||
file_name = f"{self.name}.hdf5"
|
||||
with h5py.File(file_name, 'a') as f:
|
||||
f[key] = mesh_result
|
||||
|
||||
def close(self):
|
||||
self.lock.release(force=True)
|
||||
|
||||
|
||||
class LogicalDeviceMesh(object):
|
||||
|
||||
def __init__(self,
|
||||
phy_mesh_shape,
|
||||
mesh_shape,
|
||||
phy_ids,
|
||||
config: AutoParallelConfig,
|
||||
alpha,
|
||||
beta,
|
||||
sharp,
|
||||
prof_database=None,
|
||||
shape_consistency_manager=None,
|
||||
host_ips=None):
|
||||
self.phy_mesh_shape = phy_mesh_shape
|
||||
self.mesh_shape = mesh_shape
|
||||
self.phy_ids = phy_ids
|
||||
self.host_ips = host_ips
|
||||
self.cluster_key = config.cluster_key + '_mesh_shape{}'.format('_'.join(
|
||||
[str(i) for i in mesh_shape]))
|
||||
self.prof_min_max_size = [1, 2**34]
|
||||
self.prof_comm_dtypes = [
|
||||
"int8", "uint8", "int32", "uint32", "int64", "uint64", "float16",
|
||||
"float32", "float64", "bfloat16"
|
||||
]
|
||||
self.devices_group = {
|
||||
(0, ): [self.phy_ids.transpose(), self.mesh_shape[1] - 1],
|
||||
(1, ): [self.phy_ids, self.mesh_shape[1]],
|
||||
(0, 1): [self.phy_ids.reshape([1, self.phy_ids.size]), 0]
|
||||
}
|
||||
self.prof_database = prof_database
|
||||
self.shape_consistency_manager = shape_consistency_manager
|
||||
self.config = config
|
||||
self.cluster_info = config.get_cluster_info()
|
||||
self.hw_alpha = alpha
|
||||
self.hw_beta = beta
|
||||
self.hw_sharp = sharp
|
||||
self.algo_alpha_beta = self._estimate_algo_alpha_beta()
|
||||
self.comm_op_to_nccl_test_func_name = {
|
||||
'all_reduce': 'all_reduce_perf_mpi',
|
||||
'all_gather': 'all_gather_perf_mpi',
|
||||
'all_to_all': 'alltoall_perf_mpi',
|
||||
'reduce_scatter': 'reduce_scatter_perf_mpi',
|
||||
'split': 'split',
|
||||
}
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return self.phy_ids.size
|
||||
|
||||
def _estimate_algo_alpha_beta(self):
|
||||
ret = {}
|
||||
ar_alpha, ar_beta = {}, {}
|
||||
ag_alpha, ag_beta = {}, {}
|
||||
rs_alpha, rs_beta = {}, {}
|
||||
a2a_alpha, a2a_beta = {}, {}
|
||||
phy_num_hosts, phy_num_devices_per_host = self.phy_mesh_shape
|
||||
if phy_num_hosts == 1 or phy_num_devices_per_host == 1:
|
||||
for dims in [(0, ), (1, ), (0, 1), (1, 0)]:
|
||||
num_devices = 1
|
||||
for dim in dims:
|
||||
num_devices = self.mesh_shape[dim] * num_devices
|
||||
if num_devices != 1:
|
||||
ar_alpha[dims] = self.hw_alpha[0] if self.hw_sharp[
|
||||
0] else self.hw_alpha[0] * num_devices / 2 / (
|
||||
num_devices - 1)
|
||||
ar_beta[dims] = self.hw_beta[0]
|
||||
ag_alpha[dims] = self.hw_alpha[0] * num_devices / (
|
||||
num_devices - 1)
|
||||
ag_beta[dims] = self.hw_beta[0]
|
||||
rs_alpha[dims] = self.hw_alpha[0] * num_devices / (
|
||||
num_devices - 1)
|
||||
rs_beta[dims] = self.hw_beta[0]
|
||||
a2a_alpha[dims] = self.hw_alpha[0] * num_devices / (
|
||||
num_devices - 1)
|
||||
a2a_beta[dims] = self.hw_beta[0]
|
||||
# phy and logical have the same mesh shape if num_hosts > 1 and num_devices_per_host > 1
|
||||
else:
|
||||
for dims in [(0, ), (1, ), (0, 1), (1, 0)]:
|
||||
num_devices = 1
|
||||
for dim in dims:
|
||||
num_devices = self.mesh_shape[dim] * num_devices
|
||||
if num_devices != 1:
|
||||
if len(dims) == 1:
|
||||
dim = dims[0]
|
||||
ar_alpha[dims] = self.hw_alpha[dim] if self.hw_sharp[
|
||||
dim] else self.hw_alpha[dim] * num_devices / 2 / (
|
||||
num_devices - 1)
|
||||
ar_beta[dims] = self.hw_beta[dim]
|
||||
ag_alpha[dims] = self.hw_alpha[dim] * num_devices / (
|
||||
num_devices - 1)
|
||||
ag_beta[dims] = self.hw_beta[dim]
|
||||
rs_alpha[dims] = self.hw_alpha[dim] * num_devices / (
|
||||
num_devices - 1)
|
||||
rs_beta[dims] = self.hw_beta[dim]
|
||||
a2a_alpha[dims] = self.hw_alpha[dim] * num_devices / (
|
||||
num_devices - 1)
|
||||
a2a_beta[dims] = self.hw_beta[dim]
|
||||
elif len(dims) == 2: # two level communication
|
||||
num_hosts, num_devices_per_host = phy_num_hosts, phy_num_devices_per_host
|
||||
inter_node_col_alpha = self.hw_alpha[
|
||||
0] * num_devices_per_host
|
||||
inter_node_ar_alpha = inter_node_col_alpha if self.hw_sharp[
|
||||
0] else inter_node_col_alpha * num_hosts / 2 / (
|
||||
num_hosts - 1)
|
||||
intra_node_ar_alpha = self.hw_alpha[1]
|
||||
intra_node_ar_alpha = intra_node_ar_alpha if self.hw_sharp[
|
||||
1] else intra_node_ar_alpha * num_devices_per_host / 2 / (
|
||||
num_devices_per_host - 1)
|
||||
ar_alpha[dims] = min(inter_node_ar_alpha,
|
||||
intra_node_ar_alpha)
|
||||
ar_beta[dims] = max(self.hw_beta)
|
||||
ag_alpha[dims] = min(
|
||||
inter_node_col_alpha * num_hosts / (num_hosts - 1),
|
||||
self.hw_alpha[1] * num_devices_per_host /
|
||||
(num_devices_per_host - 1))
|
||||
ag_beta[dims] = max(self.hw_beta)
|
||||
rs_alpha[dims] = ag_alpha[dims]
|
||||
rs_beta[dims] = ag_beta[dims]
|
||||
a2a_alpha[dims] = min(
|
||||
num_hosts * self.hw_alpha[0] / (num_hosts - 1),
|
||||
self.hw_alpha[1] * num_hosts)
|
||||
a2a_beta[dims] = max(self.hw_beta)
|
||||
else:
|
||||
pass
|
||||
ret['all_to_all'] = [a2a_alpha, a2a_beta]
|
||||
ret['all_reduce'] = [ar_alpha, ar_beta]
|
||||
ret['all_gather'] = [ag_alpha, ag_beta]
|
||||
ret['reduce_scatter'] = [rs_alpha, rs_beta]
|
||||
ret['p2p_cross_device'] = [
|
||||
self.cluster_info.intra_node_bw_per_device,
|
||||
self.cluster_info.intra_node_latency
|
||||
]
|
||||
ret['p2p_cross_host'] = [
|
||||
self.cluster_info.inter_node_bw_per_device,
|
||||
self.cluster_info.inter_node_latency
|
||||
]
|
||||
return ret
|
||||
|
||||
#[ToDo] stub functions here
|
||||
def _profile_split(self, min_max_comm_size):
|
||||
comm_size, elapsed_time = [], []
|
||||
size = min_max_comm_size[0]
|
||||
while size <= min_max_comm_size[1]:
|
||||
time = size * 2 / self.cluster_info.memory_bw
|
||||
comm_size.append(size)
|
||||
elapsed_time.append(time)
|
||||
size = size * 2
|
||||
return np.array([comm_size, elapsed_time])
|
||||
|
||||
def _prase_nccl_test_results(self, f_nccl_test_out_log):
|
||||
'''[ToDo] There is some dtye that may not been supported by nccl test, using default dtype (float)'''
|
||||
start_parse = False
|
||||
comm_size, elapsed_time = [], []
|
||||
try:
|
||||
with open(f_nccl_test_out_log, 'r') as lines:
|
||||
for line in lines:
|
||||
if start_parse:
|
||||
prof_data = re.split(r"[ ]+", line.strip())
|
||||
if len(prof_data) != 13:
|
||||
continue
|
||||
comm_size.append(float(prof_data[0]))
|
||||
elapsed_time.append(float(prof_data[5]))
|
||||
if 'GB/s' in line and 'us' in line:
|
||||
start_parse = True
|
||||
except Exception:
|
||||
print(f'failed to parse {f_nccl_test_out_log}')
|
||||
return comm_size, elapsed_time
|
||||
|
||||
def _profile_with_nccl_test(self, min_max_comm_size, dtype, device_group,
|
||||
func_name, step, workload_key):
|
||||
|
||||
if func_name == 'split':
|
||||
if 2 == step:
|
||||
return self._profile_split(min_max_comm_size)
|
||||
else:
|
||||
return None
|
||||
workspace_dir = self.config['profiling_workspace'] + f'/{workload_key}'
|
||||
os.makedirs(workspace_dir, exist_ok=True)
|
||||
outfile, errfile = workspace_dir + '/profile.out', workspace_dir + '/profile.err'
|
||||
if 1 == step:
|
||||
num_nodes = len(self.host_ips)
|
||||
num_gpus = self.mesh_shape[0] * self.mesh_shape[1]
|
||||
ntasks_per_node = num_gpus // num_nodes
|
||||
nccl_test_command = '"export NCCL_TESTS_SPLIT_MASK={} && export NCCL_COLLNET_ENABLE=1 && {} -b {} -e {} -g 1 -d {} -f {}"'.format(
|
||||
device_group[1], func_name, min_max_comm_size[0],
|
||||
min_max_comm_size[1], dtype, 2)
|
||||
sbatch_command = '#!/bin/bash\n'
|
||||
sbatch_command += '#SBATCH -p {}\n'.format(self.config['partition'])
|
||||
sbatch_command += '#SBATCH -A {}\n'.format(self.config['account'])
|
||||
sbatch_command += '#SBATCH -J {}\n'.format(self.config['jobname'])
|
||||
sbatch_command += '#SBATCH -N {}\n'.format(num_nodes)
|
||||
sbatch_command += '#SBATCH -t {}\n'.format(self.config['time'])
|
||||
sbatch_command += '#SBATCH --ntasks-per-node={}\n'.format(
|
||||
ntasks_per_node)
|
||||
sbatch_command += '#SBATCH --exclusive\n'
|
||||
sbatch_command += '#SBATCH --mem=0\n'
|
||||
sbatch_command += '#SBATCH --network=sharp\n'
|
||||
sbatch_command += '#SBATCH --mail-type=FAIL\n'
|
||||
srun_command = 'srun --nodes={} --mpi=pmix --ntasks-per-node={} --network=sharp -o {} -e {} --container-image={} bash -c '.format(
|
||||
num_nodes, ntasks_per_node, outfile, errfile,
|
||||
self.config['container'])
|
||||
command = sbatch_command + srun_command + nccl_test_command
|
||||
with open(workspace_dir + '/workload.sub', 'w') as f:
|
||||
f.write(command)
|
||||
with open('./preprofiling_step1.sh', 'a') as f:
|
||||
f.write(f'sbatch {workspace_dir}/workload.sub\n')
|
||||
return None
|
||||
|
||||
else:
|
||||
comm_size, elapsed_time = self._prase_nccl_test_results(outfile)
|
||||
if len(comm_size) < 2:
|
||||
assert 0, 'the profiling for {} was failed at step1, please try again'.format(
|
||||
workload_key)
|
||||
else:
|
||||
print(workload_key, comm_size, elapsed_time)
|
||||
return np.array([comm_size, elapsed_time])
|
||||
|
||||
def _profile_single_comm_perf(self, device_group, comm_op, step, data_key):
|
||||
results = {}
|
||||
func_name = self.comm_op_to_nccl_test_func_name[comm_op]
|
||||
for dtype in self.prof_comm_dtypes:
|
||||
size_time = self._profile_with_nccl_test(
|
||||
self.prof_min_max_size, dtype, device_group, func_name, step,
|
||||
data_key + f'_dtype{dtype}')
|
||||
results[dtype] = size_time
|
||||
return results
|
||||
|
||||
def profile_all_comms_perf(self, step):
|
||||
if self.mesh_shape == (1, 1):
|
||||
return None
|
||||
mesh_results = self.prof_database.query(self.cluster_key,
|
||||
self.mesh_shape)
|
||||
if mesh_results:
|
||||
return mesh_results
|
||||
|
||||
mesh_results = {}
|
||||
data_key = self.cluster_key + f'_mesh_shape{self.mesh_shape[0]}x{self.mesh_shape[1]}'
|
||||
for comm_op in [
|
||||
'all_reduce', 'all_to_all', 'all_gather', 'reduce_scatter',
|
||||
'split'
|
||||
]:
|
||||
comm_perf = {}
|
||||
for dim, device_group in self.devices_group.items():
|
||||
# don't need to profile for mesh dim == 1
|
||||
if len(dim) == 1 and self.mesh_shape[dim[0]] == 1:
|
||||
continue
|
||||
|
||||
comm_perf[dim] = self._profile_single_comm_perf(
|
||||
device_group, comm_op, step, data_key +
|
||||
'_comm_op{}_dim{}'.format(comm_op, ''.join(map(str, dim))))
|
||||
mesh_results[comm_op] = comm_perf
|
||||
if 2 == step:
|
||||
self.prof_database.update(self.cluster_key, self.mesh_shape,
|
||||
mesh_results)
|
||||
|
||||
return mesh_results
|
||||
|
||||
def _model_comm_cost_from_s_curve(self, size_time_array, realsize):
|
||||
assert size_time_array[0][0] <= realsize <= size_time_array[0][-1],\
|
||||
'the comm_size: {} is not in the profile range: [{}{}]'\
|
||||
.format(realsize, size_time_array[0][0], size_time_array[0][-1])
|
||||
return np.interp(realsize, size_time_array[0], size_time_array[1])
|
||||
|
||||
def _model_comm_cost_from_alpha_beta(self, comm_op, dim_key, size_in_bytes):
|
||||
elapsed_time = 0.0
|
||||
if 'split' == comm_op:
|
||||
elapsed_time = size_in_bytes * 2 / (
|
||||
self.cluster_info.memory_bw *
|
||||
self.cluster_info.memory_efficiency) * 1e-3
|
||||
else:
|
||||
dict_alpha, dict_beta = self.algo_alpha_beta[comm_op]
|
||||
alpha, beta = dict_alpha[dim_key], dict_beta[dim_key]
|
||||
elapsed_time = (size_in_bytes /
|
||||
(alpha * self.cluster_info.communication_efficiency)
|
||||
* 1e-3) + beta
|
||||
return elapsed_time
|
||||
|
||||
def _input_size_to_comm_size(self, comm_op, dims, input_size):
|
||||
ret = input_size
|
||||
if 'all_gather' == comm_op:
|
||||
for dim in dims:
|
||||
ret = ret * self.mesh_shape[dim]
|
||||
return ret
|
||||
|
||||
def estimate_comm_cost(self, comm_op, dim, input_size, dtype):
|
||||
|
||||
size = self._input_size_to_comm_size(comm_op, dim, input_size)
|
||||
if self.config.comm_cost_model == CostModel.S_CURVE:
|
||||
mesh_perf = self.prof_database.query(self.cluster_key,
|
||||
self.mesh_shape)
|
||||
assert mesh_perf is not None, 'the mesh is not profiled, mesh_shape = {}'.format(
|
||||
self.mesh_shape)
|
||||
comm_op_perf = mesh_perf.get(comm_op, None)
|
||||
assert comm_op_perf is not None, '{} is not profiled'.format(
|
||||
comm_op)
|
||||
elapsed_time = self._model_comm_cost_from_s_curve(
|
||||
comm_op_perf[tuple(dim)][dtype], size)
|
||||
return elapsed_time
|
||||
elif self.config.comm_cost_model == CostModel.ALPHA_BETA:
|
||||
elapsed_time = self._model_comm_cost_from_alpha_beta(
|
||||
comm_op, tuple(dim), size)
|
||||
elif self.config.comm_cost_model == CostModel.PROFILE:
|
||||
assert False, 'Unsupported profile based communication cost model now'
|
||||
elif self.config.comm_cost_model == CostModel.ZERO:
|
||||
elapsed_time = 0.0
|
||||
|
||||
return elapsed_time # us
|
||||
|
||||
|
||||
class PhysicalDeviceMesh(object):
|
||||
|
||||
def __init__(self,
|
||||
phy_devices_id,
|
||||
config: AutoParallelConfig,
|
||||
prof_database=None,
|
||||
shape_consistency_manager=None,
|
||||
host_ips=None):
|
||||
self.phy_devices_id = np.array(phy_devices_id)
|
||||
self.num_hosts, self.num_devices_per_host = self.phy_devices_id.shape
|
||||
self.host_ips = host_ips
|
||||
if host_ips is None:
|
||||
self.host_ips = [''] * self.num_hosts
|
||||
self.config = config
|
||||
self.cluster_info = config.get_cluster_info()
|
||||
self.prof_database: ProfileDB = prof_database
|
||||
self.shape_consistency_manager = shape_consistency_manager
|
||||
if self.config.comm_cost_model not in CostModel:
|
||||
raise ValueError(
|
||||
f'unsupported communication cost model: {self.config.comm_cost_model}'
|
||||
)
|
||||
if self.config.sharding_cost_model not in CostModel:
|
||||
raise ValueError(
|
||||
f'unsupported sharding cost model: {self.config.sharding_cost_model}'
|
||||
)
|
||||
if self.config.comm_cost_model == CostModel.S_CURVE or self.config.sharding_cost_model == CostModel.PROFILE:
|
||||
if self.prof_database is None:
|
||||
profile_cache = config.profile_cache
|
||||
if profile_cache is None:
|
||||
self.prof_database = MemDB()
|
||||
else:
|
||||
self.prof_database = Hdf5DB(profile_cache)
|
||||
elif self.config.comm_cost_model == CostModel.ALPHA_BETA:
|
||||
assert self.cluster_info.intra_node_bw_per_device > 0, 'intra_node_bw_per_device is needed for alpha_beta method'
|
||||
assert self.cluster_info.inter_node_bw_per_device > 0, 'inter_node_bw_per_device is needed for alpha_beta method'
|
||||
if self.config.sharding_cost_model == CostModel.ALPHA_BETA:
|
||||
assert self.cluster_info.memory_bw > 0, 'memory_bw is needed for alpha_beta method'
|
||||
|
||||
if not shape_consistency_manager:
|
||||
self.shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return self.phy_devices_id.size
|
||||
|
||||
def close(self):
|
||||
if self.prof_database is not None:
|
||||
self.prof_database.close()
|
||||
|
||||
def split_pipeline_meshes(
|
||||
self, num_stages,
|
||||
num_devices_per_stage) -> List["PhysicalDeviceMesh"]:
|
||||
sub_meshes = []
|
||||
if num_devices_per_stage <= self.num_devices_per_host:
|
||||
assert self.num_devices_per_host % num_devices_per_stage == 0, \
|
||||
"num_devices_per_host ({}) % num_devices_per_stage ({}) != 0"\
|
||||
.format(self.num_devices_per_host, num_devices_per_stage)
|
||||
num_clusters_per_host = self.num_devices_per_host // num_devices_per_stage
|
||||
num_clusters = self.num_hosts * num_clusters_per_host
|
||||
assert num_stages % num_clusters == 0, \
|
||||
"num_stages({}) % num_clusters({}) !=0".format(num_stages, num_clusters)
|
||||
for mesh_id in range(num_stages):
|
||||
cluster_id = mesh_id % num_clusters
|
||||
cluster_col = cluster_id % num_clusters_per_host
|
||||
cluster_row = cluster_id // num_clusters_per_host
|
||||
sub_devices_id = [
|
||||
self.phy_devices_id[cluster_row][cluster_col *
|
||||
num_devices_per_stage:(
|
||||
(cluster_col + 1) *
|
||||
num_devices_per_stage)]
|
||||
]
|
||||
sub_meshes.append(
|
||||
PhysicalDeviceMesh(sub_devices_id, self.config,
|
||||
self.prof_database,
|
||||
self.shape_consistency_manager,
|
||||
[self.host_ips[cluster_row]]))
|
||||
else:
|
||||
assert num_devices_per_stage % self.num_devices_per_host == 0, \
|
||||
"num_devices_per_stage ({}) % num_devices_per_host ({}) != 0"\
|
||||
.format(num_devices_per_stage, self.num_devices_per_host)
|
||||
num_host_per_cluster = num_devices_per_stage // self.num_devices_per_host
|
||||
assert self.num_hosts % num_host_per_cluster == 0, \
|
||||
"num_hosts ({}) % num_host_per_cluster({}) != 0".format(self.num_hosts, num_host_per_cluster)
|
||||
num_clusters = self.num_hosts // num_host_per_cluster
|
||||
for mesh_id in range(num_stages):
|
||||
cluster_id = mesh_id % num_clusters
|
||||
cluster_row = cluster_id * num_host_per_cluster
|
||||
sub_devices_id = self.phy_devices_id[cluster_row:(
|
||||
cluster_row + num_host_per_cluster)]
|
||||
host_ips = self.host_ips[cluster_row:(cluster_row +
|
||||
num_host_per_cluster)]
|
||||
sub_meshes.append(
|
||||
PhysicalDeviceMesh(sub_devices_id, self.config,
|
||||
self.prof_database,
|
||||
self.shape_consistency_manager,
|
||||
host_ips))
|
||||
return sub_meshes
|
||||
|
||||
def _profile_logical_meshes(self, logical_meshes, step):
|
||||
for lmesh in logical_meshes:
|
||||
lmesh.profile_all_comms_perf(step)
|
||||
|
||||
def as_logical_mesh(self) -> LogicalDeviceMesh:
|
||||
alpha = [
|
||||
self.cluster_info.inter_node_bw_per_device,
|
||||
self.cluster_info.intra_node_bw_per_device
|
||||
]
|
||||
beta = [
|
||||
self.cluster_info.inter_node_latency,
|
||||
self.cluster_info.intra_node_latency
|
||||
]
|
||||
sharp = [
|
||||
self.cluster_info.inter_node_sharp,
|
||||
self.cluster_info.intra_node_sharp
|
||||
]
|
||||
return LogicalDeviceMesh(
|
||||
self.phy_devices_id.shape,
|
||||
self.phy_devices_id.shape,
|
||||
self.phy_devices_id,
|
||||
self.config,
|
||||
alpha,
|
||||
beta,
|
||||
sharp,
|
||||
self.prof_database,
|
||||
self.shape_consistency_manager,
|
||||
self.host_ips,
|
||||
)
|
||||
|
||||
def get_logical_meshes(self):
|
||||
logical_meshes = []
|
||||
# (1, 2) -> (1, 2)
|
||||
# (1, 4) -> (2, 2)
|
||||
# (1, 8) -> (2, 4)
|
||||
# (1, 16) -> (2, 8), (4, 4)
|
||||
# (1, 32) -> (2, 16), (4, 8)
|
||||
# (1, 48) -> (2, 24), (3, 16), (4, 12), (6, 8)
|
||||
# (1, 64) -> (2, 32), (4, 16), (8, 8)
|
||||
# we will traverse logical shape's axis in sharding spec, thus (2, 8) contains (8, 2)
|
||||
# we will merge logical shapes' axis, thus (2, 8) contains (1, 16) and (16, 1)
|
||||
if self.num_hosts == 1:
|
||||
alpha = [self.cluster_info.intra_node_bw_per_device]
|
||||
beta = [self.cluster_info.intra_node_latency]
|
||||
sharp = [self.cluster_info.intra_node_sharp]
|
||||
for i in range(2, self.num_devices_per_host):
|
||||
if self.num_devices_per_host % i == 0 and i * i <= self.num_devices_per_host:
|
||||
lmesh_shape = (i, self.num_devices_per_host // i)
|
||||
lmesh_phy_ids = self.phy_devices_id.reshape(lmesh_shape)
|
||||
logical_meshes.append(
|
||||
LogicalDeviceMesh(self.phy_devices_id.shape,
|
||||
lmesh_shape, lmesh_phy_ids,
|
||||
self.config, alpha, beta, sharp,
|
||||
self.prof_database,
|
||||
self.shape_consistency_manager,
|
||||
self.host_ips))
|
||||
# (8, 1) -> (2, 4)
|
||||
# (16, 1) -> (2, 8), (4, 4)
|
||||
elif self.num_devices_per_host == 1:
|
||||
alpha = [self.cluster_info.inter_node_bw_per_device]
|
||||
beta = [self.cluster_info.inter_node_latency]
|
||||
sharp = [self.cluster_info.inter_node_sharp]
|
||||
for i in range(2, self.num_hosts):
|
||||
if self.num_hosts % i == 0 and i * i <= self.num_hosts:
|
||||
lmesh_shape = (i, self.num_hosts // i)
|
||||
lmesh_phy_ids = self.phy_devices_id.reshape(lmesh_shape)
|
||||
logical_meshes.append(
|
||||
LogicalDeviceMesh(self.phy_devices_id.shape,
|
||||
lmesh_phy_ids, self.config, alpha,
|
||||
beta, sharp, self.prof_database,
|
||||
self.shape_consistency_manager,
|
||||
self.host_ips))
|
||||
# (2, 1) -> (2, 1)
|
||||
# (2, 8) -> (2, 8)
|
||||
# (1, 2) -> (1, 2)
|
||||
# (1, 3) -> (1, 3)
|
||||
# (1, 5) -> (1, 5)
|
||||
if 0 == len(logical_meshes):
|
||||
logical_meshes.append(self.as_logical_mesh())
|
||||
return logical_meshes
|
||||
|
||||
'''
|
||||
we assume we can evenly split the pipeline and deviceMesh
|
||||
'''
|
||||
|
||||
def _list_all_sub_meshes(self):
|
||||
sub_meshes = []
|
||||
for num_devices_per_stage in range(1, self.num_devices_per_host + 1):
|
||||
if self.num_devices_per_host % num_devices_per_stage == 0:
|
||||
num_stages = self.num_hosts * self.num_devices_per_host // num_devices_per_stage
|
||||
sub_meshes.append(
|
||||
self.split_pipeline_meshes(num_stages,
|
||||
num_devices_per_stage)[0])
|
||||
for num_hosts_per_stage in range(2, self.num_hosts + 1):
|
||||
if self.num_hosts % num_hosts_per_stage == 0:
|
||||
num_stages = self.num_hosts // num_hosts_per_stage
|
||||
sub_meshes.append(
|
||||
self.split_pipeline_meshes(
|
||||
num_stages,
|
||||
num_hosts_per_stage * self.num_devices_per_host)[0])
|
||||
return sub_meshes
|
||||
|
||||
def list_all_pipeline_configs(self):
|
||||
configs = []
|
||||
for num_devices_per_stage in range(1, self.num_devices_per_host + 1):
|
||||
if self.num_devices_per_host % num_devices_per_stage == 0:
|
||||
num_stages = self.num_hosts * self.num_devices_per_host // num_devices_per_stage
|
||||
configs.append((num_stages, num_devices_per_stage))
|
||||
for num_hosts_per_stage in range(2, self.num_hosts + 1):
|
||||
if self.num_hosts % num_hosts_per_stage == 0:
|
||||
num_stages = self.num_hosts // num_hosts_per_stage
|
||||
configs.append(
|
||||
(num_stages,
|
||||
num_hosts_per_stage * self.num_devices_per_host))
|
||||
return configs
|
||||
|
||||
def profile_s_curve(self, step):
|
||||
sub_phy_device_meshes = self._list_all_sub_meshes()
|
||||
for phy_mesh in sub_phy_device_meshes:
|
||||
lmeshes = phy_mesh.get_logical_meshes()
|
||||
self._profile_logical_meshes(lmeshes, step)
|
||||
if 2 == step:
|
||||
self.save_profile_database()
|
||||
|
||||
def profile_alpha_beta(self):
|
||||
alpha = [250, 25]
|
||||
beta = [100, 100]
|
||||
return alpha, beta
|
||||
@ -1,347 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
import tensorrt as trt
|
||||
|
||||
from .pipeline_graph import PipelineGraph
|
||||
from .runtime_profiling import RuntimeProfiler
|
||||
from .simplifier import GraphConfig, StageType
|
||||
from .solver import CostGraph, Solver
|
||||
from .tensor_parallel.activation_node import Activation
|
||||
from .tensor_parallel.assertion_node import Assertion
|
||||
from .tensor_parallel.cast_node import Cast
|
||||
from .tensor_parallel.concatenation_node import Concatenation
|
||||
from .tensor_parallel.constant_node import Constant
|
||||
from .tensor_parallel.elementwise_node import ElementWise
|
||||
from .tensor_parallel.fill_node import Fill
|
||||
from .tensor_parallel.gather_node import Gather
|
||||
from .tensor_parallel.identity_node import Identity
|
||||
from .tensor_parallel.input_node import InputNode
|
||||
from .tensor_parallel.matmul_node import MatrixMultiply
|
||||
from .tensor_parallel.node import Node
|
||||
from .tensor_parallel.normalization_node import Normalization
|
||||
from .tensor_parallel.output_node import OuputNode
|
||||
from .tensor_parallel.p2p_node import P2PNode, P2PType
|
||||
from .tensor_parallel.plugin_node import PluginNode
|
||||
from .tensor_parallel.plugin_nodes.gemm_node import GemmPlugin
|
||||
from .tensor_parallel.plugin_nodes.gpt_attention_node import GPTAttentionPlugin
|
||||
from .tensor_parallel.plugin_nodes.identity_node import IdentityPlugin
|
||||
from .tensor_parallel.plugin_nodes.look_up_node import LookupPlugin
|
||||
from .tensor_parallel.plugin_nodes.normalization_node import (LayernormPlugin,
|
||||
RMSnormPlugin)
|
||||
from .tensor_parallel.reduce_node import Reduce
|
||||
from .tensor_parallel.select_node import Select
|
||||
from .tensor_parallel.shape_node import Shape
|
||||
from .tensor_parallel.shuffle_node import Shuffle
|
||||
from .tensor_parallel.slice_node import Slice
|
||||
from .tensor_parallel.softmax_node import SoftMax
|
||||
from .tensor_parallel.unary_node import Unary
|
||||
|
||||
LAYER_TYPE_2_NODE_TYPE = {
|
||||
trt.LayerType.ACTIVATION: Activation,
|
||||
trt.LayerType.ASSERTION: Assertion,
|
||||
trt.LayerType.CAST: Cast,
|
||||
trt.LayerType.CONCATENATION: Concatenation,
|
||||
trt.LayerType.CONSTANT: Constant,
|
||||
trt.LayerType.ELEMENTWISE: ElementWise,
|
||||
trt.LayerType.FILL: Fill,
|
||||
trt.LayerType.GATHER: Gather,
|
||||
trt.LayerType.IDENTITY: Identity,
|
||||
trt.LayerType.MATRIX_MULTIPLY: MatrixMultiply,
|
||||
trt.LayerType.NORMALIZATION: Normalization,
|
||||
trt.LayerType.PLUGIN_V2: PluginNode,
|
||||
trt.LayerType.REDUCE: Reduce,
|
||||
trt.LayerType.SELECT: Select,
|
||||
trt.LayerType.SHAPE: Shape,
|
||||
trt.LayerType.SHUFFLE: Shuffle,
|
||||
trt.LayerType.SLICE: Slice,
|
||||
trt.LayerType.SOFTMAX: SoftMax,
|
||||
trt.LayerType.UNARY: Unary,
|
||||
}
|
||||
# TODO: BertAttention/All Quant plugins
|
||||
PLUGIN_LAYER_TYPE_2_NODE_TYPE = {
|
||||
'GPTAttention': GPTAttentionPlugin,
|
||||
'Gemm': GemmPlugin,
|
||||
'Layernorm': LayernormPlugin,
|
||||
'Rmsnorm': RMSnormPlugin,
|
||||
'Lookup': LookupPlugin,
|
||||
'Identity': IdentityPlugin,
|
||||
}
|
||||
|
||||
|
||||
class NodeGraph:
|
||||
|
||||
def __init__(self, graph: PipelineGraph):
|
||||
self._nodes = {}
|
||||
|
||||
# construct nodes
|
||||
for input in graph.inputs:
|
||||
self._nodes[input.name] = InputNode(input)
|
||||
for layer in graph.layers:
|
||||
layer.to_base_class()
|
||||
if "p2p_type" in layer.attrs:
|
||||
self._nodes[layer.name] = P2PNode(layer)
|
||||
elif layer.type == trt.LayerType.PLUGIN_V2:
|
||||
layer.to_subclass()
|
||||
plugin_type = layer.as_trt().plugin.plugin_type
|
||||
layer.to_base_class()
|
||||
if plugin_type in PLUGIN_LAYER_TYPE_2_NODE_TYPE:
|
||||
node = PLUGIN_LAYER_TYPE_2_NODE_TYPE[plugin_type](layer)
|
||||
else:
|
||||
node = LAYER_TYPE_2_NODE_TYPE[layer.type](layer)
|
||||
self._nodes[layer.name] = node
|
||||
else:
|
||||
node = LAYER_TYPE_2_NODE_TYPE[layer.type](layer)
|
||||
self._nodes[layer.name] = node
|
||||
for output in graph.outputs:
|
||||
self._nodes[output.name] = OuputNode(output)
|
||||
for node in self.nodes:
|
||||
node.post_init(self)
|
||||
node.node_runtime_profiler = RuntimeProfiler()
|
||||
|
||||
def get_node(self, name):
|
||||
return self._nodes[name]
|
||||
|
||||
@property
|
||||
def nodes(self) -> List[Node]:
|
||||
return [*self._nodes.values()]
|
||||
|
||||
def assign_cost_weights(self, graph_config: GraphConfig):
|
||||
layer_mapping = graph_config.graph_mapping.layer_mapping
|
||||
for layer_name in layer_mapping.values():
|
||||
node = self.get_node(layer_name)
|
||||
node.sharding_weight += 1
|
||||
node.resharding_weight += 1
|
||||
same_spec_layer_mapping = graph_config.graph_mapping.same_spec_layer_mapping
|
||||
for same_spec_layer_name, layer_name in same_spec_layer_mapping.items():
|
||||
node = self.get_node(layer_name)
|
||||
same_spec_node = self.get_node(same_spec_layer_name)
|
||||
same_spec_node.sharding_weight = node.sharding_weight
|
||||
same_spec_node.resharding_weight = node.resharding_weight
|
||||
|
||||
def set_slowest_stage(self, stage_type: StageType,
|
||||
graph_config: GraphConfig):
|
||||
num_micro_batches = graph_config.num_micro_batches
|
||||
block_per_stage = graph_config.num_blocks // graph_config.num_stages
|
||||
block_pipeline_weight = block_per_stage * (num_micro_batches - 1)
|
||||
for node in self.nodes:
|
||||
node.pipeline_weight = 0
|
||||
node.cost_level = -1
|
||||
if node.stage_type == StageType.START:
|
||||
if stage_type == StageType.START:
|
||||
node.pipeline_weight = num_micro_batches - 1
|
||||
node.cost_level = 1
|
||||
else:
|
||||
node.cost_level = 0
|
||||
if stage_type == StageType.START and node.in_start_block:
|
||||
node.pipeline_weight = block_pipeline_weight
|
||||
if node.stage_type == StageType.END:
|
||||
if stage_type == StageType.END:
|
||||
node.pipeline_weight = num_micro_batches - 1
|
||||
node.cost_level = 1
|
||||
else:
|
||||
node.cost_level = 0
|
||||
if stage_type == StageType.END and node.in_end_block:
|
||||
node.pipeline_weight = block_pipeline_weight
|
||||
if isinstance(node, P2PNode):
|
||||
if (graph_config.has_cross_host
|
||||
and node.p2p_type == P2PType.CROSS_HOST) or (
|
||||
not graph_config.has_cross_host
|
||||
and node.p2p_type == P2PType.CROSS_DEVICE):
|
||||
if stage_type == StageType.BLOCK:
|
||||
node.pipeline_weight += num_micro_batches - 1
|
||||
node.cost_level = 1
|
||||
else:
|
||||
node.cost_level = 0
|
||||
elif (graph_config.has_cross_device
|
||||
and node.p2p_type == P2PType.CROSS_DEVICE) or (
|
||||
not graph_config.has_cross_device
|
||||
and node.p2p_type == P2PType.CROSS_HOST):
|
||||
node.pipeline_weight += num_micro_batches - 1
|
||||
if stage_type == StageType.BLOCK and node.in_slowest_block:
|
||||
node.pipeline_weight = block_pipeline_weight
|
||||
|
||||
def get_cost_graph(self, lmesh):
|
||||
leaf_strategies = []
|
||||
for node in self.nodes:
|
||||
if node.is_replicated:
|
||||
node.set_strategy(None, lmesh)
|
||||
else:
|
||||
node.collect_strategies(lmesh)
|
||||
for node in self.nodes:
|
||||
strategies_vector = node.update_resharding_cost()
|
||||
if len(strategies_vector) != 0:
|
||||
leaf_strategies.append(strategies_vector)
|
||||
cost_graph = CostGraph(leaf_strategies)
|
||||
return cost_graph
|
||||
|
||||
def find_solution(self, cost_graph, memory_budget):
|
||||
solver = Solver(cost_graph, memory_budget=memory_budget)
|
||||
solution = solver.find_solution()[1]
|
||||
|
||||
graph_strategy = solution.node_best_strategy
|
||||
for node_name, strategy in graph_strategy.items():
|
||||
node = self._nodes[node_name]
|
||||
for idx, pre_node in enumerate(node.predecessor_nodes):
|
||||
if pre_node is None:
|
||||
continue
|
||||
if pre_node.node_name not in strategy.best_resharding_cost:
|
||||
continue
|
||||
strategy.best_resharding_cost[
|
||||
idx] = strategy.best_resharding_cost[pre_node.node_name]
|
||||
strategy.node_names[idx] = pre_node.node_name
|
||||
for key in list(strategy.best_resharding_cost.keys()):
|
||||
if isinstance(key, str):
|
||||
del strategy.best_resharding_cost[key]
|
||||
|
||||
return solution
|
||||
|
||||
def visualize(self, name='pp_graph'):
|
||||
with open(name + '.dot', 'w') as f:
|
||||
f.write("digraph {\n")
|
||||
'''
|
||||
f.write(" // Value Nodes\n")
|
||||
for name, tensor in self._tensors.items():
|
||||
f.write(" \"{}\" [fillcolor = \"green\", label = \"{}\", shape = \"box\", style = \"filled\"];\n".format(name, tensor.shape))
|
||||
'''
|
||||
f.write(" // Operation Nodes\n")
|
||||
for name, node in self._nodes.items():
|
||||
fillcolor = 'white'
|
||||
if 'MATRIX_MULTIPLY' in name:
|
||||
fillcolor = 'green'
|
||||
label = name
|
||||
if len(node.outputs) > 0:
|
||||
label = name + '\\n' + str(node.outputs[0].shape)
|
||||
f.write(
|
||||
" \"{}\" [fillcolor = \"{}\", label = \"{}\", shape = \"box\", style = \"filled\"];\n"
|
||||
.format(name, fillcolor, label))
|
||||
f.write(" // Edges\n")
|
||||
for name, node in self._nodes.items():
|
||||
for successor_node in node.successor_nodes:
|
||||
if successor_node:
|
||||
f.write(" \"{}\" ->\"{}\";\n".format(
|
||||
name, successor_node.node_name))
|
||||
f.write(" }\n")
|
||||
|
||||
def visualize_solution(self,
|
||||
solution,
|
||||
fname='pp_graph_solution',
|
||||
ignore_shape_io=True):
|
||||
with open(fname + '.dot', 'w') as f:
|
||||
names, costs, block_ids = [], [], []
|
||||
f.write("digraph {\n")
|
||||
f.write(" // Operation Nodes\n")
|
||||
for name, node in self._nodes.items():
|
||||
if ignore_shape_io and node.layer is not None and node.layer.is_shape_io:
|
||||
continue
|
||||
cost = 0.0
|
||||
fillcolor = 'white'
|
||||
if 'MATRIX_MULTIPLY' in name or 'PLUGIN_V2_Gemm' in name:
|
||||
fillcolor = 'orange'
|
||||
elif '_same_spec' in name:
|
||||
fillcolor = 'gray'
|
||||
elif 'p2p_block' in name:
|
||||
fillcolor = 'blue'
|
||||
elif 'PLUGIN' in name:
|
||||
fillcolor = 'yellow'
|
||||
|
||||
shape = 'box'
|
||||
if 'output_node' == node.node_type or 'input_node' == node.node_type:
|
||||
shape = 'ellipse'
|
||||
fillcolor = 'green'
|
||||
|
||||
label = name + f'_block{node.building_block_id}_weight{node.sharding_weight}'
|
||||
if len(node.inputs) > 0:
|
||||
for idx, input in enumerate(node.inputs):
|
||||
if not input:
|
||||
continue
|
||||
label = label + f'\\ninput{idx}_' + str(
|
||||
input.shape) + f'_{input.dtype_str_size[0]}_'
|
||||
if node.node_name in solution.node_best_strategy:
|
||||
best_strategy = solution.node_best_strategy[
|
||||
node.node_name]
|
||||
shard_seq = str(
|
||||
best_strategy.sharding_specs[f'input{idx}'].
|
||||
sharding_sequence)
|
||||
label = label + shard_seq
|
||||
if idx not in best_strategy.best_resharding_cost:
|
||||
continue
|
||||
rcosts = best_strategy.best_resharding_cost[idx][0]
|
||||
comm_action_sequence, resharding_cost = rcosts[
|
||||
1], rcosts[2]
|
||||
if len(comm_action_sequence) > 0:
|
||||
label = label + '|'
|
||||
for commspec in comm_action_sequence:
|
||||
comm = [
|
||||
commspec.comm_pattern, commspec.gather_dim,
|
||||
commspec.shard_dim,
|
||||
commspec.logical_process_axis
|
||||
]
|
||||
label = label + '->' + str(comm)
|
||||
if resharding_cost > 0:
|
||||
label = label + '_rcost{:.2}'.format(
|
||||
resharding_cost)
|
||||
cost = cost + resharding_cost
|
||||
if len(node.outputs) > 0:
|
||||
best_strategy = None
|
||||
for idx, output in enumerate(node.outputs):
|
||||
label = label + f'\\noutput{idx}_' + str(
|
||||
output.shape) + f'_{output.dtype_str_size[0]}'
|
||||
if node.node_name in solution.node_best_strategy:
|
||||
best_strategy = solution.node_best_strategy[
|
||||
node.node_name]
|
||||
shard_seq = str(
|
||||
best_strategy.sharding_specs[f'output{idx}'].
|
||||
sharding_sequence)
|
||||
comm = None
|
||||
if f'output{idx}' in best_strategy.communication_actions:
|
||||
commspec = best_strategy.communication_actions[
|
||||
f'output{idx}']
|
||||
comm = [
|
||||
commspec.comm_pattern, commspec.gather_dim,
|
||||
commspec.shard_dim,
|
||||
commspec.logical_process_axis
|
||||
]
|
||||
label = label + '_' + shard_seq
|
||||
if comm:
|
||||
label = label + f' | {comm}'
|
||||
if best_strategy:
|
||||
cost = cost + best_strategy.sharding_cost + best_strategy.communication_cost
|
||||
label = label + '| scost{:.2}'.format(
|
||||
best_strategy.sharding_cost)
|
||||
if best_strategy.communication_cost > 0:
|
||||
label = label + ' | ccost{:.2}'.format(
|
||||
best_strategy.communication_cost)
|
||||
names.append(name)
|
||||
costs.append(cost)
|
||||
block_ids.append([
|
||||
node.building_block_id, node.cost_level,
|
||||
node.sharding_weight + node.pipeline_weight,
|
||||
node.same_spec_id
|
||||
])
|
||||
f.write(
|
||||
" \"{}\" [fillcolor = \"{}\", label = \"{}\", shape = \"{}\", style = \"filled\"];\n"
|
||||
.format(name, fillcolor, label, shape))
|
||||
f.write(" // Edges\n")
|
||||
for name, node in self._nodes.items():
|
||||
if ignore_shape_io and node.layer is not None and node.layer.is_shape_io:
|
||||
continue
|
||||
for successor_node in node.successor_nodes:
|
||||
if successor_node:
|
||||
if ignore_shape_io and successor_node.layer is not None and successor_node.layer.is_shape_io:
|
||||
continue
|
||||
f.write(" \"{}\" ->\"{}\";\n".format(
|
||||
name, successor_node.node_name))
|
||||
f.write(" }\n")
|
||||
df = pd.DataFrame.from_dict({
|
||||
'node':
|
||||
names,
|
||||
'cost':
|
||||
costs,
|
||||
'block_id': [block[0] for block in block_ids],
|
||||
'cost_level': [block[1] for block in block_ids],
|
||||
'sharding_weight': [block[2] for block in block_ids],
|
||||
'same_spec_id': [block[3] for block in block_ids]
|
||||
})
|
||||
df['weight_cost'] = df['sharding_weight'] * df['cost']
|
||||
df.to_csv(fname + '.csv')
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,150 +0,0 @@
|
||||
import numpy as np
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.network import get_plugin_info
|
||||
|
||||
from .shape_info import get_per_layer_graph
|
||||
from .utils import get_cache_key, get_trt_network, get_updated_plugin
|
||||
|
||||
|
||||
class NvtxProfiler(object):
|
||||
|
||||
def __init__(self, nvtx_name, enable=True):
|
||||
self.nvtx_name = nvtx_name
|
||||
self.enable = enable
|
||||
|
||||
def __enter__(self):
|
||||
if self.enable:
|
||||
torch.cuda.nvtx.range_push(self.nvtx_name)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.enable:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
|
||||
class LayerProfiler(trt.IProfiler):
|
||||
|
||||
def __init__(self):
|
||||
trt.IProfiler.__init__(self)
|
||||
self.layer_count = 0
|
||||
self.time = 0
|
||||
|
||||
def report_layer_time(self, layer_name, ms):
|
||||
logger.debug(f'{layer_name=}, {self.layer_count=}, time = {ms} ms')
|
||||
self.time += ms
|
||||
self.layer_count += 1
|
||||
|
||||
|
||||
class RuntimeProfiler(object):
|
||||
|
||||
def __init__(self):
|
||||
self.timing_cache = None
|
||||
|
||||
def _profile(self, layer, layer_attrs, shapes, values, io_buffer_mapping):
|
||||
is_plugin = layer.type == trt.LayerType.PLUGIN_V2
|
||||
if is_plugin and len(layer_attrs) > 0:
|
||||
plugin_info = get_plugin_info(
|
||||
get_trt_network(layer),
|
||||
layer.name,
|
||||
)
|
||||
new_plugin, _ = get_updated_plugin(plugin_info, layer_attrs)
|
||||
layer_attrs = {"plugin": new_plugin}
|
||||
graph, output_mapping = get_per_layer_graph(layer, shapes, values,
|
||||
layer_attrs)
|
||||
graph._io_buffer_mapping = io_buffer_mapping
|
||||
network = graph.as_trt()
|
||||
if network.num_outputs > 0 and np.all([
|
||||
network.get_output(i).is_shape_tensor
|
||||
for i in range(network.num_outputs)
|
||||
]):
|
||||
return 0.0
|
||||
for proxy_output, output in output_mapping.items():
|
||||
shapes[proxy_output] = shapes[output]
|
||||
if not self.timing_cache:
|
||||
self.timing_cache = network.builder.create_builder_config(
|
||||
).create_timing_cache(b"")
|
||||
runner = graph.get_runner(
|
||||
shapes,
|
||||
values,
|
||||
timing_cache=self.timing_cache,
|
||||
)
|
||||
context = runner.session.context
|
||||
context.profiler = LayerProfiler()
|
||||
runner.run()
|
||||
profiler_time_first_run = context.profiler.time
|
||||
runner.run()
|
||||
return (context.profiler.time - profiler_time_first_run) * 1000.0
|
||||
|
||||
def runtime_profile(self, layer, layer_attrs, input_values, strategy,
|
||||
device_mesh):
|
||||
logger.debug(f"start to profile layer {layer.name}")
|
||||
shapes = {}
|
||||
values = {}
|
||||
dtypes = {}
|
||||
trt_layer = layer.as_trt()
|
||||
|
||||
sharding_sequences = ()
|
||||
for i in range(layer.num_inputs):
|
||||
input = trt_layer.get_input(i)
|
||||
if input is not None:
|
||||
shapes[input.name] = strategy.sharding_specs[
|
||||
f'input{i}'].get_sharded_shape_per_device()
|
||||
dtypes[input.name] = input.dtype
|
||||
sharding_sequences += (str(
|
||||
strategy.sharding_specs[f"input{i}"].sharding_sequence), )
|
||||
if i in input_values:
|
||||
values[input.name] = input_values[i]
|
||||
else:
|
||||
value = layer.get_input(i).value
|
||||
if value is not None:
|
||||
values[input.name] = value
|
||||
else:
|
||||
sharding_sequences += (None, )
|
||||
|
||||
for i in range(layer.num_outputs):
|
||||
output = trt_layer.get_output(i)
|
||||
if f'output{i}' in strategy.communication_actions:
|
||||
shapes[output.name] = strategy.communication_actions[
|
||||
f'output{i}'].sharding_spec.get_sharded_shape_per_device()
|
||||
else:
|
||||
shapes[output.name] = strategy.sharding_specs[
|
||||
f'output{i}'].get_sharded_shape_per_device()
|
||||
dtypes[output.name] = output.dtype
|
||||
sharding_sequences += (str(
|
||||
strategy.sharding_specs[f"output{i}"].sharding_sequence), )
|
||||
data_key = get_cache_key(
|
||||
trt_layer,
|
||||
shapes,
|
||||
values,
|
||||
dtypes=dtypes,
|
||||
updated_attrs=layer_attrs,
|
||||
)
|
||||
data_key += (sharding_sequences, )
|
||||
elapsed_time = device_mesh.prof_database.query(
|
||||
device_mesh.cluster_key,
|
||||
data_key,
|
||||
)
|
||||
if elapsed_time:
|
||||
logger.debug(
|
||||
f'runtime profiling cache hit {data_key}: {elapsed_time} us')
|
||||
return elapsed_time
|
||||
with NvtxProfiler(f'{layer.name}_{data_key}', enable=True):
|
||||
elapsed_time = self._profile(
|
||||
layer.as_trt(),
|
||||
layer_attrs,
|
||||
shapes,
|
||||
values,
|
||||
layer.graph._io_buffer_mapping,
|
||||
)
|
||||
logger.debug(
|
||||
f'runtime profiling cache miss {data_key}: {elapsed_time} us')
|
||||
|
||||
device_mesh.prof_database.update(
|
||||
device_mesh.cluster_key,
|
||||
data_key,
|
||||
(elapsed_time, strategy.alpha_beta_cost),
|
||||
)
|
||||
|
||||
return elapsed_time
|
||||
@ -1,362 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Set
|
||||
|
||||
import numpy as np
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._common import _is_building
|
||||
from tensorrt_llm._utils import (trt_dtype_to_np, trt_dtype_to_str,
|
||||
trt_dtype_to_torch)
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from .pipeline_graph import PipelineGraph
|
||||
from .utils import (get_builder_flags, get_cache_key, get_sorted_layer_ids,
|
||||
set_trt_network, to_base_class_layer, to_subclass_layer,
|
||||
to_trt_weights)
|
||||
|
||||
|
||||
class ShapeType(Enum):
|
||||
MIN = 0
|
||||
OPT = 1
|
||||
MAX = 2
|
||||
|
||||
|
||||
_trt_to_type_dict = {
|
||||
trt.int64: int,
|
||||
trt.bool: bool,
|
||||
}
|
||||
|
||||
|
||||
def get_shape_layers(trt_network):
|
||||
shape_layers = set()
|
||||
for i in range(trt_network.num_layers):
|
||||
layer = trt_network.get_layer(i)
|
||||
if (layer.num_inputs > 0 and np.all([
|
||||
layer.get_input(j).is_shape_tensor
|
||||
for j in range(layer.num_inputs)
|
||||
if layer.get_input(j) is not None
|
||||
])) or (layer.num_outputs > 0 and np.all([
|
||||
layer.get_output(j).is_shape_tensor
|
||||
for j in range(layer.num_outputs)
|
||||
])):
|
||||
shape_layers.add(layer.name)
|
||||
return shape_layers
|
||||
|
||||
|
||||
def get_layers_in_shape_network(trt_network, shape_layers, sorted_layer_ids):
|
||||
layers = set()
|
||||
shape_tensors = set()
|
||||
for layer_id in reversed(sorted_layer_ids):
|
||||
layer = trt_network.get_layer(layer_id)
|
||||
in_shape_network = False
|
||||
if layer.name in shape_layers:
|
||||
in_shape_network = True
|
||||
else:
|
||||
for j in range(layer.num_outputs):
|
||||
output = layer.get_output(j)
|
||||
if output.name in shape_tensors:
|
||||
in_shape_network = True
|
||||
break
|
||||
if in_shape_network:
|
||||
layers.add(layer.name)
|
||||
for j in range(layer.num_inputs):
|
||||
input = layer.get_input(j)
|
||||
if input is not None:
|
||||
shape_tensors.add(input.name)
|
||||
return layers
|
||||
|
||||
|
||||
def get_shape_network(trt_network,
|
||||
shapes,
|
||||
values,
|
||||
sorted_layer_ids,
|
||||
profile=None,
|
||||
shape_type: ShapeType = ShapeType.OPT):
|
||||
shape_layers = get_shape_layers(trt_network)
|
||||
layers_in_shape_network = get_layers_in_shape_network(
|
||||
trt_network, shape_layers, sorted_layer_ids)
|
||||
shape_graph = PipelineGraph.create_graph()
|
||||
shape_network = shape_graph.as_trt()
|
||||
shape_builder = shape_network.builder
|
||||
shape_profile = shape_builder.create_optimization_profile()
|
||||
for i in range(trt_network.num_inputs):
|
||||
input = trt_network.get_input(i)
|
||||
shapes[input.name] = input.shape
|
||||
new_input = shape_graph.add_input(input)
|
||||
if profile is not None:
|
||||
if -1 in input.shape:
|
||||
shape = profile.get_shape(input.name)
|
||||
shape = shape[shape_type.value]
|
||||
shapes[input.name] = shape
|
||||
new_input.raw_shape = shape
|
||||
if input.is_shape_tensor:
|
||||
shape_values = profile.get_shape_input(input.name)
|
||||
value = shape_values[shape_type.value]
|
||||
values[input.name] = value
|
||||
shape_profile.set_shape_input(input.name, value, value, value)
|
||||
output_mapping = {}
|
||||
for layer_id in sorted_layer_ids:
|
||||
layer = trt_network.get_layer(layer_id)
|
||||
if layer.name in shape_layers:
|
||||
new_layer = shape_graph.add_layer(layer)
|
||||
for i in range(layer.num_outputs):
|
||||
output = layer.get_output(i)
|
||||
if output.dtype == trt.DataType.BOOL:
|
||||
proxy_layer = shape_network.add_cast(
|
||||
new_layer.as_trt().get_output(i),
|
||||
trt.DataType.INT32,
|
||||
)
|
||||
proxy_output = proxy_layer.get_output(0)
|
||||
shape_graph.register_layer(proxy_layer)
|
||||
shape_graph.add_output_shape(proxy_output)
|
||||
output_mapping[proxy_output.name] = (output.name,
|
||||
output.dtype)
|
||||
else:
|
||||
shape_graph.add_output_shape(output)
|
||||
elif layer.name in layers_in_shape_network:
|
||||
if layer.type == trt.LayerType.CONSTANT:
|
||||
shape_graph.add_input(layer.get_output(0))
|
||||
else:
|
||||
shape_graph.add_layer(layer)
|
||||
return shape_network, shape_profile, shape_layers, output_mapping
|
||||
|
||||
|
||||
def get_per_layer_graph(
|
||||
layer,
|
||||
shapes,
|
||||
values,
|
||||
updated_attrs=None,
|
||||
is_shape_io: bool = None,
|
||||
):
|
||||
graph = PipelineGraph.create_graph()
|
||||
network = graph.as_trt()
|
||||
is_shape_layer = layer.num_inputs != 0
|
||||
for i in range(layer.num_inputs):
|
||||
input = layer.get_input(i)
|
||||
if input is not None:
|
||||
shape = shapes[input.name]
|
||||
if (values.get(input.name) is not None
|
||||
and not isinstance(values[input.name], torch.Tensor)):
|
||||
value = values[input.name]
|
||||
weights = np.asarray(value, dtype=trt_dtype_to_np(input.dtype))
|
||||
weights = to_trt_weights(weights)
|
||||
input_layer = network.add_constant(shape, weights)
|
||||
new_input = input_layer.get_output(0)
|
||||
new_input.name = input.name
|
||||
graph.register_layer(input_layer)
|
||||
elif graph.get_input(input.name) is None:
|
||||
new_input = graph.add_input(input)
|
||||
new_input.raw_shape = shapes[input.name]
|
||||
is_shape_layer = False
|
||||
new_layer = graph.add_layer(
|
||||
layer,
|
||||
updated_attrs=updated_attrs,
|
||||
)
|
||||
output_mapping = {}
|
||||
if layer.type == trt.LayerType.SHAPE:
|
||||
is_shape_layer = True
|
||||
if layer.num_inputs == 0:
|
||||
is_shape_layer = False
|
||||
if is_shape_io is not None:
|
||||
is_shape_layer = is_shape_io
|
||||
for i in range(layer.num_outputs):
|
||||
output = layer.get_output(i)
|
||||
value = values.get(output.name)
|
||||
if value is not None and isinstance(value, torch.Tensor):
|
||||
is_output_shape = False
|
||||
elif is_shape_layer:
|
||||
is_output_shape = True
|
||||
else:
|
||||
is_output_shape = False
|
||||
if is_output_shape:
|
||||
if output.dtype == trt.DataType.BOOL:
|
||||
proxy_layer = network.add_cast(
|
||||
new_layer.as_trt().get_output(i),
|
||||
trt.DataType.INT32,
|
||||
)
|
||||
proxy_output = proxy_layer.get_output(0)
|
||||
graph.register_layer(proxy_layer)
|
||||
output_mapping[proxy_output.name] = (output.name, output.dtype)
|
||||
output = proxy_output
|
||||
graph.add_output_shape(output)
|
||||
else:
|
||||
graph.add_output(output)
|
||||
return graph, output_mapping
|
||||
|
||||
|
||||
@_is_building
|
||||
def infer_shapes(network, shapes, values, profile=None):
|
||||
if network.num_outputs == 0:
|
||||
return
|
||||
builder = network.builder
|
||||
config = builder.create_builder_config()
|
||||
config.builder_optimization_level = 0
|
||||
config.flags = get_builder_flags()
|
||||
profile = profile or builder.create_optimization_profile()
|
||||
config.add_optimization_profile(profile)
|
||||
plan = builder.build_serialized_network(network, config)
|
||||
if plan is None:
|
||||
raise RuntimeError(
|
||||
'Engine building failed when inferring shapes, please check the error log.'
|
||||
)
|
||||
runtime = trt.Runtime(logger.trt_logger)
|
||||
engine = runtime.deserialize_cuda_engine(plan)
|
||||
context = engine.create_execution_context()
|
||||
for i in range(network.num_inputs):
|
||||
input = network.get_input(i)
|
||||
if input.is_shape_tensor:
|
||||
value = values[input.name]
|
||||
context.set_shape_input(engine[input.name], value)
|
||||
for i in range(network.num_outputs):
|
||||
output = network.get_output(i)
|
||||
shape = context.get_tensor_shape(output.name)
|
||||
shapes[output.name] = shape
|
||||
if output.is_shape_tensor:
|
||||
if shape == [0]:
|
||||
values[output.name] = []
|
||||
else:
|
||||
if shape == []:
|
||||
shape = [1]
|
||||
value = torch.empty(
|
||||
list(shape),
|
||||
dtype=trt_dtype_to_torch(output.dtype),
|
||||
device="cpu",
|
||||
)
|
||||
values[output.name] = value
|
||||
context.set_tensor_address(output.name, value.data_ptr())
|
||||
context.infer_shapes()
|
||||
assert context.all_binding_shapes_specified
|
||||
for i in range(network.num_outputs):
|
||||
output = network.get_output(i)
|
||||
if isinstance(values.get(output.name), torch.Tensor):
|
||||
values[output.name] = values[output.name].tolist()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShapeInfo:
|
||||
shapes: Dict[str, trt.Dims]
|
||||
values: Dict[str, List[int]]
|
||||
shape_layers: Set[str]
|
||||
max_shapes: Dict[str, trt.Dims] = None
|
||||
|
||||
|
||||
def set_constant_value(layer, values):
|
||||
to_subclass_layer(layer)
|
||||
output_name = layer.get_output(0).name
|
||||
weights = layer.weights
|
||||
if isinstance(weights, trt.Weights):
|
||||
weights = weights.numpy()
|
||||
values[output_name] = list(weights)
|
||||
to_base_class_layer(layer)
|
||||
|
||||
|
||||
def infer_per_layer_shapes(
|
||||
layer: trt.ILayer,
|
||||
shapes,
|
||||
values,
|
||||
cache=None,
|
||||
is_shape_io=False,
|
||||
):
|
||||
if layer.type == trt.LayerType.CONSTANT:
|
||||
to_subclass_layer(layer)
|
||||
output_name = layer.get_output(0).name
|
||||
shape = layer.shape
|
||||
shapes[output_name] = shape
|
||||
if is_shape_io:
|
||||
set_constant_value(layer, values)
|
||||
to_base_class_layer(layer)
|
||||
return
|
||||
elif layer.type == trt.LayerType.SHAPE:
|
||||
input_name = layer.get_input(0).name
|
||||
output_name = layer.get_output(0).name
|
||||
shape = [*shapes[input_name]]
|
||||
shapes[output_name] = trt.Dims([len(shape)])
|
||||
values[output_name] = shape
|
||||
return
|
||||
if cache is not None:
|
||||
cache_key = get_cache_key(layer, shapes, values)
|
||||
if cache_key in cache:
|
||||
output_shapes, output_values = cache[cache_key]
|
||||
for i in range(layer.num_outputs):
|
||||
output = layer.get_output(i)
|
||||
shapes[output.name] = output_shapes[i]
|
||||
if output_values[i] is not None:
|
||||
values[output.name] = output_values[i]
|
||||
return
|
||||
graph, output_mapping = get_per_layer_graph(layer, shapes, values)
|
||||
dtypes = [
|
||||
trt_dtype_to_str(layer.get_input(i).dtype)
|
||||
for i in range(layer.num_inputs)
|
||||
]
|
||||
layer_info = (f"type={cache_key[0]}, "
|
||||
f"attrs={dict(cache_key[1])}, "
|
||||
f"dtypes={dtypes}, "
|
||||
f"shapes={list(cache_key[2])}, "
|
||||
f"values={list(cache_key[3])}")
|
||||
logger.debug(f"infer shapes for layer {layer.name} ({layer_info})")
|
||||
try:
|
||||
infer_shapes(graph.as_trt(), shapes, values)
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(
|
||||
f"infer shapes failed for layer {layer.name} ({layer_info})") from e
|
||||
for proxy_output, (output, dtype) in output_mapping.items():
|
||||
shapes[output] = shapes[proxy_output]
|
||||
del shapes[proxy_output]
|
||||
if proxy_output in values:
|
||||
values[output] = [
|
||||
*map(_trt_to_type_dict[dtype], values[proxy_output])
|
||||
]
|
||||
del values[proxy_output]
|
||||
if cache is not None:
|
||||
logger.debug(
|
||||
f"shape inference cache miss, layer: {layer.name}, cache key: {cache_key}"
|
||||
)
|
||||
output_shapes = []
|
||||
output_values = []
|
||||
for i in range(layer.num_outputs):
|
||||
output = layer.get_output(i)
|
||||
output_shapes.append(shapes[output.name])
|
||||
output_values.append(values.get(output.name))
|
||||
cache[cache_key] = (output_shapes, output_values)
|
||||
|
||||
|
||||
def get_shape_info(trt_network, profile, shape_type: ShapeType = ShapeType.OPT):
|
||||
shapes = {}
|
||||
values = {}
|
||||
sorted_layer_ids = get_sorted_layer_ids(trt_network)
|
||||
infer_shape_layers = False
|
||||
|
||||
shape_network, shape_profile, shape_layers, output_mapping = get_shape_network(
|
||||
trt_network,
|
||||
shapes,
|
||||
values,
|
||||
sorted_layer_ids,
|
||||
profile=profile,
|
||||
shape_type=shape_type)
|
||||
try:
|
||||
infer_shapes(shape_network, shapes, values, shape_profile)
|
||||
for proxy_output, (output, dtype) in output_mapping.items():
|
||||
shapes[output] = shapes[proxy_output]
|
||||
values[output] = [
|
||||
*map(_trt_to_type_dict[dtype], values[proxy_output])
|
||||
]
|
||||
del shapes[proxy_output]
|
||||
del values[proxy_output]
|
||||
except RuntimeError:
|
||||
infer_shape_layers = True
|
||||
|
||||
cache = {}
|
||||
for layer_id in sorted_layer_ids:
|
||||
layer = trt_network.get_layer(layer_id)
|
||||
is_shape_io = layer.name in shape_layers
|
||||
if is_shape_io and not infer_shape_layers:
|
||||
continue
|
||||
set_trt_network(layer, trt_network)
|
||||
infer_per_layer_shapes(layer,
|
||||
shapes,
|
||||
values,
|
||||
cache,
|
||||
is_shape_io=is_shape_io)
|
||||
return ShapeInfo(shapes, values, shape_layers)
|
||||
@ -1,837 +0,0 @@
|
||||
import math
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorrt_llm.network import Network
|
||||
|
||||
from .config import AutoParallelConfig
|
||||
from .device_mesh import PhysicalDeviceMesh
|
||||
from .pipeline_graph import PipelineGraph
|
||||
from .shape_info import ShapeInfo, ShapeType, get_shape_info
|
||||
from .tensor_parallel.p2p_node import P2PType
|
||||
from .utils import get_cache_key, get_sorted_layer_ids, silent_trt_logger
|
||||
|
||||
|
||||
class StageType(Enum):
|
||||
START = 0
|
||||
BLOCK = 1
|
||||
END = 2
|
||||
|
||||
|
||||
class BuildingBlock:
|
||||
|
||||
def __init__(self, graph, layer_range) -> None:
|
||||
self.graph = graph
|
||||
self.layer_range = layer_range
|
||||
self.network = graph.as_trt()
|
||||
self.owned_inputs = {}
|
||||
self.is_edges_collected = False
|
||||
self.intra_edges = []
|
||||
self.src_inter_edges = []
|
||||
self.dst_inter_edges = []
|
||||
self.relative_src_inter_edges = []
|
||||
self.relative_dst_inter_edges = []
|
||||
self.relative_inter_edges = set()
|
||||
self.edge_hash = None
|
||||
self.outputs = None
|
||||
self.type_id = -1
|
||||
self.block_id = -1
|
||||
self.p2p_type = None
|
||||
self.is_superset = False
|
||||
self.is_subset = False
|
||||
self.sorted_layer_ids = []
|
||||
|
||||
def collect_edges(self):
|
||||
if self.is_edges_collected:
|
||||
return
|
||||
for layer_index in self.layer_range:
|
||||
trt_layer = self.network.get_layer(layer_index)
|
||||
layer = self.graph.get_layer(trt_layer.name)
|
||||
layer_offset = layer.index - self.layer_range.start
|
||||
for input_index, input in enumerate(layer.inputs):
|
||||
if input is not None:
|
||||
if input.is_graph_input:
|
||||
is_owned = input.graph_input_index in self.owned_inputs
|
||||
if not is_owned and np.all([
|
||||
layer.index in self.layer_range or np.all([
|
||||
output.as_trt().is_shape_tensor
|
||||
for output in layer.outputs
|
||||
]) for layer, _ in input.consumers
|
||||
]):
|
||||
self.owned_inputs[input.graph_input_index] = len(
|
||||
self.owned_inputs)
|
||||
is_owned = True
|
||||
if is_owned:
|
||||
self.intra_edges.append(
|
||||
(-1, self.owned_inputs[input.graph_input_index],
|
||||
layer_offset, input_index))
|
||||
else:
|
||||
self.dst_inter_edges.append(
|
||||
(-1, input.graph_input_index, layer_offset,
|
||||
input_index))
|
||||
else:
|
||||
src_layer_index = input.producer.index
|
||||
if src_layer_index < self.layer_range.start or src_layer_index >= self.layer_range.stop:
|
||||
self.dst_inter_edges.append(
|
||||
(src_layer_index, input.output_index,
|
||||
layer_offset, input_index))
|
||||
else:
|
||||
src_layer_offset = src_layer_index - self.layer_range.start
|
||||
self.intra_edges.append(
|
||||
(src_layer_offset, input.output_index,
|
||||
layer_offset, input_index))
|
||||
for output_index, output in enumerate(layer.outputs):
|
||||
for dst_layer, dst_input_index in output.consumers:
|
||||
dst_layer_index = dst_layer.index
|
||||
if dst_layer_index < self.layer_range.start or dst_layer_index >= self.layer_range.stop:
|
||||
self.src_inter_edges.append(
|
||||
(layer_offset, output_index, dst_layer_index,
|
||||
dst_input_index))
|
||||
self.edge_hash = tuple(self.intra_edges)
|
||||
self.outputs = sorted(
|
||||
set((edge[0], edge[1]) for edge in self.src_inter_edges))
|
||||
self.is_edges_collected = True
|
||||
|
||||
def collect_relative_inter_edges(self, layer_to_block):
|
||||
self.collect_edges()
|
||||
for src_layer_index, src_output_index, dst_layer_index, dst_input_index in self.dst_inter_edges:
|
||||
if src_layer_index in layer_to_block:
|
||||
src_block = layer_to_block[src_layer_index]
|
||||
src_layer_offset = src_layer_index - src_block.layer_range.start
|
||||
dst = (self.type_id, dst_layer_index, dst_input_index)
|
||||
self.relative_dst_inter_edges.append(
|
||||
(src_block.type_id, src_layer_offset, src_output_index,
|
||||
*dst))
|
||||
else:
|
||||
self.relative_dst_inter_edges.append(
|
||||
(-1, src_layer_index, src_output_index, self.type_id,
|
||||
dst_layer_index, dst_input_index))
|
||||
self.relative_inter_edges = set(self.relative_dst_inter_edges +
|
||||
self.outputs)
|
||||
|
||||
def get_input_names(self):
|
||||
self.collect_edges()
|
||||
input_tensor_names = []
|
||||
for edge in self.dst_inter_edges:
|
||||
layer_index = edge[0]
|
||||
output_index = edge[1]
|
||||
if layer_index == -1:
|
||||
tensor_name = self.network.get_input(output_index).name
|
||||
else:
|
||||
tensor_name = self.network.get_layer(layer_index).get_output(
|
||||
output_index).name
|
||||
input_tensor_names.append(tensor_name)
|
||||
return input_tensor_names
|
||||
|
||||
def get_input_mapping(self, last_blocks):
|
||||
input_mapping = {}
|
||||
for tensor_name, relative_edge in zip(self.get_input_names(),
|
||||
self.relative_dst_inter_edges):
|
||||
type_id = relative_edge[0]
|
||||
output_index = relative_edge[2]
|
||||
if type_id >= 0:
|
||||
last_block = last_blocks[type_id]
|
||||
layer_offset = relative_edge[1]
|
||||
mapped_layer_index = last_block.layer_range.start + layer_offset
|
||||
mapped_tensor_name = self.network.get_layer(
|
||||
mapped_layer_index).get_output(output_index).name
|
||||
input_mapping[tensor_name] = mapped_tensor_name
|
||||
else:
|
||||
input_mapping[tensor_name] = tensor_name
|
||||
return input_mapping
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphMapping:
|
||||
layer_mapping: Dict[int, int] = None
|
||||
block_mapping: Dict[int, int] = None
|
||||
p2p_types: Dict[int, P2PType] = None
|
||||
p2p_tensors: Dict[int, List[str]] = None
|
||||
block_to_stage: Dict[int, int] = None
|
||||
same_spec_layer_mapping: Dict[str, str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphConfig:
|
||||
num_micro_batches: int = 1
|
||||
num_blocks: int = 1
|
||||
num_stages: int = 1
|
||||
has_cross_device: bool = False
|
||||
has_cross_host: bool = False
|
||||
graph_mapping: GraphMapping = None
|
||||
phy_mesh: PhysicalDeviceMesh = None
|
||||
stage_phy_meshes: List[PhysicalDeviceMesh] = None
|
||||
|
||||
|
||||
class Simplifier:
|
||||
|
||||
def __init__(self, network: Network, config: AutoParallelConfig):
|
||||
self.config = config
|
||||
self.sharded_io_allowlist = config.sharded_io_allowlist
|
||||
self.same_buffer_io = config.same_buffer_io
|
||||
self.same_spec_io = config.same_spec_io.copy()
|
||||
for key, value in self.same_buffer_io.items():
|
||||
if key not in self.same_spec_io:
|
||||
self.same_spec_io[key] = value
|
||||
|
||||
self.llm_network = network
|
||||
self.network = network.trt_network
|
||||
self.module_to_layer_range_map = network._module_call_stack.module_to_layer_range_map
|
||||
self.graph = self.get_graph()
|
||||
self.init_layer_hash()
|
||||
|
||||
module_tree = self.get_module_tree()
|
||||
building_blocks = self.collect_building_blocks(module_tree)
|
||||
blocks_by_module_hash = self.get_blocks_by_module_hash(building_blocks)
|
||||
self.blocks_by_edge_hash = self.get_blocks_by_edge_hash(
|
||||
blocks_by_module_hash)
|
||||
self.layer_to_block = self.get_layer_to_block()
|
||||
self.blocks = self.get_all_blocks()
|
||||
self.backbone_blocks = self.get_backbone_blocks()
|
||||
self.graph_mapping_for_shape = self.get_graph_mapping_for_shape()
|
||||
self.graph_for_shape = self.create_simplified_graph_for_shape()
|
||||
self.shape_info = None
|
||||
self.num_micro_batches = None
|
||||
|
||||
def infer_shapes(self, num_micro_batches):
|
||||
if self.num_micro_batches == num_micro_batches:
|
||||
return
|
||||
with silent_trt_logger():
|
||||
self.shape_info = self.get_full_shape_info(num_micro_batches)
|
||||
self.graph.assign_shapes(self.shape_info)
|
||||
self.num_micro_batches = num_micro_batches
|
||||
|
||||
def list_all_num_micro_batches(self):
|
||||
opt_batch_size = self.get_opt_batch_size()
|
||||
candidates = []
|
||||
for num_micro_batches in range(1, self.get_opt_batch_size() + 1):
|
||||
if opt_batch_size % num_micro_batches == 0:
|
||||
candidates.append(num_micro_batches)
|
||||
return candidates
|
||||
|
||||
def get_graph(self):
|
||||
graph = PipelineGraph.from_trt(self.network)
|
||||
graph._unfilled_weights = self.llm_network._unfilled_weights.copy()
|
||||
graph._io_buffer_mapping
|
||||
for input in graph.inputs:
|
||||
input_name = input.name
|
||||
for pattern, repl in self.same_buffer_io.items():
|
||||
if re.match(pattern, input_name):
|
||||
output_name = re.sub(pattern, repl, input_name)
|
||||
output = graph.get_output(output_name)
|
||||
if output is not None:
|
||||
graph._io_buffer_mapping[output_name] = input_name
|
||||
return graph
|
||||
|
||||
def get_opt_batch_size(self):
|
||||
input_tensors = self.llm_network._inputs
|
||||
num_profiles = len(list(input_tensors.values())[0].profiles)
|
||||
opt_batch_sizes = []
|
||||
for i in range(num_profiles):
|
||||
for input_tensor in input_tensors.values():
|
||||
shape_profile = input_tensor.profiles[i]
|
||||
opt_shape = shape_profile.opt
|
||||
for j in range(len(input_tensor.shape)):
|
||||
name = input_tensor.trt_tensor.get_dimension_name(j)
|
||||
if name == 'batch_size':
|
||||
opt_batch_sizes.append(opt_shape[j])
|
||||
return min(opt_batch_sizes)
|
||||
|
||||
def get_module_hash(self, layer_range):
|
||||
module_hash = ()
|
||||
for i in layer_range:
|
||||
assert i < self.network.num_layers, f"layer index {i} in {layer_range} out of range of {self.network.num_layers}"
|
||||
layer_name = self.network.get_layer(i).name
|
||||
layer = self.graph.get_layer(layer_name)
|
||||
module_hash += (layer.attrs["hash"], )
|
||||
return module_hash
|
||||
|
||||
def get_network_hash(self) -> str:
|
||||
return str(self.get_module_hash(range(self.network.num_layers)))
|
||||
|
||||
def collect_building_blocks(self, module_tree):
|
||||
building_blocks = {}
|
||||
queue = []
|
||||
for tree in module_tree["children"].values():
|
||||
queue.append(tree)
|
||||
while len(queue) > 0:
|
||||
while len(queue) > 0:
|
||||
tree = queue.pop(0)
|
||||
module_name = tree["name"]
|
||||
if module_name is None:
|
||||
for child in tree["children"].values():
|
||||
queue.append(child)
|
||||
continue
|
||||
layer_range = self.module_to_layer_range_map[module_name]
|
||||
module_hash = self.get_module_hash(layer_range)
|
||||
if module_hash in building_blocks:
|
||||
building_blocks[module_hash].append(tree)
|
||||
else:
|
||||
building_blocks[module_hash] = [tree]
|
||||
for module_hash in [*building_blocks.keys()]:
|
||||
if len(building_blocks[module_hash]) == 1:
|
||||
tree = building_blocks[module_hash][0]
|
||||
for child in tree["children"].values():
|
||||
queue.append(child)
|
||||
del building_blocks[module_hash]
|
||||
blocks_by_module_hash = {
|
||||
module_hash: [
|
||||
BuildingBlock(self.graph,
|
||||
self.module_to_layer_range_map[tree["name"]])
|
||||
for tree in trees
|
||||
]
|
||||
for module_hash, trees in building_blocks.items()
|
||||
}
|
||||
building_blocks = []
|
||||
for block_list in blocks_by_module_hash.values():
|
||||
for block in block_list:
|
||||
building_blocks.append(block)
|
||||
building_blocks = sorted(building_blocks,
|
||||
key=lambda x: x.layer_range.start)
|
||||
if len(building_blocks) >= 2:
|
||||
for block, next_block in zip(building_blocks[:-1],
|
||||
building_blocks[1:]):
|
||||
block.layer_range = range(block.layer_range.start,
|
||||
next_block.layer_range.start)
|
||||
return building_blocks
|
||||
|
||||
def get_all_blocks(self):
|
||||
building_blocks = []
|
||||
for block_list in self.blocks_by_edge_hash.values():
|
||||
for block in block_list:
|
||||
building_blocks.append(block)
|
||||
building_blocks = sorted(building_blocks,
|
||||
key=lambda x: x.layer_range.start)
|
||||
all_blocks = []
|
||||
current_layer_index = 0
|
||||
block_id = 0
|
||||
for block in building_blocks:
|
||||
assert current_layer_index <= block.layer_range.start
|
||||
if current_layer_index < block.layer_range.start:
|
||||
new_block = BuildingBlock(
|
||||
self.graph,
|
||||
range(current_layer_index, block.layer_range.start))
|
||||
new_block.block_id = block_id
|
||||
block_id += 1
|
||||
all_blocks.append(new_block)
|
||||
block.block_id = block_id
|
||||
block_id += 1
|
||||
all_blocks.append(block)
|
||||
current_layer_index = block.layer_range.stop
|
||||
if current_layer_index < self.graph.num_layers:
|
||||
new_block = BuildingBlock(
|
||||
self.graph, range(current_layer_index, self.graph.num_layers))
|
||||
new_block.block_id = block_id
|
||||
all_blocks.append(new_block)
|
||||
sorted_layer_ids = get_sorted_layer_ids(self.network)
|
||||
for block in all_blocks:
|
||||
block.collect_relative_inter_edges(self.layer_to_block)
|
||||
for layer_id in sorted_layer_ids:
|
||||
if layer_id in block.layer_range:
|
||||
block.sorted_layer_ids.append(layer_id)
|
||||
return all_blocks
|
||||
|
||||
def get_backbone_blocks(self):
|
||||
sorted_blocks = sorted(
|
||||
self.blocks_by_edge_hash.values(),
|
||||
key=lambda blocks: (len(blocks), len(blocks[0].layer_range)),
|
||||
)
|
||||
if len(sorted_blocks) == 0:
|
||||
return []
|
||||
else:
|
||||
return sorted_blocks[-1]
|
||||
|
||||
def get_blocks_by_module_hash(self, blocks):
|
||||
blocks_by_module_hash = {}
|
||||
for block in blocks:
|
||||
module_hash = self.get_module_hash(block.layer_range)
|
||||
if module_hash not in blocks_by_module_hash:
|
||||
blocks_by_module_hash[module_hash] = []
|
||||
blocks_by_module_hash[module_hash].append(block)
|
||||
for module_hash in [*blocks_by_module_hash.keys()]:
|
||||
if len(blocks_by_module_hash[module_hash]) == 1:
|
||||
del blocks_by_module_hash[module_hash]
|
||||
return blocks_by_module_hash
|
||||
|
||||
def get_module_tree(self):
|
||||
module_tree = {"children": {}, "name": None}
|
||||
for module_name in self.module_to_layer_range_map.keys():
|
||||
full_name = module_name.split('.')
|
||||
current_tree = module_tree["children"]
|
||||
for depth, name in enumerate(full_name):
|
||||
if name not in current_tree:
|
||||
current_tree[name] = {"children": {}, "name": None}
|
||||
if depth == len(full_name) - 1:
|
||||
current_tree[name]["name"] = module_name
|
||||
else:
|
||||
current_tree = current_tree[name]["children"]
|
||||
return module_tree
|
||||
|
||||
def get_blocks_by_edge_hash(self, blocks_by_module_hash):
|
||||
blocks_by_edge_hash = {}
|
||||
for block_list in blocks_by_module_hash.values():
|
||||
for block in block_list:
|
||||
block.collect_edges()
|
||||
edge_hash = block.edge_hash
|
||||
if edge_hash not in blocks_by_edge_hash:
|
||||
blocks_by_edge_hash[edge_hash] = []
|
||||
blocks_by_edge_hash[edge_hash].append(block)
|
||||
for edge_hash in [*blocks_by_edge_hash.keys()]:
|
||||
if len(blocks_by_edge_hash[edge_hash]) == 1:
|
||||
del blocks_by_edge_hash[edge_hash]
|
||||
else:
|
||||
block_list = blocks_by_edge_hash[edge_hash]
|
||||
blocks_by_edge_hash[edge_hash] = sorted(
|
||||
block_list, key=lambda x: x.layer_range.start)
|
||||
for type_id, block_list in enumerate(blocks_by_edge_hash.values()):
|
||||
for block in block_list:
|
||||
block.type_id = type_id
|
||||
return blocks_by_edge_hash
|
||||
|
||||
def get_layer_to_block(self):
|
||||
layer_to_block = {}
|
||||
for block_list in self.blocks_by_edge_hash.values():
|
||||
for block in block_list:
|
||||
for layer_index in block.layer_range:
|
||||
layer_to_block[layer_index] = block
|
||||
return layer_to_block
|
||||
|
||||
def clean_blocks(self):
|
||||
for block in self.blocks:
|
||||
block.p2p_type = None
|
||||
block.is_superset = False
|
||||
block.is_subset = False
|
||||
|
||||
def mark_p2p_type(self, phy_mesh, stage_phy_meshes,
|
||||
graph_config: GraphConfig):
|
||||
if len(self.backbone_blocks) == 0 or len(stage_phy_meshes) == 1:
|
||||
return
|
||||
assert len(self.backbone_blocks) % len(stage_phy_meshes) == 0
|
||||
block_per_stage = len(self.backbone_blocks) // len(stage_phy_meshes)
|
||||
|
||||
for block in self.backbone_blocks:
|
||||
block.p2p_type = None
|
||||
for stage_index, stage_phy_mesh in enumerate(stage_phy_meshes[:-1]):
|
||||
next_stage_phy_mesh = stage_phy_meshes[stage_index + 1]
|
||||
last_device_id = stage_phy_mesh.phy_devices_id.flatten()[-1]
|
||||
next_first_device_id = next_stage_phy_mesh.phy_devices_id.flatten(
|
||||
)[0]
|
||||
num_devices_per_host = phy_mesh.num_devices_per_host
|
||||
next_block = self.backbone_blocks[(stage_index + 1) *
|
||||
block_per_stage]
|
||||
if last_device_id // num_devices_per_host != next_first_device_id // num_devices_per_host:
|
||||
next_block.p2p_type = P2PType.CROSS_HOST
|
||||
graph_config.has_cross_host = True
|
||||
else:
|
||||
next_block.p2p_type = P2PType.CROSS_DEVICE
|
||||
graph_config.has_cross_device = True
|
||||
|
||||
def get_graph_mapping(self):
|
||||
layer_mapping = {}
|
||||
block_mapping = {}
|
||||
p2p_types = {}
|
||||
p2p_tensors = {}
|
||||
for block_list in self.blocks_by_edge_hash.values():
|
||||
superset_blocks = []
|
||||
superset_block_index = {}
|
||||
for block in block_list:
|
||||
block_added = False
|
||||
for index, superset_block in enumerate(list(superset_blocks)):
|
||||
if block.p2p_type == superset_block.p2p_type:
|
||||
if block.relative_inter_edges.issubset(
|
||||
superset_block.relative_inter_edges):
|
||||
block.is_subset = True
|
||||
block.is_superset = False
|
||||
superset_block_index[id(block)] = index
|
||||
block_added = True
|
||||
break
|
||||
elif superset_block.relative_inter_edges.issubset(
|
||||
block.relative_inter_edges):
|
||||
superset_block.is_subset = True
|
||||
superset_block.is_superset = False
|
||||
block.is_subset = False
|
||||
block.is_superset = True
|
||||
superset_blocks[index] = block
|
||||
superset_block_index[id(block)] = index
|
||||
block_added = True
|
||||
break
|
||||
if not block_added:
|
||||
block.is_subset = False
|
||||
block.is_superset = True
|
||||
superset_blocks.append(block)
|
||||
superset_block_index[id(block)] = len(superset_blocks) - 1
|
||||
for block in block_list:
|
||||
assert not (block.is_subset and block.is_superset)
|
||||
if block.is_subset:
|
||||
superset_block = superset_blocks[superset_block_index[id(
|
||||
block)]]
|
||||
block_mapping[block.block_id] = superset_block.block_id
|
||||
owned_inputs = map(
|
||||
lambda x: x[0],
|
||||
sorted(block.owned_inputs.items(), key=lambda x: x[1]))
|
||||
superset_owned_inputs = map(
|
||||
lambda x: x[0],
|
||||
sorted(superset_block.owned_inputs.items(),
|
||||
key=lambda x: x[1]))
|
||||
for from_input_id, to_input_id in zip(
|
||||
owned_inputs, superset_owned_inputs):
|
||||
from_input_name = self.network.get_input(
|
||||
from_input_id).name
|
||||
to_input_name = self.network.get_input(to_input_id).name
|
||||
layer_mapping[from_input_name] = to_input_name
|
||||
for from_layer_id, to_layer_id in zip(
|
||||
block.layer_range, superset_block.layer_range):
|
||||
from_layer = self.network.get_layer(from_layer_id)
|
||||
to_layer = self.network.get_layer(to_layer_id)
|
||||
layer_mapping[from_layer.name] = to_layer.name
|
||||
for i in range(from_layer.num_outputs):
|
||||
from_output = from_layer.get_output(i)
|
||||
if from_output.is_network_output:
|
||||
to_output = to_layer.get_output(i)
|
||||
layer_mapping[from_output.name] = to_output.name
|
||||
if block.p2p_type is not None:
|
||||
p2p_types[block.block_id] = block.p2p_type
|
||||
p2p_tensors[block.block_id] = [
|
||||
*set(block.get_input_names())
|
||||
]
|
||||
for from_name, to_name in zip(
|
||||
block.get_input_names(),
|
||||
superset_block.get_input_names()):
|
||||
layer_mapping[
|
||||
f"p2p_block{block.block_id}_{from_name}"] = f"p2p_block{superset_block.block_id}_{to_name}"
|
||||
stage_id = 0
|
||||
block_to_stage = {}
|
||||
for block in self.blocks:
|
||||
if block.p2p_type is not None:
|
||||
stage_id += 1
|
||||
block_to_stage[block.block_id] = stage_id
|
||||
return GraphMapping(
|
||||
layer_mapping,
|
||||
block_mapping,
|
||||
p2p_types,
|
||||
p2p_tensors,
|
||||
block_to_stage,
|
||||
)
|
||||
|
||||
def create_simplified_graph(self, graph_config: GraphConfig):
|
||||
new_graph = PipelineGraph.create_graph()
|
||||
new_graph._io_buffer_mapping = self.graph._io_buffer_mapping
|
||||
layer_mapping = graph_config.graph_mapping.layer_mapping
|
||||
|
||||
for i in range(self.network.num_inputs):
|
||||
trt_input = self.network.get_input(i)
|
||||
if trt_input.name not in layer_mapping:
|
||||
new_graph.add_input(trt_input)
|
||||
|
||||
last_blocks = {}
|
||||
same_spec_mapping = {}
|
||||
same_spec_layer_mapping = {}
|
||||
shape_mapping = {}
|
||||
building_block_id = 0
|
||||
same_spec_ids = {}
|
||||
same_spec_count = 0
|
||||
for block in self.blocks:
|
||||
if not block.is_subset:
|
||||
stage_type = None
|
||||
if not block.is_superset:
|
||||
if block.block_id == 0:
|
||||
stage_type = StageType.START
|
||||
elif block.block_id == len(self.blocks) - 1:
|
||||
stage_type = StageType.END
|
||||
input_mapping = block.get_input_mapping(last_blocks)
|
||||
for from_name, to_name in [*input_mapping.items()]:
|
||||
if to_name in same_spec_mapping:
|
||||
input_mapping[from_name] = same_spec_mapping[to_name]
|
||||
if to_name in layer_mapping:
|
||||
input_mapping[from_name] = layer_mapping[to_name]
|
||||
if block.is_superset and block.p2p_type is not None:
|
||||
for from_name, to_name in [*input_mapping.items()]:
|
||||
output_tensor = new_graph.get_tensor(to_name)
|
||||
p2p_layer = new_graph.as_trt().add_identity(
|
||||
output_tensor.as_trt())
|
||||
p2p_layer.name = f"p2p_block{block.block_id}_{from_name}"
|
||||
p2p_layer.metadata = p2p_layer.name
|
||||
p2p_tensor = p2p_layer.get_output(0)
|
||||
p2p_tensor.name = f"{p2p_layer.name}_output"
|
||||
wrapped_layer = new_graph.register_layer(p2p_layer)
|
||||
wrapped_layer.attrs[
|
||||
"building_block_id"] = building_block_id
|
||||
wrapped_layer.attrs["p2p_type"] = block.p2p_type
|
||||
input_mapping[from_name] = p2p_tensor.name
|
||||
shape_mapping[p2p_tensor.name] = from_name
|
||||
building_block_id += 1
|
||||
for i in block.sorted_layer_ids:
|
||||
layer = self.network.get_layer(i)
|
||||
wrapped_layer = new_graph.add_layer(
|
||||
layer,
|
||||
input_mapping=input_mapping,
|
||||
)
|
||||
wrapped_layer.attrs["building_block_id"] = building_block_id
|
||||
wrapped_layer.attrs["stage_type"] = stage_type
|
||||
if block.is_superset:
|
||||
last_blocks[block.type_id] = block
|
||||
|
||||
if block.type_id in same_spec_ids:
|
||||
same_spec_id = same_spec_ids[block.type_id]
|
||||
update_same_spec_count = False
|
||||
else:
|
||||
same_spec_id = same_spec_count
|
||||
same_spec_ids[block.type_id] = same_spec_id
|
||||
update_same_spec_count = True
|
||||
count = same_spec_id
|
||||
for i, (layer_offset,
|
||||
output_index) in enumerate(block.outputs):
|
||||
layer = self.network.get_layer(block.layer_range.start +
|
||||
layer_offset)
|
||||
tensor_name = layer.get_output(output_index).name
|
||||
output_tensor = new_graph.get_tensor(tensor_name)
|
||||
same_spec_layer = new_graph.as_trt().add_identity(
|
||||
output_tensor.as_trt())
|
||||
same_spec_layer.name = f"{tensor_name}_same_spec"
|
||||
same_spec_layer.metadata = same_spec_layer.name
|
||||
same_spec_tensor = same_spec_layer.get_output(0)
|
||||
same_spec_tensor.name = f"{same_spec_layer.name}_output"
|
||||
wrapped_layer = new_graph.register_layer(
|
||||
same_spec_layer)
|
||||
wrapped_layer.attrs[
|
||||
"building_block_id"] = building_block_id
|
||||
wrapped_layer.attrs["same_spec_id"] = count
|
||||
count += 1
|
||||
same_spec_mapping[tensor_name] = same_spec_tensor.name
|
||||
same_spec_layer_mapping[
|
||||
same_spec_layer.name] = layer.name
|
||||
shape_mapping[same_spec_tensor.name] = tensor_name
|
||||
for i, graph_input_index in enumerate(
|
||||
block.owned_inputs.keys()):
|
||||
input_name = self.network.get_input(
|
||||
graph_input_index).name
|
||||
input_tensor = new_graph.get_input(input_name)
|
||||
input_tensor.attrs["same_spec_id"] = count
|
||||
count += 1
|
||||
if update_same_spec_count:
|
||||
same_spec_count = count
|
||||
building_block_id += 1
|
||||
graph_config.graph_mapping.same_spec_layer_mapping = same_spec_layer_mapping
|
||||
|
||||
if len(self.backbone_blocks) >= 2:
|
||||
start_block = self.backbone_blocks[0]
|
||||
if start_block.is_subset:
|
||||
start_block = self.blocks[graph_config.graph_mapping.
|
||||
block_mapping[start_block.block_id]]
|
||||
for i in start_block.layer_range:
|
||||
layer_name = self.network.get_layer(i).name
|
||||
layer = new_graph.get_layer(layer_name)
|
||||
layer.attrs["in_start_block"] = True
|
||||
end_block = self.backbone_blocks[-1]
|
||||
if end_block.is_subset:
|
||||
end_block = self.blocks[graph_config.graph_mapping.
|
||||
block_mapping[end_block.block_id]]
|
||||
for i in end_block.layer_range:
|
||||
layer_name = self.network.get_layer(i).name
|
||||
layer = new_graph.get_layer(layer_name)
|
||||
layer.attrs["in_end_block"] = True
|
||||
slowest_p2p_type = None
|
||||
if graph_config.has_cross_host:
|
||||
slowest_p2p_type = P2PType.CROSS_HOST
|
||||
elif graph_config.has_cross_device:
|
||||
slowest_p2p_type = P2PType.CROSS_DEVICE
|
||||
if slowest_p2p_type is not None:
|
||||
for block in self.blocks:
|
||||
if block.is_superset and block.p2p_type == slowest_p2p_type:
|
||||
for i in block.layer_range:
|
||||
layer_name = self.network.get_layer(i).name
|
||||
layer = new_graph.get_layer(layer_name)
|
||||
layer.attrs["in_slowest_block"] = True
|
||||
|
||||
for i in range(self.network.num_outputs):
|
||||
trt_output = self.network.get_output(i)
|
||||
output = self.graph.get_output(trt_output.name)
|
||||
if output.producer is not None and output.producer.index in self.layer_to_block and self.layer_to_block[
|
||||
output.producer.index].is_subset:
|
||||
continue
|
||||
if trt_output.is_shape_tensor:
|
||||
new_output = new_graph.add_output_shape(trt_output)
|
||||
else:
|
||||
new_output = new_graph.add_output(trt_output)
|
||||
sharded_io = False
|
||||
for pattern in self.sharded_io_allowlist:
|
||||
if re.match(pattern, new_output.name):
|
||||
sharded_io = True
|
||||
break
|
||||
if not sharded_io:
|
||||
new_output.producer.attrs["is_replicated"] = True
|
||||
|
||||
for input in new_graph.inputs:
|
||||
input_name = input.name
|
||||
sharded_io = False
|
||||
for pattern in self.sharded_io_allowlist:
|
||||
if re.match(pattern, input_name):
|
||||
sharded_io = True
|
||||
break
|
||||
if not sharded_io:
|
||||
input.attrs["is_replicated"] = True
|
||||
for pattern, repl in self.same_spec_io.items():
|
||||
if re.match(pattern, input_name):
|
||||
output_name = re.sub(pattern, repl, input_name)
|
||||
output = new_graph.get_output(output_name)
|
||||
if output is not None:
|
||||
if "same_spec_id" in input.attrs:
|
||||
same_spec_id = input.attrs["same_spec_id"]
|
||||
else:
|
||||
same_spec_id = same_spec_count
|
||||
same_spec_count += 1
|
||||
input.attrs["same_spec_id"] = same_spec_id
|
||||
output.attrs["same_spec_id"] = same_spec_id
|
||||
if math.prod(self.graph.get_input(
|
||||
input_name).shape) < math.prod(
|
||||
self.graph.get_output(output_name).shape):
|
||||
input.attrs["no_memory_footprint"] = True
|
||||
else:
|
||||
output.attrs["no_memory_footprint"] = True
|
||||
|
||||
return new_graph, shape_mapping
|
||||
|
||||
def enrich_shape_info(self, shape_mapping):
|
||||
shapes = self.shape_info.shapes.copy()
|
||||
max_shapes = self.shape_info.max_shapes.copy()
|
||||
values = self.shape_info.values.copy()
|
||||
shape_layers = self.shape_info.shape_layers
|
||||
for from_name, to_name in shape_mapping.items():
|
||||
if to_name in shapes:
|
||||
shapes[from_name] = shapes[to_name]
|
||||
if to_name in max_shapes:
|
||||
max_shapes[from_name] = max_shapes[to_name]
|
||||
if to_name in values:
|
||||
values[from_name] = values[to_name]
|
||||
shape_info = ShapeInfo(shapes, values, shape_layers, max_shapes)
|
||||
return shape_info
|
||||
|
||||
def simplify_graph(
|
||||
self, phy_mesh: PhysicalDeviceMesh, num_stages: int,
|
||||
num_devices_per_stage: int) -> Tuple[PipelineGraph, GraphConfig]:
|
||||
num_blocks = len(self.backbone_blocks)
|
||||
if num_blocks % num_stages != 0:
|
||||
return None, None
|
||||
graph_config = GraphConfig()
|
||||
graph_config.num_micro_batches = self.num_micro_batches
|
||||
graph_config.num_blocks = num_blocks
|
||||
graph_config.num_stages = num_stages
|
||||
graph_config.phy_mesh = phy_mesh
|
||||
stage_phy_meshes = phy_mesh.split_pipeline_meshes(
|
||||
num_stages, num_devices_per_stage)
|
||||
graph_config.stage_phy_meshes = stage_phy_meshes
|
||||
with silent_trt_logger():
|
||||
self.clean_blocks()
|
||||
self.mark_p2p_type(phy_mesh, stage_phy_meshes, graph_config)
|
||||
graph_config.graph_mapping = self.get_graph_mapping()
|
||||
new_graph, shape_mapping = self.create_simplified_graph(
|
||||
graph_config)
|
||||
shape_info = self.enrich_shape_info(shape_mapping)
|
||||
new_graph.assign_shapes(shape_info)
|
||||
return new_graph, graph_config
|
||||
|
||||
def get_graph_mapping_for_shape(self):
|
||||
layer_mapping = {}
|
||||
tensor_mapping = {}
|
||||
for block_list in self.blocks_by_edge_hash.values():
|
||||
head_block = block_list[0]
|
||||
for block in block_list[1:]:
|
||||
for from_layer_id, to_layer_id in zip(block.layer_range,
|
||||
head_block.layer_range):
|
||||
from_layer = self.network.get_layer(from_layer_id)
|
||||
to_layer = self.network.get_layer(to_layer_id)
|
||||
layer_mapping[from_layer.name] = to_layer.name
|
||||
for i in range(from_layer.num_outputs):
|
||||
tensor_mapping[from_layer.get_output(
|
||||
i).name] = to_layer.get_output(i).name
|
||||
return layer_mapping, tensor_mapping
|
||||
|
||||
def create_simplified_graph_for_shape(self):
|
||||
new_graph = PipelineGraph.create_graph()
|
||||
|
||||
for i in range(self.network.num_inputs):
|
||||
trt_input = self.network.get_input(i)
|
||||
new_graph.add_input(trt_input)
|
||||
|
||||
head_blocks = {}
|
||||
removed_blocks = set()
|
||||
removed_layers = set()
|
||||
for block_list in self.blocks_by_edge_hash.values():
|
||||
head_block = block_list[0]
|
||||
head_blocks[head_block.type_id] = head_block
|
||||
for block in block_list[1:]:
|
||||
removed_blocks.add(id(block))
|
||||
for layer_index in block.layer_range:
|
||||
removed_layers.add(layer_index)
|
||||
|
||||
for block in self.blocks:
|
||||
if not id(block) in removed_blocks:
|
||||
input_mapping = block.get_input_mapping(head_blocks)
|
||||
for i in block.sorted_layer_ids:
|
||||
layer = self.network.get_layer(i)
|
||||
new_graph.add_layer(
|
||||
layer,
|
||||
input_mapping=input_mapping,
|
||||
)
|
||||
|
||||
for i in range(self.network.num_outputs):
|
||||
trt_output = self.network.get_output(i)
|
||||
output = self.graph.get_output(trt_output.name)
|
||||
if output.producer is not None and output.producer.index in removed_layers:
|
||||
continue
|
||||
if trt_output.is_shape_tensor:
|
||||
new_graph.add_output_shape(trt_output)
|
||||
else:
|
||||
new_graph.add_output(trt_output)
|
||||
|
||||
return new_graph
|
||||
|
||||
def get_full_shape_info(self, num_micro_batches):
|
||||
layer_mapping, tensor_mapping = self.graph_mapping_for_shape
|
||||
optimization_profiles = self.llm_network._generate_optimization_profiles(
|
||||
)
|
||||
if len(optimization_profiles) > 0:
|
||||
optimization_profile = optimization_profiles[-1]
|
||||
else:
|
||||
optimization_profile = None
|
||||
shape_info = get_shape_info(self.graph_for_shape.as_trt(),
|
||||
optimization_profile)
|
||||
max_shape_info = get_shape_info(self.graph_for_shape.as_trt(),
|
||||
optimization_profile,
|
||||
shape_type=ShapeType.MAX)
|
||||
shape_info.max_shapes = max_shape_info.shapes
|
||||
for removed_tensor_name, tensor_name in tensor_mapping.items():
|
||||
shape_info.shapes[removed_tensor_name] = shape_info.shapes[
|
||||
tensor_name]
|
||||
shape_info.max_shapes[removed_tensor_name] = shape_info.max_shapes[
|
||||
tensor_name]
|
||||
if tensor_name in shape_info.values:
|
||||
shape_info.values[removed_tensor_name] = shape_info.values[
|
||||
tensor_name]
|
||||
for removed_layer_name, layer_name in layer_mapping.items():
|
||||
if layer_name in shape_info.shape_layers:
|
||||
shape_info.shape_layers.add(removed_layer_name)
|
||||
return shape_info
|
||||
|
||||
def init_layer_hash(self):
|
||||
with silent_trt_logger():
|
||||
optimization_profiles = self.llm_network._generate_optimization_profiles(
|
||||
)
|
||||
if len(optimization_profiles) > 0:
|
||||
optimization_profile = optimization_profiles[-1]
|
||||
else:
|
||||
optimization_profile = None
|
||||
shape_info = get_shape_info(self.network, optimization_profile)
|
||||
dtypes = {tensor.name: tensor.dtype for tensor in self.graph.tensors}
|
||||
for layer in self.graph.layers:
|
||||
layer_hash = get_cache_key(
|
||||
layer.as_trt(),
|
||||
shape_info.shapes,
|
||||
shape_info.values,
|
||||
dtypes,
|
||||
)
|
||||
layer.attrs["hash"] = layer_hash
|
||||
@ -1,641 +0,0 @@
|
||||
"""This code is adapted from Alpa https://github.com/alpa-projects/alpa/ with some changes.
|
||||
"""
|
||||
import multiprocessing
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import pulp
|
||||
from pulp import LpMinimize, LpProblem, LpVariable, lpDot, lpSum
|
||||
|
||||
from ..logger import logger
|
||||
|
||||
|
||||
class Solution:
|
||||
|
||||
def __init__(self, leaf_strategies, s_val, e_val, edge_pairs,
|
||||
node_index_dict, total_cost):
|
||||
self.leaf_strategies = leaf_strategies
|
||||
self.nodes = [
|
||||
strategies_vector.node for strategies_vector in self.leaf_strategies
|
||||
]
|
||||
self.s_val = s_val
|
||||
self.e_val = e_val
|
||||
self.total_cost = total_cost
|
||||
self.edge_pairs = list(np.reshape(edge_pairs, (-1, 2)))
|
||||
self.node_index_dict = node_index_dict
|
||||
self.index_node_dict = {}
|
||||
for node, index in self.node_index_dict.items():
|
||||
self.index_node_dict[index] = node
|
||||
self.node_best_strategy = {}
|
||||
self._annotate_strategy()
|
||||
|
||||
def _annotate_strategy(self):
|
||||
self.node_best_strategy = {}
|
||||
for index, node in enumerate(self.nodes):
|
||||
best_strategy_id = self.s_val[index]
|
||||
best_strategy = self.leaf_strategies[index][best_strategy_id]
|
||||
self.node_best_strategy[node.node_name] = best_strategy
|
||||
|
||||
for edge_idx, edge_pair in enumerate(self.edge_pairs):
|
||||
src_node = self.index_node_dict[edge_pair[0]]
|
||||
dst_node = self.index_node_dict[edge_pair[1]]
|
||||
src_node_index = self.node_index_dict[src_node]
|
||||
for dst_pre_node in dst_node.predecessor_nodes:
|
||||
if dst_pre_node is None:
|
||||
continue
|
||||
if src_node.node_name == dst_pre_node.node_name:
|
||||
self.node_best_strategy[
|
||||
dst_node.node_name].best_resharding_cost[
|
||||
src_node.node_name] = [
|
||||
self.node_best_strategy[dst_node.node_name].
|
||||
resharding_costs[src_node.node_name][
|
||||
self.s_val[src_node_index]]
|
||||
]
|
||||
|
||||
def print_solution(self):
|
||||
for index, node in enumerate(self.nodes):
|
||||
best_strategy = self.node_best_strategy[node.node_name]
|
||||
print(f'\n[{index}]: node_name = {node.node_name}')
|
||||
best_strategy.print_strategy(best_resharding_cost_only=True)
|
||||
print(f'solution total cost = {self.total_cost}')
|
||||
|
||||
|
||||
class CostGraph:
|
||||
'''
|
||||
A graph data structure to simplify the edge cost graph. It has two main functions:
|
||||
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
|
||||
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
|
||||
2. To reduce the searching space, we merge computationally-trivial operators, such as
|
||||
element-wise operators, transpose, and reduction, into their following nodes. The merging information will
|
||||
be given by the StrategiesVector depending on the type of target node and following nodes.
|
||||
|
||||
Argument:
|
||||
leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
|
||||
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
|
||||
'''
|
||||
|
||||
def __init__(self, leaf_strategies):
|
||||
self.leaf_strategies = leaf_strategies
|
||||
self.nodes = [
|
||||
strategies_vector.node for strategies_vector in leaf_strategies
|
||||
]
|
||||
# stores number of strategies in each node
|
||||
self.node_strategies_vector = {}
|
||||
for node, strategies_vector in zip(self.nodes, self.leaf_strategies):
|
||||
self.node_strategies_vector[node] = strategies_vector
|
||||
# extra_node_costs will store the extra costs introduced by merging nodes
|
||||
self.extra_node_costs = {}
|
||||
self.following_dict = {}
|
||||
self._build_cost_graph()
|
||||
|
||||
def _remove_invalid_node(self, node, attr_name):
|
||||
remove_list = []
|
||||
target_node_list = getattr(node, attr_name, [])
|
||||
for target_node in target_node_list:
|
||||
if target_node not in self.nodes:
|
||||
remove_list.append(target_node)
|
||||
for element in remove_list:
|
||||
target_node_list.remove(element)
|
||||
|
||||
def _build_cost_graph(self):
|
||||
'''
|
||||
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
|
||||
set to node.
|
||||
'''
|
||||
self.edge_costs = {}
|
||||
for dst_node, strategies_vector in zip(self.nodes,
|
||||
self.leaf_strategies):
|
||||
# build edge_cost
|
||||
for src_node in dst_node.predecessor_nodes:
|
||||
if src_node is None:
|
||||
continue
|
||||
if src_node not in self.nodes:
|
||||
continue
|
||||
node_pair = (src_node, dst_node)
|
||||
edge_cost = {}
|
||||
for i in range(len(strategies_vector)):
|
||||
for j in range(len(self.node_strategies_vector[src_node])):
|
||||
resharding_cost = strategies_vector[i].resharding_costs[
|
||||
src_node.node_name][j][-1]
|
||||
edge_cost[(j, i)] = resharding_cost
|
||||
self.edge_costs[node_pair] = edge_cost
|
||||
|
||||
def get_edge_cost(self, src_node, dst_node):
|
||||
return self.edge_costs[(src_node, dst_node)]
|
||||
|
||||
|
||||
class Solver:
|
||||
INFINITY_COST = 1e13
|
||||
|
||||
def __init__(self,
|
||||
cost_graph: CostGraph,
|
||||
memory_budget: float = -1.0,
|
||||
solution_numbers: int = 1,
|
||||
memory_increasing_coefficient: float = 1.3,
|
||||
verbose=False):
|
||||
'''
|
||||
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
|
||||
Argument:
|
||||
graph: The computing graph to be optimized.
|
||||
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
|
||||
cost_graph: A graph data structure to simplify the edge cost graph.
|
||||
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
|
||||
memory_budget: Memory constraint for the solution.
|
||||
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
|
||||
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
|
||||
'''
|
||||
self.cost_graph = cost_graph
|
||||
self.leaf_strategies = cost_graph.leaf_strategies
|
||||
self.nodes = cost_graph.nodes
|
||||
self.memory_budget = memory_budget
|
||||
self.solution_numbers = solution_numbers
|
||||
if self.solution_numbers > 1:
|
||||
self.memory_increasing_coefficient = memory_increasing_coefficient
|
||||
else:
|
||||
self.memory_increasing_coefficient = 1
|
||||
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
|
||||
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
|
||||
# self.liveness_list = self.graph_analyser.liveness_analysis()
|
||||
self.liveness_list = self.nodes
|
||||
self.node_index_dict = self._generate_node_index_dict()
|
||||
# The last solution vector of auto sharding.
|
||||
self.last_s_val = None
|
||||
# The last objective value of the best ILP solution.
|
||||
self.last_objective = None
|
||||
self.verbose = verbose
|
||||
|
||||
def _generate_node_index_dict(self):
|
||||
node_index_dict = {}
|
||||
for index, node in enumerate(self.nodes):
|
||||
node_index_dict[node] = index
|
||||
return node_index_dict
|
||||
|
||||
def _prepare_data_for_solver(self):
|
||||
'''
|
||||
Extract information from components for solver.
|
||||
'''
|
||||
node_nums = len(self.leaf_strategies)
|
||||
memory_budget = self.memory_budget
|
||||
|
||||
# prepare strategies_len
|
||||
strategies_len = []
|
||||
for node in self.nodes:
|
||||
strategies_len.append(
|
||||
len(self.cost_graph.node_strategies_vector[node]))
|
||||
strategies_len = np.array(strategies_len)
|
||||
|
||||
# prepare edge_pairs and resharding costs
|
||||
edge_pairs = []
|
||||
resharding_costs = []
|
||||
edge_cost_level = []
|
||||
edge_resharding_weights = []
|
||||
for pairs, edge_cost in self.cost_graph.edge_costs.items():
|
||||
src_node = pairs[0]
|
||||
dst_node = pairs[1]
|
||||
src_node_index = self.node_index_dict[src_node]
|
||||
dst_node_index = self.node_index_dict[dst_node]
|
||||
edge_pairs.append(src_node_index)
|
||||
edge_pairs.append(dst_node_index)
|
||||
edge_cost_level.append(
|
||||
(dst_node.building_block_id, dst_node.cost_level))
|
||||
for i in range(strategies_len[src_node_index]):
|
||||
for j in range(strategies_len[dst_node_index]):
|
||||
resharding_costs.append(edge_cost[(i, j)])
|
||||
edge_resharding_weights.append(dst_node.resharding_weight +
|
||||
dst_node.pipeline_weight)
|
||||
edge_pairs = np.array(edge_pairs)
|
||||
resharding_costs = np.array(resharding_costs)
|
||||
edge_resharding_weights = np.array(edge_resharding_weights)
|
||||
# prepare compute_costs, communication_costs and memory_costs
|
||||
compute_costs = []
|
||||
communication_costs = []
|
||||
memory_costs = []
|
||||
peak_act_memory_costs, constant_memory_costs = [], []
|
||||
node_sharding_weights = []
|
||||
for node, strategies_vector in zip(self.nodes, self.leaf_strategies):
|
||||
for index, strategy in enumerate(strategies_vector):
|
||||
compute_cost = strategy.sharding_cost
|
||||
origin_communication_cost = strategy.communication_cost
|
||||
memory_cost = strategy.const_memory_footprint * node.sharding_weight
|
||||
peak_act_memory = strategy.peak_memory_footprint
|
||||
# extract the memory cost in float from MemoryCost item and sum them up
|
||||
compute_costs.append(compute_cost)
|
||||
# node in extra_node_costs means it has some extra communication
|
||||
# cost from node merging, so we need to add those extra communication
|
||||
# cost into
|
||||
|
||||
communication_costs.append(origin_communication_cost)
|
||||
peak_act_memory_costs.append(peak_act_memory)
|
||||
constant_memory_costs.append(memory_cost)
|
||||
node_sharding_weights.append(node.sharding_weight +
|
||||
node.pipeline_weight)
|
||||
|
||||
compute_costs = np.array(compute_costs)
|
||||
communication_costs = np.array(communication_costs)
|
||||
memory_costs = np.array([constant_memory_costs, peak_act_memory_costs])
|
||||
node_sharding_weights = np.array(node_sharding_weights)
|
||||
same_spec_nodes_dict = defaultdict(list)
|
||||
node_cost_level = []
|
||||
for idx, node in enumerate(self.nodes):
|
||||
if node.same_spec_id >= 0:
|
||||
same_spec_nodes_dict[node.same_spec_id].append(idx)
|
||||
node_cost_level.append((node.building_block_id, node.cost_level))
|
||||
# omit initial value for nodes
|
||||
s_init_np = None
|
||||
following_nodes = [-1 for i in range(node_nums)]
|
||||
liveness_set = self.nodes
|
||||
alias_set = []
|
||||
alias_convert_costs = None
|
||||
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, node_sharding_weights, edge_resharding_weights, same_spec_nodes_dict, node_cost_level, edge_cost_level, alias_convert_costs, s_init_np, self.verbose
|
||||
|
||||
def _call_solver_serialized_args(self,
|
||||
node_nums,
|
||||
memory_budget,
|
||||
strategies_len,
|
||||
following_nodes,
|
||||
edge_pairs,
|
||||
alias_set,
|
||||
liveness_set,
|
||||
compute_costs,
|
||||
communication_costs,
|
||||
memory_costs,
|
||||
resharding_costs,
|
||||
node_sharding_weights,
|
||||
edge_resharding_weights,
|
||||
same_spec_nodes_dict,
|
||||
node_cost_level,
|
||||
edge_cost_level,
|
||||
alias_convert_costs,
|
||||
s_init_np=None,
|
||||
verbose=True):
|
||||
"""
|
||||
Call the solver with serialized arguments.
|
||||
"""
|
||||
|
||||
time.time()
|
||||
|
||||
for x in [
|
||||
strategies_len, edge_pairs, compute_costs, communication_costs,
|
||||
memory_costs, resharding_costs, node_sharding_weights,
|
||||
edge_resharding_weights
|
||||
]:
|
||||
assert isinstance(x, np.ndarray)
|
||||
assert len(strategies_len) == node_nums, "strategies_len"
|
||||
|
||||
def get_non_zero_index(binary_vector):
|
||||
"""
|
||||
Get the index of non-zero item in a vector.
|
||||
"""
|
||||
ct = 0
|
||||
ret = None
|
||||
for i, elem in enumerate(binary_vector):
|
||||
if pulp.value(elem):
|
||||
ret = i
|
||||
ct += 1
|
||||
|
||||
assert ct == 1
|
||||
return ret
|
||||
|
||||
# 0. Unpack flatten numpy arrays
|
||||
s_follow = following_nodes
|
||||
s_alias = alias_set
|
||||
|
||||
E = edge_pairs.reshape((-1, 2)) # noqa
|
||||
r = []
|
||||
pt = 0
|
||||
edge_set = set()
|
||||
for (i, j) in E:
|
||||
prod_length = strategies_len[i] * strategies_len[j]
|
||||
|
||||
if (i, j) in edge_set:
|
||||
raise ValueError(f"Duplicated edges: {(i, j)}")
|
||||
|
||||
edge_set.add((i, j))
|
||||
r.append(resharding_costs[pt:pt + prod_length])
|
||||
pt += prod_length
|
||||
assert pt == len(resharding_costs)
|
||||
|
||||
######################
|
||||
# omit alias set now #
|
||||
######################
|
||||
|
||||
# A = alias_set.reshape((-1, 2)) # noqa
|
||||
# for (i, j) in A:
|
||||
# prod_length = strategies_len[i] * strategies_len[j]
|
||||
# v.append(alias_convert_costs[pt:pt + prod_length])
|
||||
# pt += prod_length
|
||||
# assert pt == len(alias_convert_costs)
|
||||
|
||||
# L = [] # noqa
|
||||
# pt = node_nums
|
||||
# for i in range(node_nums):
|
||||
# length = liveness_set[i]
|
||||
# L.append(liveness_set[pt:pt + length])
|
||||
# pt += length
|
||||
# assert pt == len(liveness_set)
|
||||
pt = 0
|
||||
|
||||
c = []
|
||||
d = []
|
||||
m = []
|
||||
peak_m = []
|
||||
pt = 0
|
||||
for i in range(node_nums):
|
||||
length = strategies_len[i]
|
||||
c.append(compute_costs[pt:pt + length])
|
||||
d.append(communication_costs[pt:pt + length])
|
||||
m.append(memory_costs[0][pt:pt + length])
|
||||
peak_m.append(memory_costs[1][pt:pt + length])
|
||||
pt += length
|
||||
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
|
||||
assert pt == len(
|
||||
communication_costs), f"{pt} == {len(communication_costs)}"
|
||||
assert pt == len(memory_costs[0]), f"{pt} == {len(memory_costs[0])}"
|
||||
|
||||
# 1. Create variables
|
||||
|
||||
#############################
|
||||
# create variables for node #
|
||||
#############################
|
||||
s = []
|
||||
num_nodes = 0
|
||||
reverse_follow_backpatch = []
|
||||
for i in range(node_nums):
|
||||
if s_follow[i] < 0:
|
||||
if strategies_len[i] == 1:
|
||||
s.append([1])
|
||||
else:
|
||||
if i not in s_alias:
|
||||
num_nodes += 1
|
||||
s.append(
|
||||
LpVariable.matrix(f"s[{i}]",
|
||||
(range(strategies_len[i]), ),
|
||||
cat="Binary"))
|
||||
else:
|
||||
s.append(s[s_alias[i]])
|
||||
else:
|
||||
if s_follow[i] < len(s):
|
||||
s.append(s[s_follow[i]])
|
||||
else:
|
||||
s.append(None)
|
||||
reverse_follow_backpatch.append(i)
|
||||
|
||||
for i in reverse_follow_backpatch:
|
||||
s[i] = s[s_follow[i]]
|
||||
|
||||
#############################
|
||||
# create variables for edge #
|
||||
#############################
|
||||
e = []
|
||||
num_edges = 0
|
||||
map_edge_to_idx = {}
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
if len(s[i]) == 1:
|
||||
e.append(s[j])
|
||||
elif len(s[j]) == 1:
|
||||
e.append(s[i])
|
||||
else:
|
||||
if i in s_alias and j in s_alias and (
|
||||
s_alias[i], s_alias[j]) in map_edge_to_idx:
|
||||
e.append(e[map_edge_to_idx[(s_alias[i], s_alias[j])]])
|
||||
else:
|
||||
num_edges += 1
|
||||
e.append(
|
||||
LpVariable.matrix(f"e[{i},{j}]",
|
||||
(range(len(s[i]) * len(s[j])), ),
|
||||
cat="Binary"))
|
||||
assert len(e[idx]) == len(r[idx])
|
||||
map_edge_to_idx[(i, j)] = idx
|
||||
for element in s:
|
||||
assert len(element) > 0
|
||||
# 2. Set initial value
|
||||
######################################
|
||||
# set a initial value for warm start #
|
||||
######################################
|
||||
if s_init_np is not None:
|
||||
s_init = s_init_np.reshape((-1, 3))
|
||||
for (idx, value, fix) in s_init:
|
||||
for i in range(len(s[idx])):
|
||||
s[idx][i].setInitialValue(i == value)
|
||||
if fix:
|
||||
s[idx][i].fixValue()
|
||||
|
||||
# 3. Objective
|
||||
prob = LpProblem("myProblem", LpMinimize)
|
||||
###################################################################
|
||||
# computing the node cost(computing cost and communication cost) #
|
||||
###################################################################
|
||||
obj = 0
|
||||
block_cost_level_dict = {}
|
||||
for i in range(node_nums):
|
||||
assert len(s[i]) == len(c[i])
|
||||
assert len(s[i]) == len(d[i])
|
||||
obj += (lpDot(s[i], c[i]) +
|
||||
lpDot(s[i], d[i])) * node_sharding_weights[i]
|
||||
cost_level = node_cost_level[i]
|
||||
if -1 != cost_level[1]:
|
||||
if cost_level in block_cost_level_dict:
|
||||
block_cost_level_dict[cost_level] += lpDot(
|
||||
s[i], c[i]) + lpDot(s[i], d[i])
|
||||
else:
|
||||
block_cost_level_dict[cost_level] = lpDot(
|
||||
s[i], c[i]) + lpDot(s[i], d[i])
|
||||
|
||||
#############################################
|
||||
# computing the edge cost(resharding cost) #
|
||||
#############################################
|
||||
|
||||
for i in range(len(E)):
|
||||
assert len(e[i]) == len(r[i])
|
||||
obj += lpDot(e[i], r[i]) * edge_resharding_weights[i]
|
||||
cost_level = edge_cost_level[i]
|
||||
if -1 != cost_level[1]:
|
||||
if cost_level in block_cost_level_dict:
|
||||
block_cost_level_dict[cost_level] += lpDot(e[i], r[i])
|
||||
else:
|
||||
block_cost_level_dict[cost_level] = lpDot(e[i], r[i])
|
||||
prob += obj
|
||||
if len(block_cost_level_dict) >= 2:
|
||||
block_cost_levels = [key for key in block_cost_level_dict.keys()]
|
||||
for i in range(len(block_cost_levels)):
|
||||
for j in range(i + 1, len(block_cost_levels)):
|
||||
if block_cost_levels[i][1] > block_cost_levels[j][1]:
|
||||
prob += block_cost_level_dict[
|
||||
block_cost_levels[i]] >= block_cost_level_dict[
|
||||
block_cost_levels[j]] + 1e-6
|
||||
elif block_cost_levels[i][1] < block_cost_levels[j][1]:
|
||||
prob += block_cost_level_dict[
|
||||
block_cost_levels[j]] >= block_cost_level_dict[
|
||||
block_cost_levels[i]] + 1e-6
|
||||
# 4. Constraints
|
||||
# (a). specified by `cat="Binary"`
|
||||
|
||||
# (b)
|
||||
#################################################
|
||||
# make sure each node only choose one strategy #
|
||||
#################################################
|
||||
for i in range(node_nums):
|
||||
if s_follow[i] < 0:
|
||||
prob += lpSum(s[i]) == 1
|
||||
|
||||
# (c)
|
||||
#################################################
|
||||
# force to constrain some nodes have the same sharding specs #
|
||||
#################################################
|
||||
for spec_id, same_spec_nodes_id in same_spec_nodes_dict.items():
|
||||
num_same_spec_nodes = len(same_spec_nodes_id)
|
||||
if num_same_spec_nodes >= 2:
|
||||
src_node_s = s[same_spec_nodes_id[0]]
|
||||
num_specs = len(src_node_s)
|
||||
for i in range(1, num_same_spec_nodes):
|
||||
dst_node_s = s[same_spec_nodes_id[i]]
|
||||
assert len(
|
||||
dst_node_s
|
||||
) == num_specs, f'unmatched num_specs when force node {same_spec_nodes_id[0]} and {same_spec_nodes_id[i]} the same specs'
|
||||
for j in range(num_specs):
|
||||
prob += (src_node_s[j] == dst_node_s[j])
|
||||
|
||||
# (c)
|
||||
#################################################
|
||||
# compute memory consumption with liveness set #
|
||||
#################################################
|
||||
if memory_budget > 0:
|
||||
# calculate the constant memory
|
||||
mem = 0
|
||||
for node in liveness_set:
|
||||
if node not in self.node_index_dict:
|
||||
continue
|
||||
node_index = self.node_index_dict[node]
|
||||
mem += lpSum(s[node_index][j] * m[node_index][j]
|
||||
for j in range(len(s[node_index])))
|
||||
# calculate the peak activation memory
|
||||
for node in liveness_set:
|
||||
if node not in self.node_index_dict:
|
||||
continue
|
||||
node_index = self.node_index_dict[node]
|
||||
cur_peak_mem = lpSum(s[node_index][j] * peak_m[node_index][j]
|
||||
for j in range(len(s[node_index])))
|
||||
total_mem = mem + cur_peak_mem
|
||||
prob += total_mem <= memory_budget
|
||||
|
||||
# (d). specified by `cat="Binary"`
|
||||
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
if strategies_len[i] == 1 or strategies_len[j] == 1:
|
||||
continue
|
||||
|
||||
# (e)
|
||||
prob += lpSum(e[idx]) == 1
|
||||
|
||||
# (f)
|
||||
for row in range(len(s[i])):
|
||||
C = len(s[j]) # noqa
|
||||
prob += lpSum(e[idx][row * C + col]
|
||||
for col in range(0, C)) <= s[i][row]
|
||||
|
||||
# (g)
|
||||
for col in range(len(s[j])):
|
||||
R = len(s[i]) # noqa
|
||||
C = len(s[j]) # noqa
|
||||
prob += lpSum(e[idx][row * C + col]
|
||||
for row in range(0, R)) <= s[j][col]
|
||||
|
||||
if prob.objective.isNumericalConstant():
|
||||
objective = float(pulp.value(prob.objective))
|
||||
status = pulp.LpStatusOptimal
|
||||
else:
|
||||
msg = verbose
|
||||
time_limit = 600
|
||||
solver = pulp.PULP_CBC_CMD(
|
||||
mip=True,
|
||||
msg=msg,
|
||||
timeLimit=time_limit,
|
||||
threads=multiprocessing.cpu_count(),
|
||||
)
|
||||
prob.solve(solver)
|
||||
|
||||
status = prob.status
|
||||
objective = pulp.value(prob.objective)
|
||||
objective = float(
|
||||
objective) if objective is not None else self.INFINITY_COST
|
||||
|
||||
if prob.status in [pulp.LpStatusInfeasible]:
|
||||
objective = self.INFINITY_COST
|
||||
|
||||
# Get and check results
|
||||
s_val = np.full((node_nums, ), -1, dtype=np.int32)
|
||||
for i in range(node_nums):
|
||||
s_val[i] = get_non_zero_index(s[i])
|
||||
|
||||
e_val = np.full((len(E), ), -1, dtype=np.int32)
|
||||
for (idx, (i, j)) in enumerate(E):
|
||||
e_val[idx] = get_non_zero_index(e[idx])
|
||||
i_spec_index = e_val[idx] // len(s[j])
|
||||
j_spec_index = e_val[idx] % len(s[j])
|
||||
assert i_spec_index == s_val[i], f"e_val[{i}][{j}]"
|
||||
assert j_spec_index == s_val[j], f"e_val[{i}][{j}]"
|
||||
if verbose and r[idx][e_val[idx]] > 0:
|
||||
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
|
||||
|
||||
self.last_s_val = list(s_val)
|
||||
# self._recover_merged_node_strategy()
|
||||
self.last_objective = objective
|
||||
|
||||
if objective >= self.INFINITY_COST:
|
||||
warnings.warn(
|
||||
f"Cannot find an optimized solution given memory budget {self.memory_budget}, Please consider\n" + \
|
||||
f"1. increase memory budget if possible\n" + \
|
||||
f"2. enlarge mesh shape if possible\n" + \
|
||||
f"3. decrease the maximum parameters(i.e., max_batch_size, max_seq_len, etc.) in building config")
|
||||
if memory_budget > 0:
|
||||
# calculate the constant memory
|
||||
mem = 0
|
||||
for node in liveness_set:
|
||||
if node not in self.node_index_dict:
|
||||
continue
|
||||
node_index = self.node_index_dict[node]
|
||||
j = self.last_s_val[node_index]
|
||||
mem += m[node_index][j]
|
||||
max_peak_mem = 0
|
||||
for node in liveness_set:
|
||||
if node not in self.node_index_dict:
|
||||
continue
|
||||
node_index = self.node_index_dict[node]
|
||||
j = self.last_s_val[node_index]
|
||||
cur_peak_mem = peak_m[node_index][j]
|
||||
max_peak_mem = max(max_peak_mem, cur_peak_mem)
|
||||
logger.debug(
|
||||
f'constant_mem = {mem}, peak_mem = {max_peak_mem}, memory_budget = {memory_budget}'
|
||||
)
|
||||
|
||||
solution = Solution(self.leaf_strategies, self.last_s_val, e_val,
|
||||
edge_pairs, self.node_index_dict,
|
||||
self.last_objective)
|
||||
return status, solution
|
||||
|
||||
def find_solution(self):
|
||||
"""
|
||||
Call the solver with serialized arguments and handle python errors. Additionally,
|
||||
we could give a serious of solutions with different memory budget.
|
||||
"""
|
||||
if self.solution_numbers == 1:
|
||||
args = self._prepare_data_for_solver()
|
||||
ret = self._call_solver_serialized_args(*args)
|
||||
|
||||
return ret
|
||||
|
||||
origin_memory_budget = self.memory_budget
|
||||
memory_budget_list = [
|
||||
origin_memory_budget * self.memory_increasing_coefficient**i
|
||||
for i in range(self.solution_numbers)
|
||||
]
|
||||
ret_list = []
|
||||
for memory_budget in memory_budget_list:
|
||||
self.memory_budget = memory_budget
|
||||
args = self._prepare_data_for_solver()
|
||||
ret = self._call_solver_serialized_args(*args)
|
||||
ret_list.append(ret)
|
||||
|
||||
return ret_list
|
||||
@ -1,41 +0,0 @@
|
||||
import copy
|
||||
|
||||
from .node import Node
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class Activation(Node):
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['input0'].shape)
|
||||
dim_partition_list.append({})
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
||||
strategies_vector = StrategiesVector(self)
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
in0_partition_dict = dim_partition_dict
|
||||
out_partition_dict = copy.deepcopy(dim_partition_dict)
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_partition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
name = '{} = <activation op> {}'.format(
|
||||
sharding_spec_mapping['output0'].sharding_sequence,
|
||||
sharding_spec_mapping['input0'].sharding_sequence)
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
return strategies_vector
|
||||
@ -1,34 +0,0 @@
|
||||
import copy
|
||||
|
||||
from .node import Node
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class Assertion(Node):
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
predecessor = self.predecessor_nodes[0] # one input for softmax node
|
||||
strategies_vector = StrategiesVector(self)
|
||||
for idx, strategy in enumerate(predecessor.strategies_vector):
|
||||
global_input_name = self.op_data[
|
||||
'input0'].name # current node's local name input0 -> global name xxx
|
||||
prenode_local_name = predecessor.global_to_local_op_name[
|
||||
global_input_name] # global name xxx -> pre node local output name
|
||||
dim_partition_dict = copy.deepcopy(
|
||||
strategy.sharding_specs[prenode_local_name].dim_partition_dict)
|
||||
in0_partition_dict = dim_partition_dict
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_partition_dict,
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
return strategies_vector
|
||||
name = '<assertion> {}'.format(
|
||||
sharding_spec_mapping['input0'].sharding_sequence)
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
return strategies_vector
|
||||
@ -1,45 +0,0 @@
|
||||
import copy
|
||||
|
||||
from .node import Node
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class Cast(Node):
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['input0'].shape)
|
||||
dim_partition_list.append({})
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
||||
strategies_vector = StrategiesVector(self)
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
in0_partition_dict = dim_partition_dict
|
||||
out_partition_dict = copy.deepcopy(dim_partition_dict)
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_partition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
name = '{} = <cast op> {}'.format(
|
||||
sharding_spec_mapping['output0'].sharding_sequence,
|
||||
sharding_spec_mapping['input0'].sharding_sequence)
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
return strategies_vector
|
||||
|
||||
def _update_memory_cost(self, strategies):
|
||||
pass
|
||||
@ -1,58 +0,0 @@
|
||||
__all__ = [
|
||||
'CommSpec',
|
||||
]
|
||||
|
||||
|
||||
class CommSpec:
|
||||
|
||||
def __init__(self,
|
||||
comm_pattern,
|
||||
sharding_spec,
|
||||
gather_dim=None,
|
||||
shard_dim=None,
|
||||
logical_process_axis=None,
|
||||
mix_gather=False,
|
||||
forward_only=True):
|
||||
self.comm_pattern = comm_pattern
|
||||
self.sharding_spec = sharding_spec
|
||||
self.gather_dim = gather_dim
|
||||
self.shard_dim = shard_dim
|
||||
self.logical_process_axis = logical_process_axis
|
||||
self.device_mesh = self.sharding_spec.device_mesh
|
||||
self.mix_gather = mix_gather
|
||||
self.forward_only = forward_only
|
||||
if self.gather_dim:
|
||||
assert len(self.gather_dim) == len(
|
||||
self.logical_process_axis
|
||||
), f'unmatched gather dim {self.gather_dim} and logical process axis {self.logical_process_axis}'
|
||||
if self.shard_dim:
|
||||
assert len(self.shard_dim) == len(
|
||||
self.logical_process_axis
|
||||
), f'unmatched shard dim {self.shard_dim} and logical process axis {self.logical_process_axis}'
|
||||
if self.gather_dim and self.shard_dim:
|
||||
assert len(self.shard_dim) == len(
|
||||
self.gather_dim
|
||||
), f'unmatched gather dim {self.gather_dim} and shard dim {self.shard_dim}'
|
||||
|
||||
def get_comm_cost(self):
|
||||
'''
|
||||
For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
|
||||
compute the communication cost.
|
||||
For shard operation, it is an on-chip operation, so the communication cost is zero.
|
||||
'''
|
||||
comm_size = self.sharding_spec.get_sharded_size_per_device()
|
||||
dtype = self.sharding_spec.dtype
|
||||
|
||||
# reduce list_of_list to list
|
||||
comm_dims = sum(self.logical_process_axis, [])
|
||||
comm_cost = self.device_mesh.estimate_comm_cost(self.comm_pattern,
|
||||
comm_dims, comm_size,
|
||||
dtype)
|
||||
return comm_cost
|
||||
|
||||
def get_mem_cost(self):
|
||||
return self.device_mesh.shape_consistency_manager.mem_cost([self])
|
||||
|
||||
def get_max_mem_cost(self):
|
||||
return self.device_mesh.shape_consistency_manager.mem_cost(
|
||||
[self], mem_pattern='max')
|
||||
@ -1,56 +0,0 @@
|
||||
import copy
|
||||
|
||||
from .node import Node
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class Concatenation(Node):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
layer.to_subclass()
|
||||
batch_dims = [i for i in range(len(self.get_output(0).shape))]
|
||||
self.axis = layer.as_trt().axis
|
||||
batch_dims.remove(self.axis)
|
||||
self._generate_bcast_dims(batch_dims, self.get_output(0).shape)
|
||||
layer.to_base_class()
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['output0'].shape)
|
||||
dim_partition_list.append({})
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
||||
|
||||
dim_partition_dict_mapping = {}
|
||||
strategies_vector = StrategiesVector(self)
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
if self.axis in dim_partition_dict:
|
||||
dim_partition_dict.pop(self.axis)
|
||||
for idx in range(self.num_inputs):
|
||||
in_partition_dict = copy.deepcopy(dim_partition_dict)
|
||||
dim_partition_dict_mapping[f'input{idx}'] = in_partition_dict
|
||||
out_partition_dict = dim_partition_dict
|
||||
dim_partition_dict_mapping['output0'] = out_partition_dict
|
||||
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
name = '{} = <concate along dim {}> {}'.format(
|
||||
sharding_spec_mapping['output0'].sharding_sequence, self.axis, [
|
||||
sharding_spec_mapping[f'input{idx}'].sharding_sequence
|
||||
for idx in range(self.num_inputs)
|
||||
])
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
return strategies_vector
|
||||
@ -1,45 +0,0 @@
|
||||
from .node import Node
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class Constant(Node):
|
||||
|
||||
def _update_memory_cost(self, strategies):
|
||||
super()._update_memory_cost(strategies)
|
||||
for strategy in strategies:
|
||||
strategy.inout_memory_footprint = 0.0
|
||||
strategy.peak_memory_footprint = 0.0
|
||||
strategy.const_memory_footprint = strategy.sharding_specs[
|
||||
'output0'].get_max_sharded_size_per_device()
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['output0'].shape)
|
||||
dim_partition_list.append({})
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
||||
|
||||
strategies_vector = StrategiesVector(self)
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
dim_partition_dict_mapping = {'output0': dim_partition_dict}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
sharding_seq = sharding_spec_mapping['output0'].sharding_sequence
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=f'constant-op {sharding_seq}',
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
return strategies_vector
|
||||
|
||||
def _profile_sharding_cost(self, strategy, device_mesh):
|
||||
return 0.0
|
||||
@ -1,49 +0,0 @@
|
||||
from .node import Node
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class ElementWise(Node):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
batch_dims = [i for i in range(len(self.get_output(0).shape))]
|
||||
self._generate_bcast_dims(batch_dims, self.get_output(0).shape)
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['output0'].shape)
|
||||
dim_partition_list.append({})
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
||||
strategies_vector = StrategiesVector(self)
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
in0_partition_dict = self._recover_bcast_partition_dict(
|
||||
dim_partition_dict, self.op_data['input0'])
|
||||
in1_partition_dict = self._recover_bcast_partition_dict(
|
||||
dim_partition_dict, self.op_data['input1'])
|
||||
out_partition_dict = dim_partition_dict
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_partition_dict,
|
||||
"input1": in1_partition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
name = '{} = {} <elementwise> {}'.format(
|
||||
sharding_spec_mapping['output0'].sharding_sequence,
|
||||
sharding_spec_mapping['input0'].sharding_sequence,
|
||||
sharding_spec_mapping['input1'].sharding_sequence)
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
return strategies_vector
|
||||
@ -1,59 +0,0 @@
|
||||
import tensorrt as trt
|
||||
|
||||
from .node import Node
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class Fill(Node):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
layer.to_subclass()
|
||||
self.operation = layer.as_trt().operation
|
||||
layer.to_base_class()
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['output0'].shape)
|
||||
dim_partition_list.append({})
|
||||
if self.num_inputs == 0 and self.operation != trt.FillOperation.LINSPACE:
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
||||
|
||||
strategies_vector = StrategiesVector(self)
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
dim_partition_dict_mapping = {'output0': dim_partition_dict}
|
||||
for i in range(self.num_inputs):
|
||||
dim_partition_dict_mapping[f'input{i}'] = {}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
sharding_seq = sharding_spec_mapping['output0'].sharding_sequence
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=f'fill-op {sharding_seq}',
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
return strategies_vector
|
||||
|
||||
def _profile_sharding_cost(self, strategy, device_mesh):
|
||||
updated_layer_attrs = {}
|
||||
updated_input_values = {}
|
||||
shape = strategy.sharding_specs['output0'].get_sharded_shape_per_device(
|
||||
)
|
||||
if self.layer.num_inputs >= 1:
|
||||
updated_input_values[0] = shape
|
||||
else:
|
||||
updated_layer_attrs['shape'] = shape
|
||||
elapsed_time = self.node_runtime_profiler.runtime_profile(
|
||||
self.layer, updated_layer_attrs, updated_input_values, strategy,
|
||||
device_mesh)
|
||||
return elapsed_time
|
||||
@ -1,196 +0,0 @@
|
||||
import copy
|
||||
|
||||
import tensorrt as trt
|
||||
|
||||
from .comm_spec import CommSpec
|
||||
from .node import Node
|
||||
from .sharding_spec import DimSpec
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class Gather(Node):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
layer.to_subclass()
|
||||
self.mode = layer.as_trt().mode
|
||||
self.axis = layer.as_trt().axis
|
||||
self.num_elementwise_dims = layer.as_trt().num_elementwise_dims
|
||||
self.input_id = 0
|
||||
self.indice_id = 1
|
||||
self.support_vocab_tp = False
|
||||
layer.to_base_class()
|
||||
|
||||
def _update_memory_cost(self, strategies):
|
||||
for strategy in strategies:
|
||||
# for gather node, it input0's read = output0's write
|
||||
inout_memory_footprint = (
|
||||
strategy.sharding_specs['output0'].get_sharded_size_per_device(
|
||||
) * 2 +
|
||||
strategy.sharding_specs['input1'].get_sharded_size_per_device())
|
||||
strategy.inout_memory_footprint = inout_memory_footprint
|
||||
strategy.peak_memory_footprint = (
|
||||
strategy.sharding_specs['output0'].
|
||||
get_max_sharded_size_per_device() + strategy.
|
||||
sharding_specs['input0'].get_max_sharded_size_per_device() +
|
||||
strategy.sharding_specs['input1'].
|
||||
get_max_sharded_size_per_device())
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
if self.mode == trt.GatherMode.DEFAULT:
|
||||
return self._default_gather_strategies(device_mesh)
|
||||
elif self.mode == trt.GatherMode.ELEMENT:
|
||||
return self._element_gather_strategies(device_mesh)
|
||||
elif self.mode == trt.GatherMode.ND:
|
||||
assert 0, 'unsupport gatherND'
|
||||
else:
|
||||
assert 0, f'unsupport gather mode {self.mode}'
|
||||
|
||||
def _element_gather_strategies(self, device_mesh):
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['output0'].shape)
|
||||
dim_partition_list.append({})
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
||||
|
||||
strategies_vector = StrategiesVector(self)
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
if self.axis in dim_partition_dict:
|
||||
dim_partition_dict.pop(self.axis)
|
||||
|
||||
dim_partition_dict_mapping = {
|
||||
'input0': dim_partition_dict,
|
||||
'input1': copy.deepcopy(dim_partition_dict),
|
||||
'output0': copy.deepcopy(dim_partition_dict),
|
||||
}
|
||||
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
name = '{} = {} <element gather op axis {}> {}'.format(
|
||||
sharding_spec_mapping['output0'].sharding_sequence,
|
||||
sharding_spec_mapping['input0'].sharding_sequence, self.axis,
|
||||
sharding_spec_mapping['input1'].sharding_sequence)
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
return strategies_vector
|
||||
|
||||
# for plugin, indice is input0, and weight is input1, which is different from gather node
|
||||
def _default_gather_strategies(self, device_mesh):
|
||||
|
||||
def add_sharding_strategy(dim_partition_dict_mapping,
|
||||
vocab_tp_dim=None):
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if len(sharding_spec_mapping) > 0:
|
||||
name = '{} = {} <default gather op axis {}, num_elementwise_dims {}> {}'.format(
|
||||
sharding_spec_mapping['output0'].sharding_sequence,
|
||||
sharding_spec_mapping['input0'].sharding_sequence,
|
||||
self.axis, self.num_elementwise_dims,
|
||||
sharding_spec_mapping['input1'].sharding_sequence)
|
||||
communication_action_mapping = {}
|
||||
if vocab_tp_dim is not None:
|
||||
name += f'_allreduce{DimSpec(vocab_tp_dim)}'
|
||||
output0_comm_action = CommSpec(
|
||||
comm_pattern='all_reduce',
|
||||
sharding_spec=sharding_spec_mapping['output0'],
|
||||
logical_process_axis=[vocab_tp_dim],
|
||||
)
|
||||
communication_action_mapping[
|
||||
'output0'] = output0_comm_action
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
input_id, indice_id = self.input_id, self.indice_id
|
||||
strategies_vector = StrategiesVector(self)
|
||||
input_size = len(self.op_data[f'input{input_id}'].shape)
|
||||
indice_size = len(self.op_data[f'input{indice_id}'].shape)
|
||||
output_dim = input_size + indice_size - 1 - self.num_elementwise_dims
|
||||
for strategy in self.predecessor_nodes[input_id].strategies_vector:
|
||||
# current node's local name input0 -> global name xxx
|
||||
global_input_name = self.op_data[f'input{input_id}'].name
|
||||
# global name xxx -> pre node local output name
|
||||
prenode_local_name = self.predecessor_nodes[
|
||||
input_id].global_to_local_op_name[global_input_name]
|
||||
input_dim_partition_dict = copy.deepcopy(
|
||||
strategy.sharding_specs[prenode_local_name].dim_partition_dict)
|
||||
|
||||
vocab_tp_dim = input_dim_partition_dict.pop(self.axis, None)
|
||||
|
||||
input_mesh_dims = []
|
||||
for dim, mesh_dims in input_dim_partition_dict.items():
|
||||
input_mesh_dims += mesh_dims
|
||||
input_mesh_dims = set(input_mesh_dims)
|
||||
|
||||
for idx_strategy in self.predecessor_nodes[
|
||||
indice_id].strategies_vector:
|
||||
# current node's local name input0 -> global name xxx
|
||||
global_indice_name = self.op_data[f'input{indice_id}'].name
|
||||
# global name xxx -> pre node local output name
|
||||
prenode_local_name = self.predecessor_nodes[
|
||||
indice_id].global_to_local_op_name[global_indice_name]
|
||||
indice_dim_partition_dict = copy.deepcopy(
|
||||
idx_strategy.sharding_specs[prenode_local_name].
|
||||
dim_partition_dict)
|
||||
|
||||
for dim, indice_mesh_dims in idx_strategy.sharding_specs[
|
||||
prenode_local_name].dim_partition_dict.items():
|
||||
for indice_mesh_dim in indice_mesh_dims:
|
||||
if indice_mesh_dim in input_mesh_dims:
|
||||
indice_dim_partition_dict.pop(dim)
|
||||
break
|
||||
|
||||
out_partition_dict = {}
|
||||
|
||||
for dim in range(output_dim):
|
||||
if dim < self.axis:
|
||||
if dim in input_dim_partition_dict:
|
||||
out_partition_dict[dim] = \
|
||||
input_dim_partition_dict[dim]
|
||||
elif dim >= self.axis and dim < self.axis + indice_size - self.num_elementwise_dims:
|
||||
indice_dim = dim - self.axis + self.num_elementwise_dims
|
||||
if indice_dim in indice_dim_partition_dict:
|
||||
out_partition_dict[dim] = \
|
||||
indice_dim_partition_dict[indice_dim]
|
||||
else:
|
||||
input_dim = dim - (indice_size -
|
||||
self.num_elementwise_dims) + 1
|
||||
if input_dim in input_dim_partition_dict:
|
||||
out_partition_dict[dim] = \
|
||||
input_dim_partition_dict[input_dim]
|
||||
|
||||
dim_partition_dict_mapping = {
|
||||
f"input{input_id}": input_dim_partition_dict,
|
||||
f"input{indice_id}": indice_dim_partition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
add_sharding_strategy(dim_partition_dict_mapping)
|
||||
|
||||
if self.support_vocab_tp and vocab_tp_dim is not None:
|
||||
vocab_tp_dim_partition_dict = {
|
||||
**input_dim_partition_dict,
|
||||
self.axis: vocab_tp_dim,
|
||||
}
|
||||
dim_partition_dict_mapping = {
|
||||
f"input{input_id}": vocab_tp_dim_partition_dict,
|
||||
f"input{indice_id}": indice_dim_partition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
add_sharding_strategy(dim_partition_dict_mapping,
|
||||
vocab_tp_dim)
|
||||
|
||||
return strategies_vector
|
||||
@ -1,56 +0,0 @@
|
||||
import copy
|
||||
|
||||
from .node import Node
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class Identity(Node):
|
||||
|
||||
def _update_memory_cost(self, strategies):
|
||||
if not self.is_fake:
|
||||
super()._update_memory_cost(strategies)
|
||||
else:
|
||||
# fake nodes for building block/PP connection
|
||||
pass
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['input0'].shape)
|
||||
dim_partition_list.append({})
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
||||
strategies_vector = StrategiesVector(self)
|
||||
# dim_partition_dict can be the same as previous node if solver's time is a problem
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
in0_partition_dict = dim_partition_dict
|
||||
out_partition_dict = copy.deepcopy(dim_partition_dict)
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_partition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
name = '{} = <identity op> {}'.format(
|
||||
sharding_spec_mapping['output0'].sharding_sequence,
|
||||
sharding_spec_mapping['input0'].sharding_sequence)
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
return strategies_vector
|
||||
|
||||
def _profile_sharding_cost(self, strategy, device_mesh):
|
||||
# if same spec id is not 0, identify node is used as same spec id node
|
||||
if self.same_spec_id == -1:
|
||||
return super()._profile_sharding_cost(strategy, device_mesh)
|
||||
else:
|
||||
return 0.0
|
||||
@ -1,79 +0,0 @@
|
||||
from .node import Node
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class InputNode(Node):
|
||||
|
||||
def _update_memory_cost(self, strategies):
|
||||
for strategy in strategies:
|
||||
if not self.no_memory_footprint:
|
||||
strategy.const_memory_footprint = strategy.sharding_specs[
|
||||
'output0'].get_max_sharded_size_per_device()
|
||||
|
||||
def __init__(self, tensor):
|
||||
self._layer = None
|
||||
self.is_shape_io = False
|
||||
self._inputs = []
|
||||
self._outputs = []
|
||||
self.predecessor_nodes = []
|
||||
self.predecessor_nodes_out_index = {}
|
||||
self.successor_nodes = []
|
||||
self.op_data = {}
|
||||
self.global_to_local_op_name = {}
|
||||
self.is_replicated = tensor.attrs.get("is_replicated", False)
|
||||
self.same_spec_id = tensor.attrs.get("same_spec_id", -1)
|
||||
self.no_memory_footprint = tensor.attrs.get("no_memory_footprint",
|
||||
False)
|
||||
self.building_block_id = -1
|
||||
self.cost_level = -1
|
||||
self.stage_type = None
|
||||
self.in_start_block = None
|
||||
self.in_end_block = None
|
||||
self.in_slowest_block = None
|
||||
output = tensor.copy()
|
||||
self._outputs.append(output)
|
||||
self.op_data['output0'] = output
|
||||
self.global_to_local_op_name[output.name] = 'output0'
|
||||
|
||||
self.sharding_weight = 1.0
|
||||
self.resharding_weight = 1.0
|
||||
self.pipeline_weight = 0
|
||||
self.node_name = tensor.name
|
||||
self.node_type = 'input_node'
|
||||
self.num_inputs = 0
|
||||
self.num_outputs = 1
|
||||
self.dtype = tensor.dtype
|
||||
self.strategies_vector = []
|
||||
self.node_runtime_profiler = None
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['output0'].shape)
|
||||
dim_partition_list.append({})
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
||||
|
||||
strategies_vector = StrategiesVector(self)
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
dim_partition_dict_mapping = {'output0': dim_partition_dict}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
sharding_seq = sharding_spec_mapping['output0'].sharding_sequence
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=f'input-op {sharding_seq}',
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
return strategies_vector
|
||||
|
||||
def _profile_sharding_cost(self, strategy, device_mesh):
|
||||
return 0.0
|
||||
@ -1,798 +0,0 @@
|
||||
import copy
|
||||
import operator
|
||||
from functools import reduce
|
||||
|
||||
import tensorrt as trt
|
||||
|
||||
from ..device_mesh import LogicalDeviceMesh
|
||||
from ..utils import get_builder_flags
|
||||
from .comm_spec import CommSpec
|
||||
from .node import Node
|
||||
from .sharding_spec import DimSpec
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class MatrixMultiply(Node):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
layer.to_subclass()
|
||||
batch_dims = [i for i in range(len(self.get_output(0).shape))][:-2]
|
||||
self._generate_bcast_dims(batch_dims, self.get_output(0).shape)
|
||||
self.op0_transpose = layer.as_trt().op0 == trt.MatrixOperation.TRANSPOSE
|
||||
self.op1_transpose = layer.as_trt().op1 == trt.MatrixOperation.TRANSPOSE
|
||||
self.num_out_dims = len(self.get_output(0).shape)
|
||||
dtypes_str = [
|
||||
self.get_input(0).dtype_str,
|
||||
self.get_input(1).dtype_str,
|
||||
self.get_output(0).dtype_str
|
||||
]
|
||||
dtypes_size = [
|
||||
self.get_input(0).dtype_size,
|
||||
self.get_input(1).dtype_size,
|
||||
self.get_output(0).dtype_size
|
||||
]
|
||||
min_idx = dtypes_size.index(min(dtypes_size))
|
||||
self.dtype = dtypes_str[min_idx]
|
||||
layer.to_base_class()
|
||||
|
||||
def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1, device_mesh):
|
||||
in0_split_dim = -1 if self.op0_transpose else -2
|
||||
in1_split_dim = -2 if self.op1_transpose else -1
|
||||
name = (f'{DimSpec(mesh_dim_0)}{DimSpec(mesh_dim_1)} = '
|
||||
f'{DimSpec(mesh_dim_0)}R x R{DimSpec(mesh_dim_1)}')
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": {
|
||||
in0_split_dim: mesh_dim_0
|
||||
},
|
||||
"input1": {
|
||||
in1_split_dim: mesh_dim_1
|
||||
},
|
||||
"output0": {
|
||||
-2: mesh_dim_0,
|
||||
-1: mesh_dim_1
|
||||
},
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if len(sharding_spec_mapping) == 0:
|
||||
return None
|
||||
strategy = self._get_sharding_strategy(name = name, \
|
||||
sharding_spec_mapping = sharding_spec_mapping, \
|
||||
communication_action_mapping = {})
|
||||
return strategy
|
||||
|
||||
def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1,
|
||||
device_mesh):
|
||||
# handle the case SR = SS x SR
|
||||
name = (
|
||||
f'{DimSpec(mesh_dim_0)}R = '
|
||||
f'{DimSpec(mesh_dim_0)}{DimSpec(mesh_dim_1)} x {DimSpec(mesh_dim_1)}R'
|
||||
f'_allreduce{DimSpec(mesh_dim_1)}')
|
||||
in0_split_dim = [-1, -2] if self.op0_transpose else [-2, -1]
|
||||
in1_split_dim = -1 if self.op1_transpose else -2
|
||||
# get sharding spec mapping
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": {
|
||||
in0_split_dim[0]: mesh_dim_0,
|
||||
in0_split_dim[1]: mesh_dim_1
|
||||
},
|
||||
"input1": {
|
||||
in1_split_dim: mesh_dim_1
|
||||
},
|
||||
"output0": {
|
||||
-2: mesh_dim_0
|
||||
},
|
||||
}
|
||||
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if len(sharding_spec_mapping) == 0:
|
||||
return None
|
||||
# get communication action mapping
|
||||
communication_action_mapping = {}
|
||||
output0_comm_action = CommSpec(
|
||||
comm_pattern='all_reduce',
|
||||
sharding_spec=sharding_spec_mapping['output0'],
|
||||
logical_process_axis=[mesh_dim_1],
|
||||
)
|
||||
communication_action_mapping['output0'] = output0_comm_action
|
||||
return self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
def _split_both_contract_rs(self, name, rs_dim, rs_mesh_dim, src_spec,
|
||||
dim_partition_dict_mapping, device_mesh):
|
||||
output0_comm_action = CommSpec(
|
||||
comm_pattern='reduce_scatter',
|
||||
sharding_spec=src_spec,
|
||||
shard_dim=[rs_dim],
|
||||
logical_process_axis=[rs_mesh_dim],
|
||||
)
|
||||
rs_out_partition_dict_mapping = copy.deepcopy(
|
||||
dim_partition_dict_mapping)
|
||||
rs_out_partition_dict_mapping["output0"][rs_dim] = rs_mesh_dim
|
||||
rs_out_sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
rs_out_partition_dict_mapping, device_mesh)
|
||||
if len(rs_out_sharding_spec_mapping) == 0:
|
||||
return None
|
||||
|
||||
communication_action_mapping = {}
|
||||
communication_action_mapping['output0'] = output0_comm_action
|
||||
return self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=rs_out_sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
def _split_lhs_space_both_contract_rs(self, mesh_dim_0, mesh_dim_1,
|
||||
device_mesh):
|
||||
# handle the case SS = SS x SR -> reduce_scatter
|
||||
in0_split_dim = [-1, -2] if self.op0_transpose else [-2, -1]
|
||||
in1_split_dim = -1 if self.op1_transpose else -2
|
||||
# get sharding spec mapping
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": {
|
||||
in0_split_dim[0]: mesh_dim_0,
|
||||
in0_split_dim[1]: mesh_dim_1
|
||||
},
|
||||
"input1": {
|
||||
in1_split_dim: mesh_dim_1
|
||||
},
|
||||
"output0": {
|
||||
-2: mesh_dim_0,
|
||||
},
|
||||
}
|
||||
mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if len(mm_out_sharding_spec_mapping) == 0:
|
||||
return []
|
||||
strategies = []
|
||||
for rs_dim in range(self.num_out_dims):
|
||||
if rs_dim != self.num_out_dims - 2:
|
||||
name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [
|
||||
'R'
|
||||
] * self.num_out_dims, ['R'] * self.num_out_dims
|
||||
name_in0[-2], name_in0[-1] = str(DimSpec(mesh_dim_0)), str(
|
||||
DimSpec(mesh_dim_1))
|
||||
name_in1[-2] = str(DimSpec(mesh_dim_1))
|
||||
name_out0[-2], name_out0[rs_dim] = str(
|
||||
DimSpec(mesh_dim_0)), str(DimSpec(mesh_dim_1))
|
||||
name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join(
|
||||
name_in1), ', '.join(name_out0)
|
||||
name = (f'[{name_out0}] = [{name_in0}] x [{name_in1}]'
|
||||
f'_reducescatter{(rs_dim, DimSpec(mesh_dim_1))}')
|
||||
ret = self._split_both_contract_rs(
|
||||
name, rs_dim, mesh_dim_1,
|
||||
mm_out_sharding_spec_mapping['output0'],
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if ret:
|
||||
strategies.append(ret)
|
||||
return strategies
|
||||
|
||||
def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1,
|
||||
device_mesh):
|
||||
name = (
|
||||
f'R{DimSpec(mesh_dim_1)} = '
|
||||
f'R{DimSpec(mesh_dim_0)} x {DimSpec(mesh_dim_0)}{DimSpec(mesh_dim_1)}'
|
||||
f'_allreduce{DimSpec(mesh_dim_0)}')
|
||||
in0_split_dim = -2 if self.op0_transpose else -1
|
||||
in1_split_dim = [-1, -2] if self.op1_transpose else [-2, -1]
|
||||
# get sharding specs
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": {
|
||||
in0_split_dim: mesh_dim_0
|
||||
},
|
||||
"input1": {
|
||||
in1_split_dim[0]: mesh_dim_0,
|
||||
in1_split_dim[1]: mesh_dim_1
|
||||
},
|
||||
"output0": {
|
||||
-1: mesh_dim_1
|
||||
},
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if len(sharding_spec_mapping) == 0:
|
||||
return None
|
||||
# get communication actions
|
||||
communication_action_mapping = {}
|
||||
output0_comm_action = CommSpec(
|
||||
comm_pattern='all_reduce',
|
||||
sharding_spec=sharding_spec_mapping['output0'],
|
||||
logical_process_axis=[mesh_dim_0],
|
||||
)
|
||||
communication_action_mapping['output0'] = output0_comm_action
|
||||
return self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
def _split_rhs_space_both_contract_rs(self, mesh_dim_0, mesh_dim_1,
|
||||
device_mesh):
|
||||
in0_split_dim = -2 if self.op0_transpose else -1
|
||||
in1_split_dim = [-1, -2] if self.op1_transpose else [-2, -1]
|
||||
# get sharding specs
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": {
|
||||
in0_split_dim: mesh_dim_0
|
||||
},
|
||||
"input1": {
|
||||
in1_split_dim[0]: mesh_dim_0,
|
||||
in1_split_dim[1]: mesh_dim_1
|
||||
},
|
||||
"output0": {
|
||||
-1: mesh_dim_1
|
||||
},
|
||||
}
|
||||
mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if len(mm_out_sharding_spec_mapping) == 0:
|
||||
return []
|
||||
strategies = []
|
||||
for rs_dim in range(self.num_out_dims):
|
||||
if rs_dim != self.num_out_dims - 1:
|
||||
name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [
|
||||
'R'
|
||||
] * self.num_out_dims, ['R'] * self.num_out_dims
|
||||
name_in1[-2], name_in1[-1] = str(DimSpec(mesh_dim_0)), str(
|
||||
DimSpec(mesh_dim_1))
|
||||
name_in0[-1] = str(DimSpec(mesh_dim_0))
|
||||
name_out0[-1], name_out0[rs_dim] = str(
|
||||
DimSpec(mesh_dim_1)), str(DimSpec(mesh_dim_0))
|
||||
name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join(
|
||||
name_in1), ', '.join(name_out0)
|
||||
name = (f'[{name_out0}] = [{name_in0}] x [{name_in1}]'
|
||||
f'_reducescatter{(rs_dim, DimSpec(mesh_dim_0))}')
|
||||
ret = self._split_both_contract_rs(
|
||||
name, rs_dim, mesh_dim_0,
|
||||
mm_out_sharding_spec_mapping['output0'],
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if ret:
|
||||
strategies.append(ret)
|
||||
return strategies
|
||||
|
||||
def _recompute_split_both_contract(self, mesh_dim, device_mesh):
|
||||
name = (f'RR = R{DimSpec(mesh_dim)} x {DimSpec(mesh_dim)}R'
|
||||
f'_allreduce{DimSpec(mesh_dim)}')
|
||||
in0_split_dim = -2 if self.op0_transpose else -1
|
||||
in1_split_dim = -1 if self.op1_transpose else -2
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": {
|
||||
in0_split_dim: mesh_dim
|
||||
},
|
||||
"input1": {
|
||||
in1_split_dim: mesh_dim
|
||||
},
|
||||
"output0": {},
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if len(sharding_spec_mapping) == 0:
|
||||
return None
|
||||
|
||||
# get communication action
|
||||
communication_action_mapping = {}
|
||||
output0_comm_action = CommSpec(
|
||||
comm_pattern='all_reduce',
|
||||
sharding_spec=sharding_spec_mapping['output0'],
|
||||
logical_process_axis=[mesh_dim],
|
||||
)
|
||||
communication_action_mapping['output0'] = output0_comm_action
|
||||
return self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
def _recompute_split_both_contract_rs(self, mesh_dim, device_mesh):
|
||||
name = (f'{DimSpec(mesh_dim)}R = '
|
||||
f'R{DimSpec(mesh_dim)} x {DimSpec(mesh_dim)}R'
|
||||
f'_reducescatter0_{DimSpec(mesh_dim)}')
|
||||
in0_split_dim = -2 if self.op0_transpose else -1
|
||||
in1_split_dim = -1 if self.op1_transpose else -2
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": {
|
||||
in0_split_dim: mesh_dim
|
||||
},
|
||||
"input1": {
|
||||
in1_split_dim: mesh_dim
|
||||
},
|
||||
"output0": {},
|
||||
}
|
||||
mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if len(mm_out_sharding_spec_mapping) == 0:
|
||||
return []
|
||||
|
||||
strategies = []
|
||||
for rs_dim in range(self.num_out_dims):
|
||||
name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [
|
||||
'R'
|
||||
] * self.num_out_dims, ['R'] * self.num_out_dims
|
||||
name_in0[-1], name_in1[-2], name_out0[rs_dim] = str(
|
||||
DimSpec(mesh_dim)), str(DimSpec(mesh_dim)), str(
|
||||
DimSpec(mesh_dim))
|
||||
name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join(
|
||||
name_in1), ', '.join(name_out0)
|
||||
name = f'[{name_out0}] = [{name_in0}] x [{name_in1}]_reducescatter{(rs_dim, DimSpec(mesh_dim))}'
|
||||
ret = self._split_both_contract_rs(
|
||||
name, rs_dim, mesh_dim, mm_out_sharding_spec_mapping['output0'],
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if ret:
|
||||
strategies.append(ret)
|
||||
return strategies
|
||||
|
||||
def _split_rhs_space_only(self, mesh_dim, device_mesh):
|
||||
name = f'R{DimSpec(mesh_dim)} = RR x R{DimSpec(mesh_dim)}'
|
||||
in1_split_dim = -2 if self.op1_transpose else -1
|
||||
# get sharding spec
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": {},
|
||||
"input1": {
|
||||
in1_split_dim: mesh_dim
|
||||
},
|
||||
"output0": {
|
||||
-1: mesh_dim
|
||||
},
|
||||
}
|
||||
# We don't have to do anything special for bias here, because
|
||||
# the bias is already the same sharding spec as the output0.
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if len(sharding_spec_mapping) == 0:
|
||||
return None
|
||||
return self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
|
||||
def _split_lhs_space_only(self, mesh_dim, device_mesh):
|
||||
name = f'{DimSpec(mesh_dim)}R = {DimSpec(mesh_dim)}R x RR'
|
||||
in0_split_dim = -1 if self.op0_transpose else -2
|
||||
# get sharding spec
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": {
|
||||
in0_split_dim: mesh_dim
|
||||
},
|
||||
"input1": {},
|
||||
"output0": {
|
||||
-2: mesh_dim
|
||||
},
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if len(sharding_spec_mapping) == 0:
|
||||
return None
|
||||
return self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
|
||||
def _non_split(self, device_mesh):
|
||||
name = 'RR = RR x RR'
|
||||
# get sharding spec
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": {},
|
||||
"input1": {},
|
||||
"output0": {},
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if len(sharding_spec_mapping) == 0:
|
||||
return None
|
||||
return self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
|
||||
def _split_one_batch_dim(self, batch_dim, mesh_dim, device_mesh):
|
||||
name = (
|
||||
f'{DimSpec(mesh_dim)}b{batch_dim}RR = '
|
||||
f'{DimSpec(mesh_dim)}b{batch_dim}RR x {DimSpec(mesh_dim)}b{batch_dim}RR'
|
||||
)
|
||||
in0_data = self.op_data['input0']
|
||||
in1_data = self.op_data['input1']
|
||||
|
||||
batch_partition_dict = {batch_dim: mesh_dim}
|
||||
in0_parition_dict = self._recover_bcast_partition_dict(
|
||||
batch_partition_dict, in0_data)
|
||||
in1_parition_dict = self._recover_bcast_partition_dict(
|
||||
batch_partition_dict, in1_data)
|
||||
out_partition_dict = {batch_dim: mesh_dim}
|
||||
# TODO: Double check if MatrixMultiplication's output has bcast in dim
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_parition_dict,
|
||||
"input1": in1_parition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if len(sharding_spec_mapping) == 0:
|
||||
return None
|
||||
return self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
|
||||
def _split_two_batch_dims(self, batch_dim0, batch_dim1, mesh_dim0,
|
||||
mesh_dim1, device_mesh):
|
||||
name = (
|
||||
f'{DimSpec(mesh_dim0)}b{batch_dim0}{DimSpec(mesh_dim1)}b{batch_dim1}RR = '
|
||||
f'{DimSpec(mesh_dim0)}b{batch_dim0}RR x {DimSpec(mesh_dim1)}b{batch_dim1}RR'
|
||||
)
|
||||
in0_data = self.op_data['input0']
|
||||
in1_data = self.op_data['input1']
|
||||
|
||||
in0_parition_dict = {}
|
||||
if batch_dim0 not in in0_data.attrs["broadcast_dims"]:
|
||||
in0_parition_dict[batch_dim0] = mesh_dim0
|
||||
if batch_dim1 not in in0_data.attrs["broadcast_dims"]:
|
||||
in0_parition_dict[batch_dim1] = mesh_dim1
|
||||
|
||||
in1_parition_dict = {}
|
||||
if batch_dim0 not in in1_data.attrs["broadcast_dims"]:
|
||||
in1_parition_dict[batch_dim0] = mesh_dim0
|
||||
if batch_dim1 not in in1_data.attrs["broadcast_dims"]:
|
||||
in1_parition_dict[batch_dim1] = mesh_dim1
|
||||
|
||||
batch_partition_dict = {batch_dim0: mesh_dim0, batch_dim1: mesh_dim1}
|
||||
in0_parition_dict = self._recover_bcast_partition_dict(
|
||||
batch_partition_dict, in0_data)
|
||||
in1_parition_dict = self._recover_bcast_partition_dict(
|
||||
batch_partition_dict, in1_data)
|
||||
out_partition_dict = {batch_dim0: mesh_dim0, batch_dim1: mesh_dim1}
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_parition_dict,
|
||||
"input1": in1_parition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if len(sharding_spec_mapping) == 0:
|
||||
return None
|
||||
return self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
|
||||
def _split_batch_dim_lhs_space(self, batch_dim, mesh_dim0, mesh_dim1,
|
||||
device_mesh):
|
||||
|
||||
name = (
|
||||
f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R = '
|
||||
f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R x {DimSpec(mesh_dim0)}b{batch_dim}RR'
|
||||
)
|
||||
in0_data = self.op_data['input0']
|
||||
in1_data = self.op_data['input1']
|
||||
in0_parition_dict = {batch_dim: mesh_dim0}
|
||||
in1_parition_dict = {batch_dim: mesh_dim0}
|
||||
in0_lhs_split_dim = -1 if self.op0_transpose else -2
|
||||
in0_parition_dict[in0_lhs_split_dim] = mesh_dim1
|
||||
|
||||
in0_parition_dict = self._recover_bcast_partition_dict(
|
||||
in0_parition_dict, in0_data)
|
||||
in1_parition_dict = self._recover_bcast_partition_dict(
|
||||
in1_parition_dict, in1_data)
|
||||
out_partition_dict = {batch_dim: mesh_dim0, -2: mesh_dim1}
|
||||
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_parition_dict,
|
||||
"input1": in1_parition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if len(sharding_spec_mapping) == 0:
|
||||
return None
|
||||
return self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
|
||||
def _split_batch_dim_rhs_space(self, batch_dim, mesh_dim0, mesh_dim1,
|
||||
device_mesh):
|
||||
|
||||
name = (
|
||||
f'{DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)} = '
|
||||
f'{DimSpec(mesh_dim0)}b{batch_dim}RR x {DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)}'
|
||||
)
|
||||
in0_data = self.op_data['input0']
|
||||
in1_data = self.op_data['input1']
|
||||
in0_parition_dict = {batch_dim: mesh_dim0}
|
||||
in1_parition_dict = {batch_dim: mesh_dim0}
|
||||
|
||||
in1_rhs_split_dim = -2 if self.op1_transpose else -1
|
||||
in1_parition_dict[in1_rhs_split_dim] = mesh_dim1
|
||||
|
||||
in0_parition_dict = self._recover_bcast_partition_dict(
|
||||
in0_parition_dict, in0_data)
|
||||
in1_parition_dict = self._recover_bcast_partition_dict(
|
||||
in1_parition_dict, in1_data)
|
||||
out_partition_dict = {batch_dim: mesh_dim0, -1: mesh_dim1}
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_parition_dict,
|
||||
"input1": in1_parition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if len(sharding_spec_mapping) == 0:
|
||||
return None
|
||||
return self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
|
||||
def _split_batch_dim_both_contract(self, batch_dim, mesh_dim0, mesh_dim1,
|
||||
device_mesh):
|
||||
|
||||
name = (
|
||||
f'{DimSpec(mesh_dim0)}b{batch_dim}RR = '
|
||||
f'{DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)} x '
|
||||
f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R_AR{mesh_dim1}'
|
||||
)
|
||||
in0_data = self.op_data['input0']
|
||||
in1_data = self.op_data['input1']
|
||||
in0_parition_dict = {batch_dim: mesh_dim0}
|
||||
in1_parition_dict = {batch_dim: mesh_dim0}
|
||||
|
||||
in0_contract_dim = -2 if self.op0_transpose else -1
|
||||
in1_contract_dim = -1 if self.op1_transpose else -2
|
||||
in0_parition_dict[in0_contract_dim] = mesh_dim1
|
||||
in1_parition_dict[in1_contract_dim] = mesh_dim1
|
||||
|
||||
in0_parition_dict = self._recover_bcast_partition_dict(
|
||||
in0_parition_dict, in0_data)
|
||||
in1_parition_dict = self._recover_bcast_partition_dict(
|
||||
in1_parition_dict, in1_data)
|
||||
out_partition_dict = {batch_dim: mesh_dim0}
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_parition_dict,
|
||||
"input1": in1_parition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if len(sharding_spec_mapping) == 0:
|
||||
return None
|
||||
|
||||
# get communication actions
|
||||
communication_action_mapping = {}
|
||||
output0_comm_action = CommSpec(
|
||||
comm_pattern='all_reduce',
|
||||
sharding_spec=sharding_spec_mapping['output0'],
|
||||
logical_process_axis=[mesh_dim1],
|
||||
)
|
||||
communication_action_mapping['output0'] = output0_comm_action
|
||||
return self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
def _split_batch_dim_both_contract_rs(self, batch_dim, mesh_dim0, mesh_dim1,
|
||||
device_mesh):
|
||||
|
||||
name = (
|
||||
f'{DimSpec(mesh_dim0)}b{batch_dim}RR = '
|
||||
f'{DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)} x '
|
||||
f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R_AR{mesh_dim1}'
|
||||
)
|
||||
in0_data = self.op_data['input0']
|
||||
in1_data = self.op_data['input1']
|
||||
in0_parition_dict = {batch_dim: mesh_dim0}
|
||||
in1_parition_dict = {batch_dim: mesh_dim0}
|
||||
|
||||
in0_contract_dim = -2 if self.op0_transpose else -1
|
||||
in1_contract_dim = -1 if self.op1_transpose else -2
|
||||
in0_parition_dict[in0_contract_dim] = mesh_dim1
|
||||
in1_parition_dict[in1_contract_dim] = mesh_dim1
|
||||
|
||||
in0_parition_dict = self._recover_bcast_partition_dict(
|
||||
in0_parition_dict, in0_data)
|
||||
in1_parition_dict = self._recover_bcast_partition_dict(
|
||||
in1_parition_dict, in1_data)
|
||||
out_partition_dict = {batch_dim: mesh_dim0}
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_parition_dict,
|
||||
"input1": in1_parition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if len(mm_out_sharding_spec_mapping) == 0:
|
||||
return []
|
||||
|
||||
strategies = []
|
||||
for rs_dim in range(self.num_out_dims):
|
||||
if rs_dim != batch_dim:
|
||||
name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [
|
||||
'R'
|
||||
] * self.num_out_dims, ['R'] * self.num_out_dims
|
||||
name_in0[batch_dim], name_in0[-1] = str(
|
||||
DimSpec(mesh_dim0)), str(DimSpec(mesh_dim1))
|
||||
name_in1[batch_dim], name_in1[-2] = str(
|
||||
DimSpec(mesh_dim0)), str(DimSpec(mesh_dim1))
|
||||
name_in1[batch_dim], name_out0[rs_dim] = str(
|
||||
DimSpec(mesh_dim0)), str(DimSpec(mesh_dim1))
|
||||
name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join(
|
||||
name_in1), ', '.join(name_out0)
|
||||
name = f'[{name_out0}] = [{name_in0}] x [{name_in1}]_reducescatter{(rs_dim, DimSpec(mesh_dim1))}'
|
||||
ret = self._split_both_contract_rs(
|
||||
name, rs_dim, mesh_dim1,
|
||||
mm_out_sharding_spec_mapping['output0'],
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if ret:
|
||||
strategies.append(ret)
|
||||
return strategies
|
||||
|
||||
def _dp_strategies(self, device_mesh):
|
||||
strategies = []
|
||||
# S0R = S0R x RR
|
||||
strategies.append(self._split_lhs_space_only([0], device_mesh))
|
||||
# S1R = S1R x RR
|
||||
strategies.append(self._split_lhs_space_only([1], device_mesh))
|
||||
# S01R = S01R x RR
|
||||
strategies.append(self._split_lhs_space_only([0, 1], device_mesh))
|
||||
return strategies
|
||||
|
||||
def _tp_strategies(self, device_mesh: LogicalDeviceMesh):
|
||||
strategies = []
|
||||
# RR = RS x SR _ AR
|
||||
strategies.append(self._recompute_split_both_contract([0], device_mesh))
|
||||
strategies.append(self._recompute_split_both_contract([1], device_mesh))
|
||||
strategies.append(
|
||||
self._recompute_split_both_contract([0, 1], device_mesh))
|
||||
|
||||
if device_mesh.config.enable_reduce_scatter:
|
||||
# RS x SR _ reduce scatter
|
||||
strategies.extend(
|
||||
self._recompute_split_both_contract_rs([0], device_mesh))
|
||||
strategies.extend(
|
||||
self._recompute_split_both_contract_rs([1], device_mesh))
|
||||
strategies.extend(
|
||||
self._recompute_split_both_contract_rs([0, 1], device_mesh))
|
||||
|
||||
# RS = RR x RS
|
||||
strategies.append(self._split_rhs_space_only([0], device_mesh))
|
||||
strategies.append(self._split_rhs_space_only([1], device_mesh))
|
||||
strategies.append(self._split_rhs_space_only([0, 1], device_mesh))
|
||||
|
||||
# RS = RS x SS _ AR
|
||||
strategies.append(
|
||||
self._split_rhs_space_both_contract([0], [1], device_mesh))
|
||||
strategies.append(
|
||||
self._split_rhs_space_both_contract([1], [0], device_mesh))
|
||||
|
||||
if device_mesh.config.enable_reduce_scatter:
|
||||
# RS x SS _ reduce scatter
|
||||
strategies.extend(
|
||||
self._split_rhs_space_both_contract_rs([0], [1], device_mesh))
|
||||
strategies.extend(
|
||||
self._split_rhs_space_both_contract_rs([1], [0], device_mesh))
|
||||
|
||||
return strategies
|
||||
|
||||
def _mix_strategies(self, device_mesh):
|
||||
strategies = []
|
||||
|
||||
# SR = SS x SR_AR
|
||||
strategies.append(
|
||||
self._split_lhs_space_both_contract([0], [1], device_mesh))
|
||||
strategies.append(
|
||||
self._split_lhs_space_both_contract([1], [0], device_mesh))
|
||||
if device_mesh.config.enable_reduce_scatter:
|
||||
# RS x SS _ reduce scatter
|
||||
strategies.extend(
|
||||
self._split_lhs_space_both_contract_rs([0], [1], device_mesh))
|
||||
strategies.extend(
|
||||
self._split_lhs_space_both_contract_rs([1], [0], device_mesh))
|
||||
# SS = SR x RS
|
||||
strategies.append(self._split_lhs_space_rhs_space([0], [1],
|
||||
device_mesh))
|
||||
strategies.append(self._split_lhs_space_rhs_space([0], [1],
|
||||
device_mesh))
|
||||
|
||||
# RR = RR x RR
|
||||
strategies.append(self._non_split(device_mesh))
|
||||
return strategies
|
||||
|
||||
def _bmm_strategies(self, device_mesh: LogicalDeviceMesh):
|
||||
strategies = []
|
||||
bmm_dim = len(self.op_data['output0'].shape)
|
||||
if bmm_dim >= 3:
|
||||
for batch_dim in range(0, bmm_dim - 2):
|
||||
strategies.append(
|
||||
self._split_one_batch_dim(batch_dim, [0], device_mesh))
|
||||
strategies.append(
|
||||
self._split_one_batch_dim(batch_dim, [1], device_mesh))
|
||||
strategies.append(
|
||||
self._split_one_batch_dim(batch_dim, [0, 1], device_mesh))
|
||||
|
||||
strategies.append(
|
||||
self._split_batch_dim_lhs_space(batch_dim, [0], [1],
|
||||
device_mesh))
|
||||
strategies.append(
|
||||
self._split_batch_dim_lhs_space(batch_dim, [1], [0],
|
||||
device_mesh))
|
||||
|
||||
strategies.append(
|
||||
self._split_batch_dim_rhs_space(batch_dim, [0], [1],
|
||||
device_mesh))
|
||||
strategies.append(
|
||||
self._split_batch_dim_rhs_space(batch_dim, [1], [0],
|
||||
device_mesh))
|
||||
|
||||
strategies.append(
|
||||
self._split_batch_dim_both_contract(batch_dim, [0], [1],
|
||||
device_mesh))
|
||||
strategies.append(
|
||||
self._split_batch_dim_both_contract(batch_dim, [1], [0],
|
||||
device_mesh))
|
||||
if device_mesh.config.enable_reduce_scatter:
|
||||
strategies.extend(
|
||||
self._split_batch_dim_both_contract_rs(
|
||||
batch_dim, [0], [1], device_mesh))
|
||||
strategies.extend(
|
||||
self._split_batch_dim_both_contract_rs(
|
||||
batch_dim, [1], [0], device_mesh))
|
||||
if bmm_dim >= 4:
|
||||
for batch_dim0 in range(0, bmm_dim - 2):
|
||||
for batch_dim1 in range(0, bmm_dim - 2):
|
||||
if batch_dim0 != batch_dim1:
|
||||
strategies.append(
|
||||
self._split_two_batch_dims(
|
||||
batch_dim0, batch_dim1, [0], [1],
|
||||
device_mesh))
|
||||
|
||||
return strategies
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
strategies_vector = StrategiesVector(self)
|
||||
dp_strategies = self._dp_strategies(device_mesh)
|
||||
tp_strategies = self._tp_strategies(device_mesh)
|
||||
mix_strategies = self._mix_strategies(device_mesh)
|
||||
bmm_strategies = self._bmm_strategies(device_mesh)
|
||||
strategies_vector.extend(dp_strategies)
|
||||
strategies_vector.extend(tp_strategies)
|
||||
strategies_vector.extend(mix_strategies)
|
||||
strategies_vector.extend(bmm_strategies)
|
||||
return strategies_vector
|
||||
|
||||
def is_fp16(self):
|
||||
builder_flags = get_builder_flags()
|
||||
return builder_flags & (1 << int(trt.BuilderFlag.FP16)) != 0
|
||||
|
||||
def _get_math_time(self, strategy, device_mesh):
|
||||
shape_in0 = strategy.sharding_specs[
|
||||
'input0'].get_sharded_shape_per_device()
|
||||
shape_out = strategy.sharding_specs[
|
||||
'output0'].get_sharded_shape_per_device()
|
||||
m, n = shape_out[-2], shape_out[-1]
|
||||
batches = shape_out[:-2]
|
||||
k = shape_in0[-2] if self.op0_transpose else shape_in0[-1]
|
||||
macs_shape = batches + [m, n, k]
|
||||
macs = reduce(operator.mul, macs_shape, 1) * 2
|
||||
config = device_mesh.config
|
||||
cluster_info = device_mesh.cluster_info
|
||||
dtype = self.dtype
|
||||
# For fp16 matmul ops that use_fp32_acc=True.
|
||||
# They are mistaken for fp32 ops since all of their IO tensors use fp32 dtype.
|
||||
if self.is_fp16() and self.dtype == "float32":
|
||||
dtype = "float16"
|
||||
math_throughput_tflops = getattr(cluster_info.math_throughput, dtype)
|
||||
assert math_throughput_tflops != 0, \
|
||||
"Undefined {} math throughput of cluster {}".format(dtype, config.cluster_key)
|
||||
math_time = macs / math_throughput_tflops * 1e-6 * cluster_info.math_efficiency
|
||||
return math_time
|
||||
|
||||
def _update_memory_cost(self, strategies):
|
||||
super()._update_memory_cost(strategies)
|
||||
# For fp16 matmul ops that use_fp32_acc=True.
|
||||
# Their memory footprints are calculated based on fp32 IO tensors.
|
||||
# Actually they will use fp16 IO tensors after fused.
|
||||
# So we divide all the memory footprints by 2.
|
||||
if self.is_fp16() and self.dtype == "float32":
|
||||
for strategy in strategies:
|
||||
strategy.inout_memory_footprint /= 2
|
||||
strategy.peak_memory_footprint /= 2
|
||||
strategy.comm_buff_memory_footprint /= 2
|
||||
@ -1,376 +0,0 @@
|
||||
from abc import ABC
|
||||
|
||||
from ..config import CostModel
|
||||
from ..device_mesh import LogicalDeviceMesh
|
||||
from .comm_spec import CommSpec
|
||||
from .sharding_spec import ShardingSpec
|
||||
from .sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
|
||||
|
||||
class Node(ABC):
|
||||
|
||||
def __init__(self, layer):
|
||||
self._layer = layer
|
||||
self.is_shape_io = self._layer.is_shape_io
|
||||
self._inputs = []
|
||||
self._outputs = []
|
||||
self.predecessor_nodes = []
|
||||
self.predecessor_nodes_out_index = {}
|
||||
self.successor_nodes = []
|
||||
self.op_data = {}
|
||||
self.global_to_local_op_name = {}
|
||||
self.num_inputs = 0
|
||||
self.is_replicated = layer.attrs.get("is_replicated", False)
|
||||
self.same_spec_id = layer.attrs.get("same_spec_id", -1)
|
||||
self.is_fake = self.same_spec_id != -1
|
||||
self.building_block_id = layer.attrs.get("building_block_id", -1)
|
||||
self.cost_level = -1
|
||||
self.stage_type = layer.attrs.get("stage_type", None)
|
||||
self.in_start_block = layer.attrs.get("in_start_block", False)
|
||||
self.in_end_block = layer.attrs.get("in_end_block", False)
|
||||
self.in_slowest_block = layer.attrs.get("in_slowest_block", False)
|
||||
for i, input in enumerate(layer.inputs):
|
||||
if input is None:
|
||||
self._inputs.append(None)
|
||||
self.op_data[f'input{i}'] = None
|
||||
continue
|
||||
input = input.copy()
|
||||
input.attrs["broadcast_dims"] = []
|
||||
self._inputs.append(input)
|
||||
self.op_data[f'input{i}'] = input
|
||||
self.global_to_local_op_name[input.name] = f'input{i}'
|
||||
|
||||
for i, output in enumerate(layer.outputs):
|
||||
output = output.copy()
|
||||
output.attrs["broadcast_dims"] = []
|
||||
self._outputs.append(output)
|
||||
self.op_data[f'output{i}'] = output
|
||||
self.global_to_local_op_name[output.name] = f'output{i}'
|
||||
|
||||
self.sharding_weight = 1.0
|
||||
self.resharding_weight = 1.0
|
||||
self.pipeline_weight = 0
|
||||
self.node_name = layer.name
|
||||
self.node_type = 'normal_node'
|
||||
self.num_inputs = layer.num_inputs
|
||||
self.num_outputs = layer.num_outputs
|
||||
self.dtype = layer.as_trt().precision
|
||||
self.strategies_vector = []
|
||||
self.node_runtime_profiler = None
|
||||
|
||||
def post_init(self, graph):
|
||||
for input in self.inputs:
|
||||
if input is None:
|
||||
self.predecessor_nodes.append(None)
|
||||
continue
|
||||
if input.producer is None:
|
||||
predecessor_node = graph.get_node(input.name)
|
||||
self.predecessor_nodes.append(predecessor_node)
|
||||
self.predecessor_nodes_out_index[predecessor_node] = 0
|
||||
predecessor_node.successor_nodes.append(self)
|
||||
else:
|
||||
predecessor_node = graph.get_node(input.producer.name)
|
||||
self.predecessor_nodes.append(predecessor_node)
|
||||
self.predecessor_nodes_out_index[
|
||||
predecessor_node] = input.output_index
|
||||
predecessor_node.successor_nodes.append(self)
|
||||
|
||||
@property
|
||||
def layer(self):
|
||||
return self._layer
|
||||
|
||||
def get_input(self, index):
|
||||
return self._inputs[index]
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return self._inputs
|
||||
|
||||
def get_output(self, index):
|
||||
return self._outputs[index]
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return self._outputs
|
||||
|
||||
def collect_strategies(self, device_mesh):
|
||||
strategies_vector = self._collect_strategies(device_mesh)
|
||||
strategies_vector = self._post_process(strategies_vector)
|
||||
self._update_sharding_cost(strategies_vector, device_mesh)
|
||||
self.strategies_vector = strategies_vector
|
||||
return self.strategies_vector
|
||||
|
||||
def _set_strategy(self, strategy, device_mesh):
|
||||
strategies_vector = StrategiesVector(self)
|
||||
if strategy is None:
|
||||
dim_partition_dict_mapping = {}
|
||||
for i in range(self.num_inputs):
|
||||
dim_partition_dict_mapping[f'input{i}'] = {}
|
||||
for i in range(self.num_outputs):
|
||||
dim_partition_dict_mapping[f'output{i}'] = {}
|
||||
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
assert 0 != len(
|
||||
sharding_spec_mapping
|
||||
), f'failed to set default(all Replicate) strategy for node {self.node_name}'
|
||||
name = 'RRs'
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
else:
|
||||
sharding_specs_map = strategy.sharding_specs
|
||||
comm_specs_map = strategy.communication_actions
|
||||
dim_partition_dict_mapping = {}
|
||||
for op_name, sharding_spec in sharding_specs_map.items():
|
||||
dim_partition_dict_mapping[
|
||||
op_name] = sharding_spec.dim_partition_dict
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
assert 0 != len(
|
||||
sharding_spec_mapping
|
||||
), f'failed to set strategy for node {self.node_name}'
|
||||
comm_specs_mapping = {}
|
||||
if len(comm_specs_map) > 0:
|
||||
for op_name, comm_spec in comm_specs_map.items():
|
||||
comm_specs_mapping[op_name] = CommSpec(
|
||||
comm_pattern=comm_spec.comm_pattern,
|
||||
sharding_spec=sharding_spec_mapping[op_name],
|
||||
logical_process_axis=comm_spec.logical_process_axis,
|
||||
)
|
||||
strategies_vector.append(
|
||||
self._get_sharding_strategy(
|
||||
name=strategy.name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=comm_specs_mapping))
|
||||
return strategies_vector
|
||||
|
||||
def set_strategy(self, strategy, device_mesh):
|
||||
strategies_vector = self._set_strategy(strategy, device_mesh)
|
||||
strategies_vector = self._post_process(strategies_vector)
|
||||
self._update_sharding_cost(strategies_vector, device_mesh)
|
||||
self.strategies_vector = strategies_vector
|
||||
return self.strategies_vector
|
||||
|
||||
def update_resharding_cost(self):
|
||||
self._update_resharding_cost(self.strategies_vector)
|
||||
return self.strategies_vector
|
||||
|
||||
def _to_sharding_spec_mapping(self, dim_partition_dict_mapping,
|
||||
device_mesh):
|
||||
results = {}
|
||||
for op_data_name, dim_partition_dict in dim_partition_dict_mapping.items(
|
||||
):
|
||||
if op_data_name in self.op_data:
|
||||
op_data = self.op_data[op_data_name]
|
||||
|
||||
def _to_sharding_spec(op_data, dim_partition_dict):
|
||||
sharding_spec = ShardingSpec(
|
||||
device_mesh,
|
||||
op_data.dtype_str_size, [*op_data.shape],
|
||||
[*op_data.max_shape], [*op_data.raw_shape],
|
||||
dim_partition_dict=dim_partition_dict)
|
||||
if sharding_spec.sanity_check():
|
||||
return sharding_spec
|
||||
else:
|
||||
return None
|
||||
|
||||
sharding_spec = _to_sharding_spec(op_data, dim_partition_dict)
|
||||
if sharding_spec:
|
||||
results[op_data_name] = sharding_spec
|
||||
else:
|
||||
return {}
|
||||
return results
|
||||
|
||||
def _get_sharding_strategy(self, name, sharding_spec_mapping,
|
||||
communication_action_mapping):
|
||||
return ShardingStrategy(
|
||||
name=name,
|
||||
sharding_specs=sharding_spec_mapping,
|
||||
communication_actions=communication_action_mapping,
|
||||
)
|
||||
|
||||
def _remove_duplicated_strategy(self, strategies_vector):
|
||||
name_checklist = []
|
||||
remove_list = []
|
||||
for strategy in strategies_vector:
|
||||
if strategy.name not in name_checklist:
|
||||
name_checklist.append(strategy.name)
|
||||
else:
|
||||
remove_list.append(strategy)
|
||||
for strategy in remove_list:
|
||||
strategies_vector.remove(strategy)
|
||||
|
||||
def _post_process(self, strategies_vector):
|
||||
# TODO: deal with transpose and dimension 1 problem in ClossalAI, which have been processed before
|
||||
for i in range(len(strategies_vector) - 1, -1, -1):
|
||||
if strategies_vector[i] is None:
|
||||
strategies_vector.pop(i)
|
||||
|
||||
self._remove_duplicated_strategy(strategies_vector)
|
||||
return strategies_vector
|
||||
|
||||
def _profile_sharding_cost(self, strategy, device_mesh: LogicalDeviceMesh):
|
||||
elapsed_time = self.node_runtime_profiler.runtime_profile(
|
||||
self.layer, {}, {}, strategy, device_mesh)
|
||||
return elapsed_time
|
||||
|
||||
def _model_sharding_cost_from_s_curve(self, strategy,
|
||||
device_mesh: LogicalDeviceMesh):
|
||||
'''
|
||||
[ToDo] preprofile the s_curve
|
||||
'''
|
||||
sharding_cost = 0.0
|
||||
return sharding_cost
|
||||
|
||||
# this method might be overwritten by some Ops
|
||||
def _get_math_time(self, strategy, device_mesh: LogicalDeviceMesh):
|
||||
return 0.0
|
||||
|
||||
# this method might be overwritten by some Ops
|
||||
def _get_memory_time(self, strategy, device_mesh: LogicalDeviceMesh):
|
||||
memory_time = (strategy.inout_memory_footprint /
|
||||
device_mesh.cluster_info.memory_bw * 1e-3 *
|
||||
device_mesh.cluster_info.memory_efficiency)
|
||||
return memory_time
|
||||
|
||||
def _model_sharding_cost_from_alpha_beta(self, strategy,
|
||||
device_mesh: LogicalDeviceMesh):
|
||||
math_time = self._get_math_time(strategy, device_mesh)
|
||||
mem_time = self._get_memory_time(strategy, device_mesh)
|
||||
return max(math_time, mem_time)
|
||||
|
||||
def _get_communication_cost(self, strategy):
|
||||
total_comm_cost = 0.0
|
||||
for op_data_name, comm_spec in strategy.communication_actions.items():
|
||||
comm_cost = comm_spec.get_comm_cost()
|
||||
total_comm_cost = total_comm_cost + comm_cost
|
||||
return total_comm_cost
|
||||
|
||||
def _update_sharding_cost(self, strategies, device_mesh: LogicalDeviceMesh):
|
||||
self._update_memory_cost(strategies)
|
||||
|
||||
if device_mesh.config.sharding_cost_model == CostModel.ALPHA_BETA:
|
||||
for strategy in strategies:
|
||||
strategy.sharding_cost = self._model_sharding_cost_from_alpha_beta(
|
||||
strategy, device_mesh)
|
||||
elif device_mesh.config.sharding_cost_model == CostModel.S_CURVE:
|
||||
for strategy in strategies:
|
||||
strategy.sharding_cost = self._model_sharding_cost_from_s_curve(
|
||||
strategy, device_mesh)
|
||||
elif device_mesh.config.sharding_cost_model == CostModel.PROFILE:
|
||||
for strategy in strategies:
|
||||
strategy.alpha_beta_cost = self._model_sharding_cost_from_alpha_beta(
|
||||
strategy, device_mesh)
|
||||
if self.is_shape_io:
|
||||
strategy.sharding_cost = strategy.alpha_beta_cost
|
||||
else:
|
||||
strategy.sharding_cost = self._profile_sharding_cost(
|
||||
strategy, device_mesh)
|
||||
elif device_mesh.config.sharding_cost_model == CostModel.ZERO:
|
||||
for strategy in strategies:
|
||||
strategy.sharding_cost = 0.0
|
||||
else:
|
||||
assert False, 'unsupport sharding cost model option: {}'.format(
|
||||
device_mesh.config.sharding_cost_model)
|
||||
|
||||
for strategy in strategies:
|
||||
strategy.communication_cost = self._get_communication_cost(strategy)
|
||||
|
||||
def _compute_resharding_cost(self, pre_sharding_sepc, cur_sharding_spec,
|
||||
op_data):
|
||||
transform_path, comm_action_sequence, resharding_cost = cur_sharding_spec.device_mesh.shape_consistency_manager.shape_consistency(
|
||||
pre_sharding_sepc, cur_sharding_spec)
|
||||
return (transform_path, comm_action_sequence, resharding_cost)
|
||||
|
||||
def _update_resharding_cost(self, strategies):
|
||||
for strategy in strategies:
|
||||
resharding_costs = {}
|
||||
for pre_node, out_index in self.predecessor_nodes_out_index.items():
|
||||
if pre_node is None:
|
||||
continue
|
||||
pre_node_out_data_name = pre_node.get_output(out_index).name
|
||||
pre_node_out_data_lname = pre_node.global_to_local_op_name[
|
||||
pre_node_out_data_name]
|
||||
if pre_node_out_data_name not in self.global_to_local_op_name:
|
||||
print(f"pre_node_out_data_name = {pre_node_out_data_name}")
|
||||
continue
|
||||
cur_node_inp_data_lname = self.global_to_local_op_name[
|
||||
pre_node_out_data_name]
|
||||
cur_sharding_spec = strategy.sharding_specs[
|
||||
cur_node_inp_data_lname]
|
||||
|
||||
pre_node_out_sharding_specs = []
|
||||
for pre_strategy in pre_node.strategies_vector:
|
||||
pre_node_out_sharding_specs.append(
|
||||
pre_strategy.sharding_specs[pre_node_out_data_lname])
|
||||
|
||||
if pre_node not in resharding_costs:
|
||||
resharding_costs[pre_node.node_name] = []
|
||||
for prev_sharding_spec in pre_node_out_sharding_specs:
|
||||
resharding_cost = self._compute_resharding_cost(
|
||||
prev_sharding_spec, cur_sharding_spec,
|
||||
self.op_data[cur_node_inp_data_lname])
|
||||
resharding_costs[pre_node.node_name].append(resharding_cost)
|
||||
strategy.resharding_costs = resharding_costs
|
||||
|
||||
def _enumerate_all_possible_1d_sharding(self, mesh_dim, dim_size):
|
||||
dim_partition_list = []
|
||||
for i in range(dim_size):
|
||||
dim_partition_list.append({i: mesh_dim})
|
||||
return dim_partition_list
|
||||
|
||||
def _enumerate_all_possible_2d_sharding(self, mesh_dim0, mesh_dim1,
|
||||
dim_size):
|
||||
dim_partition_list = []
|
||||
for i in range(dim_size):
|
||||
for j in range(dim_size):
|
||||
if i != j:
|
||||
dim_partition_list.append({i: mesh_dim0, j: mesh_dim1})
|
||||
return dim_partition_list
|
||||
|
||||
def _update_memory_cost(self, strategies):
|
||||
for strategy in strategies:
|
||||
inout_memory_footprint, max_inout_memory_footprint = 0.0, 0.0
|
||||
for spec in strategy.sharding_specs.values():
|
||||
inout_memory_footprint += spec.get_sharded_size_per_device()
|
||||
max_inout_memory_footprint += spec.get_max_sharded_size_per_device(
|
||||
)
|
||||
|
||||
# the communication happens
|
||||
comm_buffer_footprint, max_comm_buffer_footprint = 0.0, 0.0
|
||||
for comm_spec in strategy.communication_actions.values():
|
||||
comm_buffer_footprint += comm_spec.get_mem_cost()
|
||||
max_comm_buffer_footprint += comm_spec.get_max_mem_cost()
|
||||
|
||||
# when doing the output0 comm action, the input buffer should be released, the buffer is used to estimate the memory time
|
||||
# rather than memory usage
|
||||
strategy.inout_memory_footprint = inout_memory_footprint
|
||||
|
||||
strategy.comm_buff_memory_footprint = comm_buffer_footprint
|
||||
strategy.peak_memory_footprint = max(max_inout_memory_footprint,
|
||||
max_comm_buffer_footprint)
|
||||
|
||||
# The const memory (weight) is recorded in constant layers and should be accumulated
|
||||
strategy.const_memory_footprint = 0.0
|
||||
|
||||
def _generate_bcast_dims(self, batch_dims, out_data_shape):
|
||||
for output in self.outputs:
|
||||
if output.broadcast_across_batch:
|
||||
for bs in batch_dims:
|
||||
if output.shape[bs] == 1 and output.shape[
|
||||
bs] != out_data_shape[bs]:
|
||||
output.attrs["broadcast_dims"].append(bs)
|
||||
|
||||
def _recover_bcast_partition_dict(self, partition_dict, op_data):
|
||||
ret = {}
|
||||
for data_dim, mesh_dim in partition_dict.items():
|
||||
if data_dim not in op_data.attrs[
|
||||
"broadcast_dims"] and data_dim + len(
|
||||
op_data.shape) not in op_data.attrs[
|
||||
"broadcast_dims"] and op_data.shape[data_dim] != 1:
|
||||
ret[data_dim] = mesh_dim
|
||||
return ret
|
||||
@ -1,60 +0,0 @@
|
||||
from .node import Node
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class Normalization(Node):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
layer.to_subclass()
|
||||
self.axes = layer.as_trt().axes
|
||||
self.weight_bias_dim_base = 0
|
||||
layer.to_base_class()
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['input0'].shape)
|
||||
dim_partition_list.append({})
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
||||
strategies_vector = StrategiesVector(self)
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
shard_reduction_axes = False
|
||||
for dim in range(len(self.get_input(0).shape)):
|
||||
if (self.axes & (1 << dim)) and dim in dim_partition_dict:
|
||||
shard_reduction_axes = True
|
||||
break
|
||||
if shard_reduction_axes:
|
||||
continue
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": dim_partition_dict,
|
||||
"output0": dim_partition_dict,
|
||||
}
|
||||
if self.num_inputs >= 2:
|
||||
dim_partition_dict_mapping['input1'] = {}
|
||||
if self.num_inputs >= 3:
|
||||
dim_partition_dict_mapping['input2'] = {}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
name = '{} = {} <normalization op> scale {}, bias {}'.format(
|
||||
sharding_spec_mapping['output0'].sharding_sequence,
|
||||
sharding_spec_mapping['input0'].sharding_sequence,
|
||||
sharding_spec_mapping['input1'].sharding_sequence
|
||||
if self.num_inputs >= 2 else 'None',
|
||||
sharding_spec_mapping['input2'].sharding_sequence
|
||||
if self.num_inputs >= 3 else 'None',
|
||||
)
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
return strategies_vector
|
||||
@ -1,79 +0,0 @@
|
||||
from .node import Node
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class OuputNode(Node):
|
||||
|
||||
def _update_memory_cost(self, strategies):
|
||||
for strategy in strategies:
|
||||
if not self.no_memory_footprint:
|
||||
strategy.const_memory_footprint = strategy.sharding_specs[
|
||||
'input0'].get_max_sharded_size_per_device()
|
||||
|
||||
def __init__(self, tensor):
|
||||
self._layer = None
|
||||
self.is_shape_io = False
|
||||
self._inputs = []
|
||||
self._outputs = []
|
||||
self.predecessor_nodes = []
|
||||
self.predecessor_nodes_out_index = {}
|
||||
self.successor_nodes = []
|
||||
self.op_data = {}
|
||||
self.global_to_local_op_name = {}
|
||||
self.is_replicated = tensor.attrs.get("is_replicated", False)
|
||||
self.same_spec_id = tensor.attrs.get("same_spec_id", -1)
|
||||
self.no_memory_footprint = tensor.attrs.get("no_memory_footprint",
|
||||
False)
|
||||
self.building_block_id = -1
|
||||
self.cost_level = -1
|
||||
self.stage_type = None
|
||||
self.in_start_block = None
|
||||
self.in_end_block = None
|
||||
self.in_slowest_block = None
|
||||
input = tensor.copy()
|
||||
self._inputs.append(input)
|
||||
self.op_data['input0'] = input
|
||||
self.global_to_local_op_name[input.name] = 'input0'
|
||||
|
||||
self.sharding_weight = 1.0
|
||||
self.resharding_weight = 1.0
|
||||
self.pipeline_weight = 0
|
||||
self.node_name = tensor.name
|
||||
self.node_type = 'output_node'
|
||||
self.num_inputs = 0
|
||||
self.num_outputs = 1
|
||||
self.dtype = tensor.dtype
|
||||
self.strategies_vector = []
|
||||
self.node_runtime_profiler = None
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['input0'].shape)
|
||||
dim_partition_list.append({})
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
||||
|
||||
strategies_vector = StrategiesVector(self)
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
dim_partition_dict_mapping = {'input0': dim_partition_dict}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
sharding_seq = sharding_spec_mapping['input0'].sharding_sequence
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=f'output-op {sharding_seq}',
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
|
||||
return strategies_vector
|
||||
|
||||
def _profile_sharding_cost(self, strategy, device_mesh):
|
||||
return 0.0
|
||||
@ -1,67 +0,0 @@
|
||||
import copy
|
||||
from enum import Enum
|
||||
|
||||
from .comm_spec import CommSpec
|
||||
from .identity_node import Identity
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class P2PType(Enum):
|
||||
CROSS_DEVICE = 0
|
||||
CROSS_HOST = 1
|
||||
|
||||
|
||||
class P2PNode(Identity):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
self.p2p_type = layer.attrs["p2p_type"]
|
||||
self.is_fake = True
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
# one input for softmax node
|
||||
predecessor = self.predecessor_nodes[0]
|
||||
strategies_vector = StrategiesVector(self)
|
||||
for idx, strategy in enumerate(predecessor.strategies_vector):
|
||||
# current node's local name input0 -> global name xxx
|
||||
global_input_name = self.op_data['input0'].name
|
||||
# global name xxx -> pre node local output name
|
||||
prenode_local_name = predecessor.global_to_local_op_name[
|
||||
global_input_name]
|
||||
dim_partition_dict = copy.deepcopy(
|
||||
strategy.sharding_specs[prenode_local_name].dim_partition_dict)
|
||||
in0_partition_dict = dim_partition_dict
|
||||
out_partition_dict = copy.deepcopy(dim_partition_dict)
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_partition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
|
||||
logical_process_axis = [
|
||||
['p2p_cross_device']
|
||||
] if self.p2p_type == P2PType.CROSS_DEVICE else [['p2p_cross_host']]
|
||||
# get communication action mapping
|
||||
communication_action_mapping = {}
|
||||
output0_comm_action = CommSpec(
|
||||
comm_pattern='peer_to_peer',
|
||||
sharding_spec=sharding_spec_mapping['output0'],
|
||||
logical_process_axis=logical_process_axis,
|
||||
)
|
||||
communication_action_mapping['output0'] = output0_comm_action
|
||||
|
||||
name = '{} = <P2P op> {}'.format(
|
||||
sharding_spec_mapping['output0'].sharding_sequence,
|
||||
sharding_spec_mapping['input0'].sharding_sequence)
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
strategies_vector.append(sharding_strategy)
|
||||
return strategies_vector
|
||||
|
||||
def _profile_sharding_cost(self, strategy, device_mesh):
|
||||
return 0.0
|
||||
@ -1,40 +0,0 @@
|
||||
from tensorrt_llm.network import PluginInfo, get_plugin_info
|
||||
|
||||
from .node import Node
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class PluginNode(Node):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
layer.to_subclass()
|
||||
self.plugin = layer.as_trt().plugin
|
||||
self.plugin_type: str = self.plugin.plugin_type
|
||||
self.plugin_info: PluginInfo = get_plugin_info(layer.graph.as_trt(),
|
||||
layer.name)
|
||||
layer.to_base_class()
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
raise NotImplementedError(
|
||||
f"Auto parallel does not support {self.plugin_type} plugin right now."
|
||||
)
|
||||
|
||||
def _default_strategy(self, device_mesh):
|
||||
strategies_vector = StrategiesVector(self)
|
||||
dim_partition_dict_mapping = {}
|
||||
for idx in range(self.num_inputs):
|
||||
dim_partition_dict_mapping[f'input{idx}'] = {}
|
||||
for idx in range(self.num_outputs):
|
||||
dim_partition_dict_mapping[f'output{idx}'] = {}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
return strategies_vector
|
||||
name = '{}_all_replicate'.format(self.plugin_type)
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
return strategies_vector
|
||||
@ -1,27 +0,0 @@
|
||||
import tensorrt as trt
|
||||
|
||||
from tensorrt_llm._utils import trt_dtype_to_str
|
||||
|
||||
from ..matmul_node import MatrixMultiply
|
||||
from ..plugin_node import PluginNode
|
||||
|
||||
|
||||
class GemmPlugin(MatrixMultiply, PluginNode):
|
||||
|
||||
def __init__(self, layer):
|
||||
PluginNode.__init__(self, layer)
|
||||
batch_dims = [i for i in range(len(self.get_output(0).shape))][:-2]
|
||||
self._generate_bcast_dims(batch_dims, self.get_output(0).shape)
|
||||
pfc_as_list = self.plugin_info.pfc_as_list
|
||||
self.op0_transpose = (pfc_as_list['transa'][0] == 1)
|
||||
self.op1_transpose = (pfc_as_list['transb'][0] == 1)
|
||||
self.num_out_dims = len(self.get_output(0).shape)
|
||||
self.dtype = trt_dtype_to_str(trt.DataType(pfc_as_list['type_id'][0]))
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
strategies_vector = MatrixMultiply._collect_strategies(
|
||||
self, device_mesh)
|
||||
return strategies_vector
|
||||
|
||||
def _get_math_time(self, strategy, device_mesh):
|
||||
return MatrixMultiply._get_math_time(self, strategy, device_mesh)
|
||||
@ -1,437 +0,0 @@
|
||||
from enum import Enum, auto
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from tensorrt_llm.functional import AttentionMaskType, PositionEmbeddingType
|
||||
from tensorrt_llm.quantization import QuantMode
|
||||
|
||||
from ..plugin_node import PluginNode
|
||||
from ..sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
# WARNING: Must in sync with IdxEntry in cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h
|
||||
class IdxEntry(Enum):
|
||||
QKV_TENSOR = auto()
|
||||
K_TENSOR = auto()
|
||||
V_TENSOR = auto()
|
||||
CONTEXT_FMHA_CUSTOM_MASK = auto()
|
||||
SEQUENCE_LENGTH = auto()
|
||||
HOST_PAST_KEY_VALUE_LENGTHS = auto()
|
||||
HOST_MAX_ATTENTION_WINDOW = auto()
|
||||
HOST_SINK_TOKEN_LENGTH = auto()
|
||||
CONTEXT_LENGTHS = auto()
|
||||
CACHE_INDIR = auto()
|
||||
REQUEST_TYPES = auto()
|
||||
KV_CACHE_BLOCK_OFFSETS = auto()
|
||||
HOST_KV_CACHE_BLOCK_OFFSETS = auto()
|
||||
HOST_KV_CACHE_POOL_POINTERS = auto()
|
||||
HOST_KV_CACHE_POOL_MAPPING = auto()
|
||||
PAST_KEY_VALUE = auto()
|
||||
KV_CACHE_QUANTIZATION_SCALE = auto()
|
||||
KV_CACHE_DEQUANTIZATION_SCALE = auto()
|
||||
ATTENTION_OUTPUT_QUANTIZATION_SCALE = auto()
|
||||
ATTENTION_OUTPUT_SF_SCALE = auto()
|
||||
ROTARY_INV_FREQ = auto()
|
||||
ROTARY_COS_SIN = auto()
|
||||
ALIBI_SLOPES = auto()
|
||||
RELATIVE_ATTENTION_BIAS = auto()
|
||||
CROSS_KV = auto()
|
||||
CROSS_KV_LENGTH = auto()
|
||||
ENCODER_INPUT_LENGTH = auto()
|
||||
HOST_CONTEXT_LENGTH = auto()
|
||||
QKV_BIAS_TENSOR = auto()
|
||||
SPEC_DECODING_GENERATION_LENGTHS = auto()
|
||||
SPEC_DECODING_PACKED_MASK = auto()
|
||||
SPEC_DECODING_POSITION_OFFSETS = auto()
|
||||
MROPE_ROTARY_COS_SIN = auto()
|
||||
MROPE_POSITION_DELTAS = auto()
|
||||
HOST_RUNTIME_PERF_KNOBS = auto()
|
||||
HOST_CONTEXT_PROGRESS = auto()
|
||||
MLA_FUSED_Q_PROJ_TENSOR = auto()
|
||||
MLA_Q_B_PROJ_TENSOR = auto()
|
||||
MLA_KV_B_PROJ_TENSOR = auto()
|
||||
LOGN_SCALING = auto()
|
||||
|
||||
|
||||
class IdxEntryParser:
|
||||
|
||||
def __init__(self, plugin_info):
|
||||
self.num_kv_heads = plugin_info.pfc_as_list['num_kv_heads'][0]
|
||||
self.unfuse_qkv_gemm = bool(
|
||||
plugin_info.pfc_as_list['unfuse_qkv_gemm'][0])
|
||||
self.use_fp8_context_fmha = bool(
|
||||
plugin_info.pfc_as_list['use_fp8_context_fmha'][0])
|
||||
self.fuse_fp4_quant = bool(plugin_info.pfc_as_list['fuse_fp4_quant'][0])
|
||||
self.mask_type = AttentionMaskType(
|
||||
plugin_info.pfc_as_list['mask_type'][0])
|
||||
self.use_cache = bool(plugin_info.pfc_as_list['use_cache'][0])
|
||||
self.paged_kv_cache = bool(plugin_info.pfc_as_list['paged_kv_cache'][0])
|
||||
self.do_cross_attention = bool(
|
||||
plugin_info.pfc_as_list['do_cross_attention'][0])
|
||||
self.remove_input_padding = bool(
|
||||
plugin_info.pfc_as_list['remove_input_padding'][0])
|
||||
self.qkv_bias_enabled = bool(
|
||||
plugin_info.pfc_as_list['qkv_bias_enabled'][0])
|
||||
self.kv_cache_quant_mode = QuantMode(
|
||||
plugin_info.pfc_as_list['kv_cache_quant_mode'][0])
|
||||
self.position_embedding_type = PositionEmbeddingType(
|
||||
plugin_info.pfc_as_list['position_embedding_type'][0])
|
||||
self.is_spec_decoding_enabled = bool(
|
||||
plugin_info.pfc_as_list['is_spec_decoding_enabled'][0])
|
||||
self.is_mla_enabled = bool(plugin_info.pfc_as_list['is_mla_enabled'][0])
|
||||
self.use_logn_scaling = bool(
|
||||
plugin_info.pfc_as_list['use_logn_scaling'][0])
|
||||
self.init_entry_to_index()
|
||||
|
||||
# WARNING: Must in sync with GPTAttentionPlugin::isEntryUsed in cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp
|
||||
def is_entry_used(self, entry: IdxEntry) -> bool:
|
||||
if entry == IdxEntry.QKV_TENSOR:
|
||||
return True
|
||||
elif entry == IdxEntry.K_TENSOR:
|
||||
return self.unfuse_qkv_gemm
|
||||
elif entry == IdxEntry.V_TENSOR:
|
||||
return self.unfuse_qkv_gemm
|
||||
elif entry == IdxEntry.CONTEXT_FMHA_CUSTOM_MASK:
|
||||
return self.mask_type == AttentionMaskType.custom_mask
|
||||
elif entry == IdxEntry.SEQUENCE_LENGTH:
|
||||
return self.use_cache
|
||||
elif entry == IdxEntry.HOST_PAST_KEY_VALUE_LENGTHS:
|
||||
return self.use_cache
|
||||
elif entry == IdxEntry.HOST_MAX_ATTENTION_WINDOW:
|
||||
return True
|
||||
elif entry == IdxEntry.HOST_SINK_TOKEN_LENGTH:
|
||||
return True
|
||||
elif entry == IdxEntry.CONTEXT_LENGTHS:
|
||||
return True
|
||||
elif entry == IdxEntry.CACHE_INDIR:
|
||||
return self.use_cache
|
||||
elif entry == IdxEntry.REQUEST_TYPES:
|
||||
return True
|
||||
elif entry == IdxEntry.KV_CACHE_BLOCK_OFFSETS:
|
||||
return self.use_cache and self.paged_kv_cache
|
||||
elif entry == IdxEntry.HOST_KV_CACHE_BLOCK_OFFSETS:
|
||||
return self.use_cache and self.paged_kv_cache
|
||||
elif entry == IdxEntry.HOST_KV_CACHE_POOL_POINTERS:
|
||||
return self.use_cache and self.paged_kv_cache
|
||||
elif entry == IdxEntry.HOST_KV_CACHE_POOL_MAPPING:
|
||||
return self.use_cache and self.paged_kv_cache
|
||||
elif entry == IdxEntry.PAST_KEY_VALUE:
|
||||
return self.use_cache and not self.paged_kv_cache
|
||||
elif entry == IdxEntry.KV_CACHE_QUANTIZATION_SCALE:
|
||||
return self.use_cache and self.kv_cache_quant_mode.has_kv_cache_quant(
|
||||
)
|
||||
elif entry == IdxEntry.KV_CACHE_DEQUANTIZATION_SCALE:
|
||||
return self.use_cache and self.kv_cache_quant_mode.has_kv_cache_quant(
|
||||
)
|
||||
elif entry == IdxEntry.ATTENTION_OUTPUT_QUANTIZATION_SCALE:
|
||||
return self.use_fp8_context_fmha and self.kv_cache_quant_mode.has_fp8_qdp(
|
||||
)
|
||||
elif entry == IdxEntry.ATTENTION_OUTPUT_SF_SCALE:
|
||||
return self.fuse_fp4_quant
|
||||
elif entry == IdxEntry.ROTARY_INV_FREQ:
|
||||
return self.position_embedding_type.is_rope()
|
||||
elif entry == IdxEntry.ROTARY_COS_SIN:
|
||||
return self.position_embedding_type.is_rope()
|
||||
elif entry == IdxEntry.ALIBI_SLOPES:
|
||||
return self.position_embedding_type.is_alibi()
|
||||
elif entry == IdxEntry.RELATIVE_ATTENTION_BIAS:
|
||||
return self.position_embedding_type == PositionEmbeddingType.relative
|
||||
elif entry == IdxEntry.CROSS_KV:
|
||||
return self.do_cross_attention
|
||||
elif entry == IdxEntry.CROSS_KV_LENGTH:
|
||||
return self.do_cross_attention
|
||||
elif entry == IdxEntry.ENCODER_INPUT_LENGTH:
|
||||
return self.do_cross_attention
|
||||
elif entry == IdxEntry.HOST_CONTEXT_LENGTH:
|
||||
return self.remove_input_padding
|
||||
elif entry == IdxEntry.QKV_BIAS_TENSOR:
|
||||
return self.qkv_bias_enabled
|
||||
elif entry == IdxEntry.SPEC_DECODING_PACKED_MASK:
|
||||
return self.is_spec_decoding_enabled
|
||||
elif entry == IdxEntry.SPEC_DECODING_POSITION_OFFSETS:
|
||||
return self.is_spec_decoding_enabled
|
||||
elif entry == IdxEntry.SPEC_DECODING_GENERATION_LENGTHS:
|
||||
return self.is_spec_decoding_enabled
|
||||
elif entry == IdxEntry.MROPE_ROTARY_COS_SIN:
|
||||
return self.position_embedding_type.is_mrope()
|
||||
elif entry == IdxEntry.MROPE_POSITION_DELTAS:
|
||||
return self.position_embedding_type.is_mrope()
|
||||
elif entry == IdxEntry.HOST_RUNTIME_PERF_KNOBS:
|
||||
return True
|
||||
elif entry == IdxEntry.HOST_CONTEXT_PROGRESS:
|
||||
return True
|
||||
elif entry == IdxEntry.MLA_FUSED_Q_PROJ_TENSOR:
|
||||
return self.is_mla_enabled
|
||||
elif entry == IdxEntry.MLA_Q_B_PROJ_TENSOR:
|
||||
return self.is_mla_enabled
|
||||
elif entry == IdxEntry.MLA_KV_B_PROJ_TENSOR:
|
||||
return self.is_mla_enabled
|
||||
elif entry == IdxEntry.LOGN_SCALING:
|
||||
return self.use_logn_scaling
|
||||
else:
|
||||
return False
|
||||
|
||||
def init_entry_to_index(self):
|
||||
self.entry_to_index = {}
|
||||
index = 0
|
||||
for entry in IdxEntry:
|
||||
if self.is_entry_used(entry):
|
||||
self.entry_to_index[entry] = index
|
||||
index += 1
|
||||
|
||||
def get_index(self, entry: IdxEntry) -> int:
|
||||
if entry not in self.entry_to_index:
|
||||
raise Exception(
|
||||
f"Entry {entry} is not existed in gpt attention plugin layer {self.layer.name}"
|
||||
)
|
||||
return self.entry_to_index[entry]
|
||||
|
||||
|
||||
def get_partition(device_dim, device_ids):
|
||||
if device_dim == [0]:
|
||||
partition = device_ids.shape[0]
|
||||
elif device_dim == [1]:
|
||||
partition = device_ids.shape[1]
|
||||
else:
|
||||
assert device_dim == [0, 1] or device_dim == [1, 0]
|
||||
partition = device_ids.size
|
||||
return partition
|
||||
|
||||
|
||||
class GPTAttentionPlugin(PluginNode):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
self.parser = IdxEntryParser(self.plugin_info)
|
||||
assert self.num_inputs == len(
|
||||
self.parser.entry_to_index
|
||||
), f'the number of plugin inputs ({self.num_inputs}) is invalid'
|
||||
assert self.num_outputs == (
|
||||
2 if self.parser.is_entry_used(IdxEntry.PAST_KEY_VALUE) else 1
|
||||
), f'the number of plugin outputs ({self.num_outputs}) has been changed'
|
||||
|
||||
def _tp_strategy(self, device_mesh):
|
||||
strategies_vector = StrategiesVector(self)
|
||||
head_dim = 1 if self.parser.remove_input_padding else 2
|
||||
# TODO: allow mesh_dim = [0] or [1]
|
||||
# for mesh_dim in ([0], [1], [0, 1]):
|
||||
for mesh_dim in ([0, 1], ):
|
||||
if self.parser.num_kv_heads != 1:
|
||||
# MHA or GQA
|
||||
# TODO: allow to duplicate kv when #kv_head < #partition
|
||||
q_pdict = {
|
||||
head_dim: mesh_dim
|
||||
} # split in heads/hidden dimension
|
||||
k_pdict = {
|
||||
head_dim: mesh_dim
|
||||
} # split in heads/hidden dimension
|
||||
v_pdict = {
|
||||
head_dim: mesh_dim
|
||||
} # split in heads/hidden dimension
|
||||
pastkv_pdict = {2: mesh_dim} # split in heads dimension
|
||||
present_kv_pdict = {2: mesh_dim} # split in heads dimension
|
||||
else:
|
||||
# MQA
|
||||
q_pdict = {
|
||||
head_dim: mesh_dim
|
||||
} # split in heads/hidden dimension
|
||||
k_pdict = {} # RR
|
||||
v_pdict = {} # RR
|
||||
pastkv_pdict = {} # RR
|
||||
present_kv_pdict = {} # RR
|
||||
|
||||
out0_pdict = {head_dim: mesh_dim}
|
||||
|
||||
dim_partition_dict_mapping = {
|
||||
f'input{self.parser.get_index(IdxEntry.QKV_TENSOR)}': q_pdict,
|
||||
f'input{self.parser.get_index(IdxEntry.K_TENSOR)}': k_pdict,
|
||||
f'input{self.parser.get_index(IdxEntry.V_TENSOR)}': v_pdict,
|
||||
'output0': out0_pdict,
|
||||
}
|
||||
if self.parser.is_entry_used(IdxEntry.PAST_KEY_VALUE):
|
||||
dim_partition_dict_mapping[
|
||||
f'input{self.parser.get_index(IdxEntry.PAST_KEY_VALUE)}'] = pastkv_pdict
|
||||
dim_partition_dict_mapping['output1'] = present_kv_pdict
|
||||
for i in range(self.num_inputs):
|
||||
if f'input{i}' not in dim_partition_dict_mapping:
|
||||
dim_partition_dict_mapping[f'input{i}'] = {}
|
||||
for i in range(self.num_outputs):
|
||||
if f'output{i}' not in dim_partition_dict_mapping:
|
||||
dim_partition_dict_mapping[f'output{i}'] = {}
|
||||
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
name = 'gptAttentionPlugin_tp_strategy'
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
return strategies_vector
|
||||
|
||||
def _dp_strategy(self, device_mesh):
|
||||
strategies_vector = StrategiesVector(self)
|
||||
for mesh_dim in ([0], [1], [0, 1]):
|
||||
dim_partition_dict_mapping = {}
|
||||
for i in range(self.num_inputs):
|
||||
dim_partition_dict_mapping[f'input{i}'] = {0: mesh_dim}
|
||||
for i in range(self.num_outputs):
|
||||
dim_partition_dict_mapping[f'output{i}'] = {0: mesh_dim}
|
||||
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
name = 'gptAttentionPlugin_dp_strategy'
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
return strategies_vector
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
if device_mesh.size == 1:
|
||||
default_strategies = self._default_strategy(device_mesh)
|
||||
else:
|
||||
# Avoid to use all-replicate strategy for mesh size > 1
|
||||
# since the CPP runtime does not support it for gpt attention plugin
|
||||
default_strategies = StrategiesVector(self)
|
||||
for idx, strategy in enumerate(default_strategies):
|
||||
strategy.name = 'gptAttentionPlugin_' + strategy.name + f'{idx}'
|
||||
if self.parser.unfuse_qkv_gemm:
|
||||
tp_strategies = self._tp_strategy(device_mesh)
|
||||
default_strategies.extend(tp_strategies)
|
||||
# if we don't split the batch dim, it should be default strategis
|
||||
# elif we split the batch dim, it should be dp_strategies
|
||||
# we can use above information to distinguish the two kinds of strategy
|
||||
if not self.parser.remove_input_padding:
|
||||
dp_strategies = self._dp_strategy(device_mesh)
|
||||
default_strategies.extend(dp_strategies)
|
||||
return default_strategies
|
||||
|
||||
@staticmethod
|
||||
def parameter_generator(sharding_specs, plugin_info):
|
||||
|
||||
def get_shape(entry):
|
||||
return sharding_specs[
|
||||
f'input{parser.get_index(entry)}'].get_sharded_shape_per_device(
|
||||
)
|
||||
|
||||
parser = IdxEntryParser(plugin_info)
|
||||
updated_input_values = {}
|
||||
batch_size = get_shape(IdxEntry.CONTEXT_LENGTHS)[0]
|
||||
if parser.use_cache:
|
||||
beams_width = get_shape(IdxEntry.CACHE_INDIR)[1]
|
||||
max_seq_length = get_shape(IdxEntry.CACHE_INDIR)[2]
|
||||
elif not parser.remove_input_padding:
|
||||
max_seq_length = get_shape(IdxEntry.QKV_BIAS_TENSOR)[1]
|
||||
else:
|
||||
max_seq_length = 1
|
||||
host_request_types = torch.full(
|
||||
(batch_size, ),
|
||||
1,
|
||||
dtype=torch.int32,
|
||||
device='cpu',
|
||||
)
|
||||
updated_input_values[parser.get_index(
|
||||
IdxEntry.REQUEST_TYPES)] = host_request_types
|
||||
context_lengths = torch.full(
|
||||
(batch_size, ),
|
||||
max_seq_length - 1,
|
||||
dtype=torch.int32,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
updated_input_values[parser.get_index(
|
||||
IdxEntry.CONTEXT_LENGTHS)] = context_lengths
|
||||
host_max_attention_window_sizes = torch.tensor(
|
||||
[max_seq_length],
|
||||
dtype=torch.int32,
|
||||
device='cpu',
|
||||
)
|
||||
updated_input_values[parser.get_index(
|
||||
IdxEntry.HOST_MAX_ATTENTION_WINDOW
|
||||
)] = host_max_attention_window_sizes
|
||||
host_sink_token_length = torch.tensor(
|
||||
[0],
|
||||
dtype=torch.int32,
|
||||
device='cpu',
|
||||
)
|
||||
updated_input_values[parser.get_index(
|
||||
IdxEntry.HOST_SINK_TOKEN_LENGTH)] = host_sink_token_length
|
||||
if parser.use_cache:
|
||||
sequence_length = torch.full((batch_size, ),
|
||||
max_seq_length,
|
||||
dtype=torch.int32,
|
||||
device=torch.cuda.current_device())
|
||||
updated_input_values[parser.get_index(
|
||||
IdxEntry.SEQUENCE_LENGTH)] = sequence_length
|
||||
host_past_key_value_length = torch.full((batch_size, ),
|
||||
max_seq_length - 1,
|
||||
dtype=torch.int32,
|
||||
device='cpu')
|
||||
updated_input_values[parser.get_index(
|
||||
IdxEntry.HOST_PAST_KEY_VALUE_LENGTHS
|
||||
)] = host_past_key_value_length
|
||||
cache_indirections = torch.full(
|
||||
(batch_size, beams_width, max_seq_length),
|
||||
0,
|
||||
dtype=torch.int32,
|
||||
device=torch.cuda.current_device())
|
||||
updated_input_values[parser.get_index(
|
||||
IdxEntry.CACHE_INDIR)] = cache_indirections
|
||||
if parser.remove_input_padding:
|
||||
host_context_lengths = torch.full(get_shape(
|
||||
IdxEntry.HOST_CONTEXT_LENGTH),
|
||||
max_seq_length - 1,
|
||||
dtype=torch.int32,
|
||||
device='cpu')
|
||||
updated_input_values[parser.get_index(
|
||||
IdxEntry.HOST_CONTEXT_LENGTH)] = host_context_lengths
|
||||
return updated_input_values
|
||||
|
||||
def _profile_sharding_cost(self, strategy, device_mesh):
|
||||
sharding_spec = strategy.sharding_specs[
|
||||
f"input{self.parser.get_index(IdxEntry.QKV_TENSOR)}"]
|
||||
shard_dims = sharding_spec.dim_partition_dict
|
||||
device_ids = device_mesh.phy_ids
|
||||
if 2 in shard_dims:
|
||||
device_dim = shard_dims[2]
|
||||
partition = get_partition(device_dim, device_ids)
|
||||
else:
|
||||
partition = 1
|
||||
if self.parser.is_entry_used(IdxEntry.K_TENSOR):
|
||||
kv_sharding_spec = strategy.sharding_specs[
|
||||
f"input{self.parser.get_index(IdxEntry.K_TENSOR)}"]
|
||||
kv_shard_dims = kv_sharding_spec.dim_partition_dict
|
||||
if 2 in kv_shard_dims:
|
||||
kv_device_dim = kv_shard_dims[2]
|
||||
kv_partition = get_partition(kv_device_dim, device_ids)
|
||||
else:
|
||||
kv_partition = 1
|
||||
else:
|
||||
kv_partition = 1
|
||||
num_heads = self.plugin_info.pfc_as_ndarray["num_heads"].copy()
|
||||
num_kv_heads = self.plugin_info.pfc_as_ndarray["num_kv_heads"].copy()
|
||||
tp_size = self.plugin_info.pfc_as_ndarray["tp_size"].copy()
|
||||
tp_rank = self.plugin_info.pfc_as_ndarray["tp_rank"].copy()
|
||||
num_kv_heads = np.maximum(num_kv_heads // kv_partition, 1)
|
||||
num_heads = np.maximum(num_heads // partition, 1)
|
||||
tp_size[0] = partition
|
||||
tp_rank[0] = 0
|
||||
|
||||
updated_layer_attrs = {
|
||||
'tp_size': tp_size,
|
||||
'tp_rank': tp_rank,
|
||||
'num_heads': num_heads,
|
||||
'num_kv_heads': num_kv_heads
|
||||
}
|
||||
updated_input_values = self.parameter_generator(strategy.sharding_specs,
|
||||
self.plugin_info)
|
||||
elapsed_time = self.node_runtime_profiler.runtime_profile(
|
||||
self.layer, updated_layer_attrs, updated_input_values, strategy,
|
||||
device_mesh)
|
||||
return elapsed_time
|
||||
@ -1,11 +0,0 @@
|
||||
from ..identity_node import Identity
|
||||
from ..plugin_node import PluginNode
|
||||
|
||||
|
||||
class IdentityPlugin(Identity, PluginNode):
|
||||
|
||||
def __init__(self, layer):
|
||||
PluginNode.__init__(self, layer)
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
return Identity._collect_strategies(self, device_mesh)
|
||||
@ -1,19 +0,0 @@
|
||||
import tensorrt as trt
|
||||
|
||||
from ..gather_node import Gather
|
||||
from ..plugin_node import PluginNode
|
||||
|
||||
|
||||
class LookupPlugin(Gather, PluginNode):
|
||||
|
||||
def __init__(self, layer):
|
||||
PluginNode.__init__(self, layer)
|
||||
self.mode = trt.GatherMode.DEFAULT
|
||||
self.axis = 0
|
||||
self.num_elementwise_dims = 0
|
||||
self.input_id = 1
|
||||
self.indice_id = 0
|
||||
self.support_vocab_tp = True
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
return Gather._collect_strategies(self, device_mesh)
|
||||
@ -1,28 +0,0 @@
|
||||
from ..normalization_node import Normalization
|
||||
from ..plugin_node import PluginNode
|
||||
|
||||
|
||||
class LayernormPlugin(Normalization, PluginNode):
|
||||
|
||||
def __init__(self, layer):
|
||||
PluginNode.__init__(self, layer)
|
||||
# the is only true for llm model, because layer norm is only effect on hidden dim
|
||||
hidden_dim = len(self.op_data['input0'].shape) - 1
|
||||
self.axes = 1 << hidden_dim
|
||||
self.weight_bias_dim_base = hidden_dim
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
return Normalization._collect_strategies(self, device_mesh)
|
||||
|
||||
|
||||
class RMSnormPlugin(Normalization, PluginNode):
|
||||
|
||||
def __init__(self, layer):
|
||||
PluginNode.__init__(self, layer)
|
||||
# the is only true for llm model, because rms norm is only effect on hidden dim
|
||||
hidden_dim = len(self.op_data['input0'].shape) - 1
|
||||
self.axes = 1 << hidden_dim
|
||||
self.weight_bias_dim_base = hidden_dim
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
return Normalization._collect_strategies(self, device_mesh)
|
||||
@ -1,73 +0,0 @@
|
||||
from tensorrt_llm._utils import trt_axes_to_dim
|
||||
|
||||
from .node import Node
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class Reduce(Node):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
layer.to_subclass()
|
||||
self.reduce_dims = trt_axes_to_dim(layer.as_trt().axes)
|
||||
self.sum_mapping_dict = {}
|
||||
num_input_dims = len(self.get_input(0).shape)
|
||||
if layer.as_trt().keep_dims:
|
||||
for i in range(num_input_dims):
|
||||
self.sum_mapping_dict[i] = i
|
||||
else:
|
||||
output_index = 0
|
||||
for i in range(num_input_dims):
|
||||
if i not in self.reduce_dims:
|
||||
self.sum_mapping_dict[i] = output_index
|
||||
output_index += 1
|
||||
assert output_index == len(self.get_output(0).shape)
|
||||
layer.to_base_class()
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['input0'].shape)
|
||||
dim_partition_list.append({})
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
||||
strategies_vector = StrategiesVector(self)
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
recover_dims = []
|
||||
out_partition_dict = {}
|
||||
for dim in dim_partition_dict.keys():
|
||||
if dim in self.reduce_dims:
|
||||
recover_dims.append(dim)
|
||||
elif dim in self.sum_mapping_dict:
|
||||
out_partition_dict[
|
||||
self.sum_mapping_dict[dim]] = dim_partition_dict[dim]
|
||||
else:
|
||||
assert 0, f'dim {dim} is not in sum_dims or sum_mapping_dict'
|
||||
|
||||
for dim in recover_dims:
|
||||
dim_partition_dict.pop(dim)
|
||||
|
||||
in0_parition_dict = dim_partition_dict
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_parition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
name = '{} = <reduce along dim {}> {}'.format(
|
||||
sharding_spec_mapping['output0'].sharding_sequence,
|
||||
self.reduce_dims,
|
||||
sharding_spec_mapping['input0'].sharding_sequence)
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
return strategies_vector
|
||||
@ -1,56 +0,0 @@
|
||||
from .node import Node
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class Select(Node):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
batch_dims = [i for i in range(len(self.get_output(0).shape))]
|
||||
self._generate_bcast_dims(batch_dims, self.get_output(0).shape)
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['output0'].shape)
|
||||
dim_partition_list.append({})
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
||||
|
||||
strategies_vector = StrategiesVector(self)
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
# the three inputs are condition, true tensor and false tensor
|
||||
in0_partition_dict = self._recover_bcast_partition_dict(
|
||||
dim_partition_dict, self.op_data['input0'])
|
||||
in1_partition_dict = self._recover_bcast_partition_dict(
|
||||
dim_partition_dict, self.op_data['input1'])
|
||||
in2_partition_dict = self._recover_bcast_partition_dict(
|
||||
dim_partition_dict, self.op_data['input2'])
|
||||
out_partition_dict = dim_partition_dict
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_partition_dict,
|
||||
"input1": in1_partition_dict,
|
||||
"input2": in2_partition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
name = '{} = <select op {}> {} {}'.format(
|
||||
sharding_spec_mapping['output0'].sharding_sequence,
|
||||
sharding_spec_mapping['input0'].sharding_sequence,
|
||||
sharding_spec_mapping['input1'].sharding_sequence,
|
||||
sharding_spec_mapping['input2'].sharding_sequence)
|
||||
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
return strategies_vector
|
||||
@ -1,832 +0,0 @@
|
||||
import math
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import List, Tuple
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from .comm_spec import CommSpec
|
||||
from .sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
class ShapeConsistencyManager(object):
|
||||
|
||||
def __init__(self):
|
||||
self.forward_only = True
|
||||
self.cached_spec_pairs_transform_path = {}
|
||||
self.cache_hit = 0
|
||||
self.cache_miss = 0
|
||||
|
||||
def all_gather_simulator(self, target_pair):
|
||||
_, shard_list = target_pair
|
||||
new_shard_list = []
|
||||
return new_shard_list
|
||||
|
||||
def all_to_all_simulator(self, f_target_pair, b_target_pair):
|
||||
'''
|
||||
Simulating all-to-all operation, analyze the communication cost
|
||||
and simulate the influence of the DimSpec.
|
||||
|
||||
We BANNED all representations which shard_list in decreasing order,
|
||||
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
|
||||
Therefore, if the behind shard_list is not None, we just extend it to the front shard_list.
|
||||
Argument:
|
||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||
and the second element describes which logical axis will be sharded in that dimension.
|
||||
e.g.:
|
||||
all-to-all(S0, S1) -> [S01, R]
|
||||
all-to-all(S0, R) -> [R, S0]
|
||||
Otherwise, we extend the front shard_list to behind.
|
||||
e.g.:
|
||||
all-to-all(R, S1) -> [S1, R]
|
||||
|
||||
Argument:
|
||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||
and the second element describes which logical axis will be sharded in that dimension.
|
||||
'''
|
||||
_, f_shard_list = f_target_pair
|
||||
_, b_shard_list = b_target_pair
|
||||
if not len(b_shard_list):
|
||||
b_shard_list.extend(f_shard_list)
|
||||
f_shard_list = []
|
||||
else:
|
||||
f_shard_list.extend(b_shard_list)
|
||||
b_shard_list = []
|
||||
|
||||
return f_shard_list, b_shard_list
|
||||
|
||||
def shard_simulator(self, target_pair, legal_sharding_dims):
|
||||
'''
|
||||
Simulating shard operation, analyze the communication cost(always ZERO)
|
||||
and simulate the influence of the DimSpec.
|
||||
|
||||
We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed.
|
||||
In addition, We BANNED all representations which shard_list in decreasing order,
|
||||
such as S10, so shard(S0) -> S10 is NOT allowed.
|
||||
Therefore, for the R dimension, we could just append any legal sharding dim on it.
|
||||
e.g.:
|
||||
shard(R) -> S0
|
||||
For the S dimension, we need to make sure the shard_list after sharding still keep rising order.
|
||||
e.g:
|
||||
shard(S0) -> S01
|
||||
|
||||
Argument:
|
||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||
and the second element describes which logical axis will be sharded in that dimension.
|
||||
'''
|
||||
_, shard_list = target_pair
|
||||
shard_list_list, logical_process_axis = [], []
|
||||
for dim in legal_sharding_dims:
|
||||
if len(shard_list) != 0 and dim <= shard_list[-1]:
|
||||
continue
|
||||
new_shard_list = shard_list + [dim]
|
||||
shard_list_list.append(new_shard_list)
|
||||
logical_process_axis.append([dim])
|
||||
|
||||
# we support sorted 2D mesh here
|
||||
if len(legal_sharding_dims) == 2 and len(shard_list) == 0:
|
||||
shard_list_list.append(legal_sharding_dims)
|
||||
logical_process_axis.append(legal_sharding_dims)
|
||||
return shard_list_list, logical_process_axis
|
||||
|
||||
def mix_gather_simulator(self, f_target_pair, b_target_pair):
|
||||
'''
|
||||
Assume index of f and b target pairs are 'f' and 'b'
|
||||
S0S1 => Input: (f, [0]), (b, [1]) Output: [f, b], [[0], [1]]
|
||||
S1S0 => Input: (f, [1]), (b, [0]) Output: [f, b], [[1], [0]]
|
||||
S01R => Input: (f, [0, 1]), (b, []) Output: [f], [[0, 1]]
|
||||
RS01 => Input: (f, []), (b, [0, 1]) Output: [b], [[0, 1]]
|
||||
'''
|
||||
if f_target_pair[1] and b_target_pair[1]:
|
||||
return [f_target_pair[0],
|
||||
b_target_pair[0]], [f_target_pair[1], b_target_pair[1]]
|
||||
if f_target_pair[1]:
|
||||
return [f_target_pair[0]], [f_target_pair[1]]
|
||||
if b_target_pair[1]:
|
||||
return [b_target_pair[0]], [b_target_pair[1]]
|
||||
|
||||
def get_all_all_gather_spec(self, source_spec, orig_cost):
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with single all-gather operation, and
|
||||
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
|
||||
For the all-gather operation, we just care about the S dimension.
|
||||
|
||||
Argument:
|
||||
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
||||
orig_cost(Dict[str, float]): the original communication cost before this operation.
|
||||
|
||||
Return:
|
||||
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation.
|
||||
|
||||
Example:
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
# DistSpec:
|
||||
# shard_sequence: S0,S1,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
|
||||
print(rst_dict)
|
||||
|
||||
Output:
|
||||
{DistSpec:
|
||||
shard_sequence: R,S1,R
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: S0,R,R
|
||||
device_mesh_shape: (4, 4): 0}
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = 'all_gather'
|
||||
for target_pair in source_spec.dim_partition_dict.items():
|
||||
shard_list = self.all_gather_simulator(target_pair)
|
||||
index = target_pair[0]
|
||||
new_dim_partition_dict = source_spec.dim_partition_dict.copy()
|
||||
|
||||
# We won't add empty list into dim_partition_dict
|
||||
# The key will be popped if the related shard_list is empty
|
||||
if shard_list:
|
||||
new_dim_partition_dict[index] = shard_list
|
||||
else:
|
||||
new_dim_partition_dict.pop(index)
|
||||
|
||||
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
|
||||
gather_dim = index
|
||||
logical_process_axis = target_pair[1]
|
||||
comm_spec = CommSpec(comm_pattern,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=[gather_dim],
|
||||
logical_process_axis=[logical_process_axis],
|
||||
forward_only=self.forward_only)
|
||||
|
||||
# compute the communication cost with CommSpec
|
||||
|
||||
# generate new sharding spec
|
||||
new_sharding_spec = ShardingSpec(
|
||||
source_spec.device_mesh,
|
||||
source_spec.data_type_size,
|
||||
source_spec.entire_shape,
|
||||
source_spec.max_entire_shape,
|
||||
source_spec.raw_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
|
||||
if not new_sharding_spec.sanity_check():
|
||||
continue
|
||||
cost = comm_spec.get_comm_cost()
|
||||
valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost)
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_all_to_all_spec(self, source_spec, orig_cost):
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with single all-to-all operation, and
|
||||
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
|
||||
For the all-to-all operation, we just care about the pairs containing S dimension.
|
||||
|
||||
Argument:
|
||||
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
||||
orig_cost(Dict[str, float]): the original communication cost before this operation.
|
||||
|
||||
Return:
|
||||
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
|
||||
|
||||
Example:
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
# DistSpec:
|
||||
# shard_sequence: S0,S1,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
|
||||
print(rst_dict)
|
||||
|
||||
Output:
|
||||
{DistSpec:
|
||||
shard_sequence: S01,R,R
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: R,S1,S0
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: S0,R,S1
|
||||
device_mesh_shape: (4, 4): 0}
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = 'all_to_all'
|
||||
tensor_dims = len(source_spec.entire_shape)
|
||||
for f_index in range(tensor_dims - 1):
|
||||
for b_index in range(f_index + 1, tensor_dims):
|
||||
# skip (R, R) cases
|
||||
if f_index not in source_spec.dim_partition_dict and b_index not in source_spec.dim_partition_dict:
|
||||
continue
|
||||
else:
|
||||
if f_index in source_spec.dim_partition_dict:
|
||||
'''
|
||||
# skip (S01, R) -> (R, S01) is NOT allowed
|
||||
if len(source_spec.dim_partition_dict[f_index]) >= 2:
|
||||
continue
|
||||
'''
|
||||
f_target_pair = (f_index, [
|
||||
*source_spec.dim_partition_dict[f_index]
|
||||
])
|
||||
else:
|
||||
f_target_pair = (f_index, [])
|
||||
if b_index in source_spec.dim_partition_dict:
|
||||
'''
|
||||
# skip (R, S01) -> (S01, R) is NOT allowed
|
||||
if len(source_spec.dim_partition_dict[b_index]) >= 2:
|
||||
continue
|
||||
'''
|
||||
b_target_pair = (b_index, [
|
||||
*source_spec.dim_partition_dict[b_index]
|
||||
])
|
||||
else:
|
||||
b_target_pair = (b_index, [])
|
||||
|
||||
# skip (S1, S0) -> S10
|
||||
if f_target_pair[1] and b_target_pair[1] and f_target_pair[1][
|
||||
0] >= b_target_pair[1][0]:
|
||||
continue
|
||||
f_shard_list, b_shard_list = self.all_to_all_simulator(
|
||||
f_target_pair, b_target_pair)
|
||||
f_index = f_target_pair[0]
|
||||
b_index = b_target_pair[0]
|
||||
|
||||
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
|
||||
if len(f_shard_list) < len(f_target_pair[1]):
|
||||
gather_dim = f_index
|
||||
shard_dim = b_index
|
||||
logical_process_axis = f_target_pair[1]
|
||||
else:
|
||||
gather_dim = b_index
|
||||
shard_dim = f_index
|
||||
logical_process_axis = b_target_pair[1]
|
||||
comm_spec = CommSpec(
|
||||
comm_pattern,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=[gather_dim],
|
||||
shard_dim=[shard_dim],
|
||||
logical_process_axis=[logical_process_axis],
|
||||
forward_only=self.forward_only)
|
||||
|
||||
# compute the communication cost with CommSpec
|
||||
|
||||
new_dim_partition_dict = source_spec.dim_partition_dict.copy()
|
||||
|
||||
# We won't add empty list into dim_partition_dict
|
||||
# The key will be popped if the related shard_list is empty
|
||||
if f_shard_list:
|
||||
new_dim_partition_dict[f_index] = f_shard_list
|
||||
else:
|
||||
new_dim_partition_dict.pop(f_index)
|
||||
if b_shard_list:
|
||||
new_dim_partition_dict[b_index] = b_shard_list
|
||||
else:
|
||||
new_dim_partition_dict.pop(b_index)
|
||||
|
||||
# generate new sharding spec
|
||||
|
||||
new_sharding_spec = ShardingSpec(
|
||||
source_spec.device_mesh,
|
||||
source_spec.data_type_size,
|
||||
source_spec.entire_shape,
|
||||
source_spec.max_entire_shape,
|
||||
source_spec.raw_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
if not new_sharding_spec.sanity_check():
|
||||
continue
|
||||
cost = comm_spec.get_comm_cost()
|
||||
valid_spec_dict[new_sharding_spec] = (comm_spec,
|
||||
cost + orig_cost)
|
||||
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_shard_spec(self, source_spec, orig_cost):
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with single shard operation, and
|
||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
||||
For the sharding operation, we just care about legal sharding dimensions.
|
||||
|
||||
Argument:
|
||||
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
||||
orig_cost(float): the original communication cost before this operation.
|
||||
|
||||
Return:
|
||||
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
|
||||
|
||||
Example:
|
||||
dim_partition_dict = {0: [0]}
|
||||
# DistSpec:
|
||||
# shard_sequence: S0,R,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
|
||||
print(rst_dict)
|
||||
|
||||
Output:
|
||||
{DistSpec:
|
||||
shard_sequence: S01,R,R
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: S0,S1,R
|
||||
device_mesh_shape: (4, 4): 0, DistSpec:
|
||||
shard_sequence: S0,R,S1
|
||||
device_mesh_shape: (4, 4): 0}
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = 'split'
|
||||
|
||||
# legal sharding dims means the mesh_id is still available to use.
|
||||
legal_sharding_dims = [
|
||||
i for i in range(len(source_spec.device_mesh.mesh_shape))
|
||||
]
|
||||
for dim, shard_list in source_spec.dim_partition_dict.items():
|
||||
for element in shard_list:
|
||||
legal_sharding_dims.remove(element)
|
||||
if len(legal_sharding_dims) == 0:
|
||||
return valid_spec_dict
|
||||
|
||||
tensor_dims = len(source_spec.entire_shape)
|
||||
|
||||
for index in range(tensor_dims):
|
||||
if index not in source_spec.dim_partition_dict:
|
||||
shard_list_list, logical_process_axes = self.shard_simulator(
|
||||
(index, []), legal_sharding_dims)
|
||||
else:
|
||||
shard_list_list, logical_process_axes = self.shard_simulator(
|
||||
(index, source_spec.dim_partition_dict[index]),
|
||||
legal_sharding_dims)
|
||||
if not shard_list_list:
|
||||
continue
|
||||
for shard_list, logical_process_axis in zip(shard_list_list,
|
||||
logical_process_axes):
|
||||
new_dim_partition_dict = source_spec.dim_partition_dict.copy()
|
||||
new_dim_partition_dict[index] = shard_list
|
||||
|
||||
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
|
||||
comm_spec = CommSpec(
|
||||
comm_pattern,
|
||||
sharding_spec=source_spec,
|
||||
shard_dim=[index],
|
||||
logical_process_axis=[logical_process_axis],
|
||||
forward_only=self.forward_only)
|
||||
|
||||
# generate new sharding spec
|
||||
new_sharding_spec = ShardingSpec(
|
||||
source_spec.device_mesh,
|
||||
source_spec.data_type_size,
|
||||
source_spec.entire_shape,
|
||||
source_spec.max_entire_shape,
|
||||
source_spec.raw_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
if not new_sharding_spec.sanity_check():
|
||||
continue
|
||||
# compute the communication cost with CommSpec
|
||||
cost = comm_spec.get_comm_cost()
|
||||
valid_spec_dict[new_sharding_spec] = (comm_spec,
|
||||
cost + orig_cost)
|
||||
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_mixed_shard_spec(self, source_spec, orig_cost):
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with single shard operation, and
|
||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
||||
For the sharding operation, we just care about legal sharding dimensions.
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = 'split'
|
||||
|
||||
# legal sharding dims means the mesh_id is still available to use.
|
||||
legal_sharding_dims = [
|
||||
i for i in range(len(source_spec.device_mesh.mesh_shape))
|
||||
]
|
||||
for dim, shard_list in source_spec.dim_partition_dict.items():
|
||||
for element in shard_list:
|
||||
legal_sharding_dims.remove(element)
|
||||
if len(legal_sharding_dims) != 2:
|
||||
return valid_spec_dict
|
||||
|
||||
tensor_dims = len(source_spec.entire_shape)
|
||||
for f_index in range(tensor_dims):
|
||||
for b_index in range(tensor_dims):
|
||||
if f_index != b_index:
|
||||
shard_dims = [f_index, b_index]
|
||||
logical_process_axes = [[legal_sharding_dims[0]],
|
||||
[legal_sharding_dims[1]]]
|
||||
new_dim_partition_dict = source_spec.dim_partition_dict.copy(
|
||||
)
|
||||
new_dim_partition_dict[f_index] = [legal_sharding_dims[0]]
|
||||
new_dim_partition_dict[b_index] = [legal_sharding_dims[1]]
|
||||
comm_spec = CommSpec(
|
||||
comm_pattern,
|
||||
sharding_spec=source_spec,
|
||||
shard_dim=shard_dims,
|
||||
logical_process_axis=logical_process_axes,
|
||||
forward_only=self.forward_only)
|
||||
|
||||
# generate new sharding spec
|
||||
new_sharding_spec = ShardingSpec(
|
||||
source_spec.device_mesh,
|
||||
source_spec.data_type_size,
|
||||
source_spec.entire_shape,
|
||||
source_spec.max_entire_shape,
|
||||
source_spec.raw_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
if not new_sharding_spec.sanity_check():
|
||||
continue
|
||||
cost = comm_spec.get_comm_cost()
|
||||
valid_spec_dict[new_sharding_spec] = (comm_spec,
|
||||
cost + orig_cost)
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_mix_gather_spec(self, source_spec, orig_cost):
|
||||
'''
|
||||
S0S1 -> RR
|
||||
S1S0 -> RR
|
||||
S01R -> RR
|
||||
RS01 -> RR
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
comm_pathern = 'all_gather'
|
||||
tensor_dims = len(source_spec.entire_shape)
|
||||
for f_index in range(tensor_dims - 1):
|
||||
for b_index in range(f_index + 1, tensor_dims):
|
||||
if (f_index not in source_spec.dim_partition_dict) and (
|
||||
b_index not in source_spec.dim_partition_dict):
|
||||
continue
|
||||
else:
|
||||
if f_index in source_spec.dim_partition_dict:
|
||||
# skip (S10, R) -> (R, R)
|
||||
'''
|
||||
if len(
|
||||
f_target_pair[1]
|
||||
) == 2 and f_target_pair[1][0] >= f_target_pair[1][1]:
|
||||
continue
|
||||
'''
|
||||
f_target_pair = (f_index, [
|
||||
*source_spec.dim_partition_dict[f_index]
|
||||
])
|
||||
else:
|
||||
f_target_pair = (f_index, [])
|
||||
if b_index in source_spec.dim_partition_dict:
|
||||
# skip (R, S10) -> (R, R)
|
||||
'''
|
||||
if len(
|
||||
b_target_pair[1]
|
||||
) == 2 and b_target_pair[1][0] >= b_target_pair[1][1]:
|
||||
continue
|
||||
'''
|
||||
b_target_pair = (b_index, [
|
||||
*source_spec.dim_partition_dict[b_index]
|
||||
])
|
||||
else:
|
||||
b_target_pair = (b_index, [])
|
||||
if len(f_target_pair[1]) + len(b_target_pair[1]) != 2:
|
||||
continue
|
||||
gather_dim, logical_process_axes = self.mix_gather_simulator(
|
||||
f_target_pair, b_target_pair)
|
||||
comm_spec = CommSpec(comm_pathern,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axes,
|
||||
forward_only=self.forward_only,
|
||||
mix_gather=True)
|
||||
|
||||
new_dim_partition_dict = {}
|
||||
# generate new sharding spec
|
||||
new_sharding_spec = ShardingSpec(
|
||||
source_spec.device_mesh,
|
||||
source_spec.data_type_size,
|
||||
source_spec.entire_shape,
|
||||
source_spec.max_entire_shape,
|
||||
source_spec.raw_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
if not new_sharding_spec.sanity_check():
|
||||
continue
|
||||
cost = comm_spec.get_comm_cost()
|
||||
valid_spec_dict[new_sharding_spec] = (comm_spec,
|
||||
cost + orig_cost)
|
||||
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_one_step_transform_spec(self, source_spec, orig_cost):
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with one step transform, and
|
||||
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
|
||||
Note:
|
||||
all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before,
|
||||
and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive,
|
||||
we could safely put them together.
|
||||
|
||||
Argument:
|
||||
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
|
||||
orig_cost(float): the original communication cost before this operation.
|
||||
|
||||
Return:
|
||||
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
valid_spec_dict.update(
|
||||
self.get_all_all_gather_spec(source_spec, orig_cost))
|
||||
valid_spec_dict.update(
|
||||
self.get_all_all_to_all_spec(source_spec, orig_cost))
|
||||
valid_spec_dict.update(
|
||||
self.get_all_mix_gather_spec(source_spec, orig_cost))
|
||||
valid_spec_dict.update(
|
||||
self.get_all_mixed_shard_spec(source_spec, orig_cost))
|
||||
valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost))
|
||||
return valid_spec_dict
|
||||
|
||||
def mem_cost(self, comm_action_sequence: List[CommSpec], mem_pattern='opt'):
|
||||
"""memory cost of the communication action sequence
|
||||
|
||||
Args:
|
||||
comm_action_sequence (List[CommSpec]): list of communication actions
|
||||
|
||||
Returns:
|
||||
TrainCycleItem: memory (numel) cost of such comm_action_sequence
|
||||
"""
|
||||
|
||||
def compute_shape(sharding_spec: ShardingSpec):
|
||||
if 'opt' == mem_pattern:
|
||||
return sharding_spec.get_sharded_shape_per_device()
|
||||
elif 'max' == mem_pattern:
|
||||
return sharding_spec.get_max_sharded_shape_per_device()
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def gather_analysis(comm_spec, peak_mem):
|
||||
"""analyze all_gather memory footprint
|
||||
all_gather will allocate memory for the output tensor, and there will be temp memory for
|
||||
all_gather operation, which is twice the size of output tensor
|
||||
|
||||
Args:
|
||||
comm_spec (CommSpec): input CommSpec
|
||||
"""
|
||||
input_shape = compute_shape(comm_spec.sharding_spec)
|
||||
input_numel = reduce(operator.mul, input_shape, 1)
|
||||
for axes in comm_spec.logical_process_axis:
|
||||
for axis in axes:
|
||||
output_numel = input_numel * comm_spec.device_mesh.mesh_shape[
|
||||
axis]
|
||||
alloc_mem = (input_numel +
|
||||
output_numel * 2) * comm_spec.sharding_spec.dtype_size
|
||||
peak_mem = max(peak_mem, alloc_mem)
|
||||
return peak_mem
|
||||
|
||||
def reduce_scatter_analysis(comm_spec, peak_mem):
|
||||
|
||||
input_shape = compute_shape(comm_spec.sharding_spec)
|
||||
input_numel = reduce(operator.mul, input_shape, 1)
|
||||
output_numel = input_numel
|
||||
for axes in comm_spec.logical_process_axis:
|
||||
for axis in axes:
|
||||
output_numel = output_numel / comm_spec.device_mesh.mesh_shape[
|
||||
axis]
|
||||
alloc_mem = (input_numel +
|
||||
output_numel * 2) * comm_spec.sharding_spec.dtype_size
|
||||
peak_mem = max(peak_mem, alloc_mem)
|
||||
|
||||
return peak_mem
|
||||
|
||||
def split_analysis(comm_spec: CommSpec, peak_mem: int):
|
||||
"""analyze split memory footprint
|
||||
split will allocate memory for the output tensor if we don't apply shard on the first dimension of
|
||||
the input tensor. If we apply shard on the first dimension, the `torch.tensor.contiguous()` will not
|
||||
generate new tensor in this case, so no memory will be allocated.
|
||||
|
||||
Args:
|
||||
comm_spec (CommSpec): input CommSpec
|
||||
discard_input (bool): whether to discard the input tensor
|
||||
alloc_numel (int): current allocated numel
|
||||
peak_numel (int): current peak numel
|
||||
"""
|
||||
shard_dim = comm_spec.shard_dim
|
||||
if shard_dim != 0:
|
||||
# if we don't shard the tensor on the first dimension, the split action will
|
||||
# generate a new tensor
|
||||
input_shape = compute_shape(comm_spec.sharding_spec)
|
||||
input_numel = reduce(operator.mul, input_shape, 1)
|
||||
output_numel = input_numel
|
||||
for axes in comm_spec.logical_process_axis:
|
||||
for axis in axes:
|
||||
output_numel = output_numel / comm_spec.device_mesh.mesh_shape[
|
||||
axis]
|
||||
alloc_mem = (input_numel +
|
||||
output_numel) * comm_spec.sharding_spec.dtype_size
|
||||
peak_mem = max(peak_mem, alloc_mem)
|
||||
else:
|
||||
# if we shard the tensor on the first dimension, the split action will not generate
|
||||
# a new tensor, and as it will preserve a reference to the input tensor, we could
|
||||
# override the discard_input option here
|
||||
# NOTE: this special case might fail in some weird cases, e.g. if we have three split
|
||||
# actions in the comm actions sequence, the first split action operate on the second dimension,
|
||||
# the second split action operate on the first dimension, and the third split action operate, again,
|
||||
# on the second dimension. Therefore, after the first two actions in the sequence, we will allocate
|
||||
# memory the same size as the output of first split action. However, the third split action will discard
|
||||
# the input tensor, and it actually should discard the tensor generated by the first split action, so in
|
||||
# the current memory estimation framework, we will overestimate the memory usage. But the above case is
|
||||
# kind of weird, and I think we could ignore it for now.
|
||||
pass
|
||||
return peak_mem
|
||||
|
||||
def reduce_analysis(comm_spec: CommSpec, peak_mem: int):
|
||||
input_shape = compute_shape(comm_spec.sharding_spec)
|
||||
input_numel = reduce(operator.mul, input_shape, 1)
|
||||
output_numel = input_numel
|
||||
alloc_mem = (input_numel +
|
||||
output_numel) * comm_spec.sharding_spec.dtype_size
|
||||
peak_mem = max(peak_mem, alloc_mem)
|
||||
return peak_mem
|
||||
|
||||
def all2all_analysis(comm_spec: CommSpec, peak_mem: int):
|
||||
input_shape = compute_shape(comm_spec.sharding_spec)
|
||||
input_numel = reduce(operator.mul, input_shape, 1)
|
||||
output_numel = input_numel
|
||||
comm_spec.shard_dim
|
||||
alloc_mem = (input_numel +
|
||||
output_numel * 3) * comm_spec.sharding_spec.dtype_size
|
||||
peak_mem = max(peak_mem, alloc_mem)
|
||||
return peak_mem
|
||||
|
||||
def peer_to_peer_analysis(comm_spec: CommSpec, peak_mem: int):
|
||||
input_shape = compute_shape(comm_spec.sharding_spec)
|
||||
input_numel = reduce(operator.mul, input_shape, 1)
|
||||
alloc_mem = (input_numel) * comm_spec.sharding_spec.dtype_size
|
||||
peak_mem = max(peak_mem, alloc_mem)
|
||||
return peak_mem
|
||||
|
||||
pattern_to_func_dict = {
|
||||
'all_gather': gather_analysis,
|
||||
'all_to_all': all2all_analysis,
|
||||
'split': split_analysis,
|
||||
'all_reduce': reduce_analysis,
|
||||
'reduce_scatter': reduce_scatter_analysis,
|
||||
'peer_to_peer': peer_to_peer_analysis
|
||||
}
|
||||
|
||||
fwd_actions = []
|
||||
# construct forward and backward comm actions sequence
|
||||
for comm_spec in comm_action_sequence:
|
||||
fwd_action = pattern_to_func_dict[comm_spec.comm_pattern]
|
||||
fwd_actions.append(fwd_action)
|
||||
|
||||
# analyze memory footprint of forward comm actions sequence
|
||||
fwd_peak_numel = 0
|
||||
for idx, action_spec_pair in enumerate(
|
||||
zip(fwd_actions, comm_action_sequence)):
|
||||
# the first forward comm action will not discard input
|
||||
fwd_action, comm_spec = action_spec_pair
|
||||
fwd_peak_numel = fwd_action(comm_spec, fwd_peak_numel)
|
||||
|
||||
return fwd_peak_numel
|
||||
|
||||
def print_shape_consistency_result(self,
|
||||
transform_path,
|
||||
comm_action_sequence,
|
||||
resharding_cost,
|
||||
file=None):
|
||||
for idx, tpath in enumerate(transform_path):
|
||||
print(
|
||||
f'sharding_info = [op_shape:{tpath.entire_shape}, sharding_spec:{tpath.sharding_sequence}, sharded_shape:{tpath.get_sharded_shape_per_device()}]',
|
||||
end=" ",
|
||||
file=file)
|
||||
print('->', end=" ", file=file)
|
||||
try:
|
||||
commspec = comm_action_sequence[idx]
|
||||
comm = [
|
||||
commspec.comm_pattern, commspec.gather_dim,
|
||||
commspec.shard_dim, commspec.logical_process_axis
|
||||
]
|
||||
except:
|
||||
comm = ''
|
||||
print(f'comm_info = {comm}', end=" ", file=file)
|
||||
print('->', end=" ", file=file)
|
||||
print(f'total_cost = {resharding_cost}', file=file)
|
||||
|
||||
def construct_transform_path_from_cache(self, src_spec, target_spec,
|
||||
old_transform_path,
|
||||
old_comm_action_sequence,
|
||||
orig_cost):
|
||||
new_transform_path = [src_spec]
|
||||
new_comm_action_sequence = []
|
||||
new_cost = orig_cost
|
||||
new_src_spec = src_spec
|
||||
for idx, old_comm_spec in enumerate(old_comm_action_sequence):
|
||||
new_comm_spec = CommSpec(
|
||||
old_comm_spec.comm_pattern,
|
||||
sharding_spec=new_src_spec,
|
||||
gather_dim=old_comm_spec.gather_dim,
|
||||
shard_dim=old_comm_spec.shard_dim,
|
||||
logical_process_axis=old_comm_spec.logical_process_axis,
|
||||
forward_only=old_comm_spec.forward_only,
|
||||
mix_gather=old_comm_spec.mix_gather)
|
||||
new_comm_action_sequence.append(new_comm_spec)
|
||||
new_cost += new_comm_spec.get_comm_cost()
|
||||
old_target_spec = old_transform_path[idx + 1]
|
||||
new_target_spec = ShardingSpec(new_src_spec.device_mesh,
|
||||
new_src_spec.data_type_size,
|
||||
new_src_spec.entire_shape,
|
||||
new_src_spec.max_entire_shape,
|
||||
new_src_spec.raw_shape,
|
||||
old_target_spec.dim_partition_dict)
|
||||
new_transform_path.append(new_target_spec)
|
||||
new_src_spec = new_target_spec
|
||||
assert new_transform_path[-1].get_sharded_shape_per_device(
|
||||
) == target_spec.get_sharded_shape_per_device(
|
||||
), 'failed to insert the cache transform path'
|
||||
return new_transform_path, new_comm_action_sequence, new_cost
|
||||
|
||||
def shape_consistency(
|
||||
self, source_spec: ShardingSpec, target_spec: ShardingSpec
|
||||
) -> Tuple[List[ShardingSpec], List[CommSpec], float]:
|
||||
'''
|
||||
This method will find a path to transform source_spec to target_spec with
|
||||
a greedy algorithm.
|
||||
The basic idea is:
|
||||
Step1:
|
||||
Generate all one-step transform sequences from source_spec.
|
||||
Step2:
|
||||
Pick the 'best' sharding spec following the heuristic function.
|
||||
Step3:
|
||||
Repeat above steps until the source spec transform to target spec.
|
||||
'''
|
||||
MAX_TRANSFORM_STEPS = 20
|
||||
total_cost = 0.0
|
||||
total_steps = 0
|
||||
transform_path = []
|
||||
comm_action_sequence = []
|
||||
# We do nothing if the sharding spec is all the same.
|
||||
if source_spec.sharding_sequence_difference(target_spec) == 0:
|
||||
return (transform_path, comm_action_sequence, total_cost)
|
||||
|
||||
spec_pairs = (str(source_spec.sharding_sequence),
|
||||
str(target_spec.sharding_sequence))
|
||||
|
||||
if spec_pairs in self.cached_spec_pairs_transform_path:
|
||||
transform_path, comm_action_sequence = self.cached_spec_pairs_transform_path[
|
||||
spec_pairs]
|
||||
new_transform_path, new_comm_action_sequence, new_total_cost = self.construct_transform_path_from_cache(
|
||||
source_spec, target_spec, transform_path, comm_action_sequence,
|
||||
total_cost)
|
||||
self.cache_hit += 1
|
||||
return (new_transform_path, new_comm_action_sequence,
|
||||
new_total_cost)
|
||||
|
||||
else:
|
||||
self.cache_miss += 1
|
||||
|
||||
temp_sharding_spec = source_spec
|
||||
transform_path.append(temp_sharding_spec)
|
||||
# To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms
|
||||
while total_steps <= MAX_TRANSFORM_STEPS:
|
||||
valid_transform_spec_dict = self.get_all_one_step_transform_spec(
|
||||
temp_sharding_spec, total_cost)
|
||||
best_difference_score = math.inf
|
||||
|
||||
for sharding_spec, info_pairs in valid_transform_spec_dict.items():
|
||||
comm_spec, cost = info_pairs
|
||||
spec_difference = sharding_spec.sharding_sequence_difference(
|
||||
target_spec)
|
||||
|
||||
if spec_difference == 0:
|
||||
total_cost = cost
|
||||
transform_path.append(sharding_spec)
|
||||
comm_action_sequence.append(comm_spec)
|
||||
self.cached_spec_pairs_transform_path[spec_pairs] = (
|
||||
transform_path, comm_action_sequence)
|
||||
return (transform_path, comm_action_sequence, total_cost)
|
||||
|
||||
if spec_difference < best_difference_score:
|
||||
temp_sharding_spec = sharding_spec
|
||||
temp_cost = cost
|
||||
temp_comm_spec = comm_spec
|
||||
best_difference_score = spec_difference
|
||||
|
||||
transform_path.append(temp_sharding_spec)
|
||||
comm_action_sequence.append(temp_comm_spec)
|
||||
total_cost = temp_cost
|
||||
total_steps += 1
|
||||
|
||||
raise RuntimeError(
|
||||
f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps."
|
||||
)
|
||||
|
||||
def dum_transform_path_from_cache(self):
|
||||
src_specs, tgt_specs, path_strs = [], [], []
|
||||
for spec_pairs, trans_comm_path in self.cached_spec_pairs_transform_path.items(
|
||||
):
|
||||
src_specs.append(spec_pairs[0])
|
||||
tgt_specs.append(spec_pairs[1])
|
||||
trans_paths, comm_specs = trans_comm_path[0], trans_comm_path[1]
|
||||
path_str = f'{spec_pairs[0]}->'
|
||||
for idx in range(1, len(trans_paths)):
|
||||
comm_spec = comm_specs[idx - 1]
|
||||
comm_str = f'{comm_spec.comm_pattern}: gather_dim{comm_spec.gather_dim}, shard_dim{comm_spec.shard_dim}, mesh_axis{comm_spec.logical_process_axis}->'
|
||||
path_str += comm_str
|
||||
path_str += f'{trans_paths[idx].sharding_sequence}->'
|
||||
path_strs.append(path_str)
|
||||
ret_dict = {
|
||||
'src_spec': src_specs,
|
||||
'dst_specs': tgt_specs,
|
||||
'trans_path': path_strs
|
||||
}
|
||||
ret_df = pd.DataFrame.from_dict(ret_dict)
|
||||
return ret_df
|
||||
@ -1,41 +0,0 @@
|
||||
import copy
|
||||
|
||||
from .node import Node
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class Shape(Node):
|
||||
|
||||
def _update_memory_cost(self, strategies):
|
||||
pass
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
# one input for softmax node
|
||||
predecessor = self.predecessor_nodes[0]
|
||||
strategies_vector = StrategiesVector(self)
|
||||
for idx, strategy in enumerate(predecessor.strategies_vector):
|
||||
# current node's local name input0 -> global name xxx
|
||||
global_input_name = self.op_data['input0'].name
|
||||
# global name xxx -> pre node local output name
|
||||
prenode_local_name = predecessor.global_to_local_op_name[
|
||||
global_input_name]
|
||||
dim_partition_dict = copy.deepcopy(
|
||||
strategy.sharding_specs[prenode_local_name].dim_partition_dict)
|
||||
in0_partition_dict = dim_partition_dict
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_partition_dict,
|
||||
"output0": {},
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
return strategies_vector
|
||||
name = '{} = <shape> {}'.format(
|
||||
sharding_spec_mapping['output0'].sharding_sequence,
|
||||
sharding_spec_mapping['input0'].sharding_sequence)
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
return strategies_vector
|
||||
@ -1,418 +0,0 @@
|
||||
import operator
|
||||
from functools import reduce
|
||||
|
||||
import tensorrt as trt
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
ALLGATHER_COST = 20
|
||||
SHARD_COST = 5
|
||||
STEP_PENALTY = 6
|
||||
NAN = 'nan'
|
||||
|
||||
|
||||
def _convert_str_to_shard_list(str_spec):
|
||||
'''
|
||||
Convert str_spec into shard_list.
|
||||
|
||||
Argument:
|
||||
str_spec(str): dim spec in str type.
|
||||
'''
|
||||
|
||||
if str_spec == 'R':
|
||||
return []
|
||||
if str_spec == 'S0':
|
||||
return [0]
|
||||
if str_spec == 'S1':
|
||||
return [1]
|
||||
if str_spec == 'S01':
|
||||
return [0, 1]
|
||||
|
||||
|
||||
def _build_difference_2d_dict():
|
||||
'''
|
||||
Build a difference mapping for 2D device mesh case. It will be used to
|
||||
compute the difference between DimSpec pairs.
|
||||
'''
|
||||
|
||||
source_spec_list = ['R', 'S0', 'S1', 'S01']
|
||||
target_spec_list = ['R', 'S0', 'S1', 'S01']
|
||||
difference_dict = {}
|
||||
for source_spec in source_spec_list:
|
||||
for target_spec in target_spec_list:
|
||||
spec_pair = (source_spec, target_spec)
|
||||
source_shard_list = _convert_str_to_shard_list(source_spec)
|
||||
target_shard_list = _convert_str_to_shard_list(target_spec)
|
||||
|
||||
# source same as target
|
||||
if source_shard_list == target_shard_list:
|
||||
difference = 0
|
||||
|
||||
# all_gather(source) -> target
|
||||
elif len(source_shard_list) == len(
|
||||
target_shard_list
|
||||
) + 1 and source_shard_list[:-1] == target_shard_list:
|
||||
difference = ALLGATHER_COST
|
||||
|
||||
# shard(source) -> target
|
||||
elif len(source_shard_list) == len(
|
||||
target_shard_list
|
||||
) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[
|
||||
-1] not in source_shard_list:
|
||||
difference = SHARD_COST
|
||||
|
||||
# S1 -> S0 or S0 -> S1
|
||||
elif len(source_shard_list) == len(target_shard_list):
|
||||
# source -> R -> target
|
||||
difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST
|
||||
|
||||
# R -> S01
|
||||
elif len(source_shard_list) == len(target_shard_list) - 2:
|
||||
difference = SHARD_COST + STEP_PENALTY + SHARD_COST
|
||||
|
||||
# S01 -> R
|
||||
elif len(source_shard_list) == len(target_shard_list) + 2:
|
||||
difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST
|
||||
|
||||
# S1 -> S01
|
||||
elif len(source_shard_list) == len(target_shard_list) - 1:
|
||||
difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + STEP_PENALTY + SHARD_COST
|
||||
|
||||
# S01 -> S1
|
||||
elif len(source_shard_list) == len(target_shard_list) + 1:
|
||||
difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + STEP_PENALTY + SHARD_COST
|
||||
|
||||
else:
|
||||
difference = NAN
|
||||
difference_dict[spec_pair] = difference
|
||||
|
||||
return difference_dict
|
||||
|
||||
|
||||
_difference_dict = _build_difference_2d_dict()
|
||||
|
||||
|
||||
class DimSpec:
|
||||
'''
|
||||
Sharding spec for single dimension of the sharded tensor describe the sharding dimension of
|
||||
logical device mesh and give a method to compute the difference between them.
|
||||
|
||||
Argument:
|
||||
shard_list(List[int]): if shard_list is empty, the dim spec will be 'R' type.
|
||||
Otherwise, the element in shard_list means the data will be sharded in that dimension.
|
||||
'''
|
||||
|
||||
def __init__(self, shard_list):
|
||||
self.is_replica = len(shard_list) == 0
|
||||
self.shard_list = shard_list
|
||||
|
||||
def __eq__(self, other):
|
||||
return str(self) == str(other)
|
||||
|
||||
def __repr__(self):
|
||||
if self.is_replica:
|
||||
return 'R'
|
||||
target = 'S'
|
||||
for dim in self.shard_list:
|
||||
target += str(dim)
|
||||
return target
|
||||
|
||||
def difference(self, other):
|
||||
'''
|
||||
The difference between two DimSpec.
|
||||
|
||||
Argument:
|
||||
other(DimSpec): the dim spec to compare with.
|
||||
|
||||
Return:
|
||||
difference(int): the difference between two DimSpec.
|
||||
|
||||
Example:
|
||||
dim_spec = DimSpec([0])
|
||||
other_dim_spec = DimSpec([0, 1])
|
||||
print(dim_spec.difference(other_dim_spec))
|
||||
|
||||
Output:
|
||||
5
|
||||
'''
|
||||
difference = _difference_dict[(str(self), str(other))]
|
||||
return difference
|
||||
|
||||
|
||||
def get_sharding_sequence(num_dims, dims, device_dims):
|
||||
sharding_sequence = [DimSpec([])] * num_dims
|
||||
for dim, shard_list in zip(dims, device_dims):
|
||||
sharding_sequence[dim] = DimSpec(shard_list)
|
||||
return sharding_sequence
|
||||
|
||||
|
||||
class ShardingSpec:
|
||||
'''
|
||||
Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong
|
||||
to, the entire shape of the tensor before sharded, and the sharding sequence looks like
|
||||
[R, R, S0, S1].
|
||||
|
||||
Argument:
|
||||
device_mesh: A logical view of a physical mesh.
|
||||
entire_shape: The entire shape of tensor before sharded.
|
||||
dim_partition_dict: The key is the dimension of tensor to be sharded,
|
||||
and the value of the key describe which logical axis will be sharded in that dimension.
|
||||
sharding_sequence(List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
device_mesh,
|
||||
data_type_size,
|
||||
data_shape,
|
||||
max_data_shape,
|
||||
raw_data_shape,
|
||||
dim_partition_dict=None,
|
||||
sharding_sequence=None):
|
||||
self.device_mesh = device_mesh
|
||||
self.data_type_size = data_type_size
|
||||
self.dtype = data_type_size[0]
|
||||
self.dtype_size = data_type_size[1]
|
||||
self.entire_shape = data_shape
|
||||
self.max_entire_shape = max_data_shape
|
||||
self.raw_shape = raw_data_shape
|
||||
self.dim_partition_dict = dim_partition_dict
|
||||
self.sharding_sequence = sharding_sequence
|
||||
self.enable_shard_unbalanced_shape = device_mesh.config.enable_shard_unbalanced_shape
|
||||
self.enable_shard_dynamic_shape = device_mesh.config.enable_shard_dynamic_shape
|
||||
if self.sharding_sequence is None:
|
||||
self.dim_partition_dict = self._merge_same_dim_mesh_list(
|
||||
len(self.entire_shape), self.dim_partition_dict)
|
||||
self.dim_partition_dict = self._convert_dim_partition_dict(
|
||||
len(self.entire_shape), self.dim_partition_dict)
|
||||
assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.'
|
||||
self.convert_dict_to_shard_sequence()
|
||||
elif self.dim_partition_dict is None:
|
||||
assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.'
|
||||
self.convert_shard_sequence_to_dict()
|
||||
self.dim_partition_dict = self._merge_same_dim_mesh_list(
|
||||
len(self.entire_shape), self.dim_partition_dict)
|
||||
self.dim_partition_dict = self._convert_dim_partition_dict(
|
||||
len(self.entire_shape), self.dim_partition_dict)
|
||||
|
||||
self.sharded_shape, self.max_sharded_shape = [*self.entire_shape], [
|
||||
*self.max_entire_shape
|
||||
]
|
||||
for dim, shard_list in self.dim_partition_dict.items():
|
||||
mesh_list = [
|
||||
self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list
|
||||
]
|
||||
shard_partitions = reduce(operator.mul, mesh_list, 1)
|
||||
self.sharded_shape[dim] = (self.sharded_shape[dim] +
|
||||
shard_partitions - 1) // shard_partitions
|
||||
self.max_sharded_shape[dim] = (self.max_sharded_shape[dim] +
|
||||
shard_partitions -
|
||||
1) // shard_partitions
|
||||
|
||||
def print_spec(self, file=None):
|
||||
print(
|
||||
f"sharding_sequence = {self.sharding_sequence}, shape = {self.get_sharded_shape_per_device()}",
|
||||
file=file,
|
||||
)
|
||||
|
||||
def _merge_same_dim_mesh_list(self, dim_size, dim_partition_dict):
|
||||
'''
|
||||
This method is used to merge the different key value which points to same physical position.
|
||||
|
||||
For example:
|
||||
dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position.
|
||||
In this method, above dim_partition_dict will be converted to {1: [0, 1]}
|
||||
'''
|
||||
converted_dim_partition_dict = {}
|
||||
for dim, mesh_list in dim_partition_dict.items():
|
||||
if dim < 0:
|
||||
dim = dim_size + dim
|
||||
if dim not in converted_dim_partition_dict:
|
||||
converted_dim_partition_dict[dim] = mesh_list
|
||||
else:
|
||||
converted_dim_partition_dict[dim].extend(mesh_list)
|
||||
converted_dim_partition_dict[dim].sort()
|
||||
return converted_dim_partition_dict
|
||||
|
||||
def _convert_dim_partition_dict(self, dim_size, dim_partition_dict):
|
||||
dims_to_convert = []
|
||||
for dim, mesh_list in dim_partition_dict.items():
|
||||
if dim < 0:
|
||||
dims_to_convert.append(dim)
|
||||
for dim in dims_to_convert:
|
||||
dim_partition_dict.pop(dim)
|
||||
dim_partition_dict[dim_size + dim] = mesh_list
|
||||
return dim_partition_dict
|
||||
|
||||
def _remove_mesh_dim_one(self, dim_partition_dict):
|
||||
dims_to_remove = []
|
||||
for dim, mesh_list in dim_partition_dict.items():
|
||||
new_mesh_list = []
|
||||
for mesh_dim in mesh_list:
|
||||
if self.device_mesh.mesh_shape[mesh_dim] != 1:
|
||||
new_mesh_list.append(mesh_dim)
|
||||
if 0 != len(new_mesh_list):
|
||||
dim_partition_dict[dim] = new_mesh_list
|
||||
else:
|
||||
dims_to_remove.append(dim)
|
||||
for dim in dims_to_remove:
|
||||
dim_partition_dict.pop(dim)
|
||||
return dim_partition_dict
|
||||
|
||||
def __repr__(self):
|
||||
res = "DistSpec("
|
||||
res += f"shard_sequence={self.sharding_sequence},"
|
||||
res += f"shape={self.device_mesh.mesh_shape}"
|
||||
res += ")"
|
||||
return res
|
||||
|
||||
def sanity_check(self):
|
||||
# make sure all axes in logical device mesh only be used once
|
||||
dim_check_list = [*range(len(self.device_mesh.mesh_shape))]
|
||||
for dim, shard_list in self.dim_partition_dict.items():
|
||||
for element in shard_list:
|
||||
if element in dim_check_list:
|
||||
dim_check_list.remove(element)
|
||||
else:
|
||||
logger.warning(
|
||||
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}. dim_partition_dict={self.dim_partition_dict}"
|
||||
)
|
||||
return False
|
||||
|
||||
# make sure that the dimension is not out of index
|
||||
for dim in self.dim_partition_dict.keys():
|
||||
# we have tried to convert the negative value to positive value, if it is larger than the dim_size or negative still, it is out of index
|
||||
if dim >= len(self.entire_shape) or dim < 0:
|
||||
print(
|
||||
f"The dim_partition_dict specifies to shard dimension {dim} but the entire_shape only has {len(self.entire_shape)} dimensions"
|
||||
)
|
||||
return False
|
||||
|
||||
if not self.enable_shard_dynamic_shape:
|
||||
# make sure to not to shard on dynamic shape
|
||||
for dim, shard_list in self.dim_partition_dict.items():
|
||||
if len(shard_list) == 0:
|
||||
continue
|
||||
if len(self.raw_shape) == 0:
|
||||
continue
|
||||
if -1 == self.raw_shape[dim]:
|
||||
return False
|
||||
|
||||
# make sure that the sharding for a dimension is divisible by the number of devices
|
||||
for dim, shard_list in self.dim_partition_dict.items():
|
||||
if len(shard_list) == 0:
|
||||
continue
|
||||
tensor_dim_size = self.entire_shape[dim]
|
||||
num_devices = 1
|
||||
|
||||
for element in shard_list:
|
||||
num_devices *= self.device_mesh.mesh_shape[element]
|
||||
if num_devices == 1:
|
||||
# we only support RR when the device is 1
|
||||
return False
|
||||
|
||||
if not self.enable_shard_unbalanced_shape:
|
||||
if tensor_dim_size % num_devices != 0 or tensor_dim_size == 1:
|
||||
'''
|
||||
print(
|
||||
f'The size of static dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.'
|
||||
)
|
||||
'''
|
||||
return False
|
||||
else:
|
||||
if tensor_dim_size == 1:
|
||||
return False
|
||||
'''
|
||||
if self.get_sharded_size_per_device() > (2**31 - 1):
|
||||
print(
|
||||
f'memory footprint per device {self.get_sharded_size_per_device()} is larger than 2**31 - 1'
|
||||
)
|
||||
return False
|
||||
'''
|
||||
return True
|
||||
|
||||
def convert_dict_to_shard_sequence(self):
|
||||
'''
|
||||
Convert dim_partition_dict into list of DimSpec, and assign it to sharding_sequence.
|
||||
'''
|
||||
sharding_sequence = [DimSpec([])] * len(self.entire_shape)
|
||||
for dim, shard_list in self.dim_partition_dict.items():
|
||||
sharding_sequence[dim] = DimSpec(shard_list)
|
||||
self.sharding_sequence = sharding_sequence
|
||||
|
||||
def convert_shard_sequence_to_dict(self):
|
||||
'''
|
||||
Convert sharding_sequence into dim_partition_dict.
|
||||
'''
|
||||
new_dim_partition_dict = {}
|
||||
for index, dim_spec in enumerate(self.sharding_sequence):
|
||||
if not dim_spec.is_replica:
|
||||
if index not in new_dim_partition_dict:
|
||||
new_dim_partition_dict[index] = []
|
||||
new_dim_partition_dict[index].extend(dim_spec.shard_list)
|
||||
self.dim_partition_dict = new_dim_partition_dict
|
||||
|
||||
def sharding_sequence_difference(self, other):
|
||||
'''
|
||||
This function is a naive version of difference computation. It just simply accumulates difference every dimension between the
|
||||
pair of sharding sequence.
|
||||
|
||||
Example:
|
||||
dim_partition_dict = {0: [0, 1]}
|
||||
# DistSpec:
|
||||
# shard_sequence: S01,R,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
|
||||
dim_partition_dict_to_compare = {0: [0], 1: [1]}
|
||||
# DistSpec:
|
||||
# shard_sequence: S0,S1,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
|
||||
print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
|
||||
|
||||
Output:
|
||||
25
|
||||
|
||||
Argument:
|
||||
other(ShardingSpec): The ShardingSpec to compared with.
|
||||
|
||||
Return:
|
||||
difference(int): Difference between two ShardingSpec.
|
||||
'''
|
||||
assert len(self.sharding_sequence) == len(
|
||||
other.sharding_sequence
|
||||
), f'Cannot compare difference for two sharding specs with different length.'
|
||||
difference = 0
|
||||
for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence,
|
||||
other.sharding_sequence):
|
||||
difference += orig_dim_spec.difference(other_dim_spec)
|
||||
return difference
|
||||
|
||||
def get_sharded_shape_per_device(self, ):
|
||||
return self.sharded_shape
|
||||
|
||||
def get_sharded_element_per_device(self, ):
|
||||
sharded_shape = self.get_sharded_shape_per_device()
|
||||
if len(sharded_shape) == 0:
|
||||
num_elements = 1
|
||||
else:
|
||||
num_elements = trt.volume(sharded_shape)
|
||||
return num_elements
|
||||
|
||||
def get_sharded_size_per_device(self, ):
|
||||
num_elements = self.get_sharded_element_per_device()
|
||||
return num_elements * self.dtype_size
|
||||
|
||||
def get_max_sharded_shape_per_device(self, ):
|
||||
return self.max_sharded_shape
|
||||
|
||||
def get_max_sharded_element_per_device(self, ):
|
||||
max_sharded_shape = self.get_max_sharded_shape_per_device()
|
||||
if len(max_sharded_shape) == 0:
|
||||
num_elements = 1
|
||||
else:
|
||||
num_elements = trt.volume(max_sharded_shape)
|
||||
return num_elements
|
||||
|
||||
def get_max_sharded_size_per_device(self, ):
|
||||
num_elements = self.get_max_sharded_element_per_device()
|
||||
return num_elements * self.dtype_size
|
||||
@ -1,77 +0,0 @@
|
||||
class ShardingStrategy(object):
|
||||
|
||||
def __init__(self,
|
||||
name=None,
|
||||
sharding_specs=None,
|
||||
communication_actions=None):
|
||||
self.name = name or ""
|
||||
self.sharding_specs = sharding_specs or {}
|
||||
self.communication_actions = communication_actions
|
||||
self.sharding_cost = 0
|
||||
self.communication_cost = 0
|
||||
self.resharding_costs = {}
|
||||
self.best_resharding_cost = {}
|
||||
self.node_names = {}
|
||||
|
||||
self.comm_buff_memory_footprint = 0
|
||||
self.inout_memory_footprint = 0
|
||||
self.const_memory_footprint = 0
|
||||
self.peak_memory_footprint = 0
|
||||
self.computation_macs = 0
|
||||
self.alpha_beta_cost = 0
|
||||
|
||||
def print_strategy(self, best_resharding_cost_only=False, file=None):
|
||||
|
||||
def print_resharding_costs(resharding_cost):
|
||||
for prenode_node_name, rcosts in resharding_cost.items():
|
||||
if isinstance(prenode_node_name, int):
|
||||
idx = prenode_node_name
|
||||
prenode_node_name = self.node_names[idx]
|
||||
print(f' pre_node = {idx} {prenode_node_name}',
|
||||
file=file)
|
||||
else:
|
||||
print(f' pre_node = {prenode_node_name}', file=file)
|
||||
for idx, rcost in enumerate(rcosts):
|
||||
transpaths, commspecs, cost = rcost
|
||||
print(f' {idx}: ', end=' ', file=file)
|
||||
device_mesh.shape_consistency_manager.print_shape_consistency_result(
|
||||
transpaths, commspecs, cost, file)
|
||||
|
||||
print(f'name = {self.name}', file=file)
|
||||
print(f'sharding_cost = {self.sharding_cost}', file=file)
|
||||
print(
|
||||
f'communication_buffer_memory_footprint = {self.comm_buff_memory_footprint}, communication_cost = {self.communication_cost}',
|
||||
file=file)
|
||||
print(f'inout_memory_footprint = {self.inout_memory_footprint}',
|
||||
file=file)
|
||||
print(f'peak_memory_footprint = {self.peak_memory_footprint}',
|
||||
file=file)
|
||||
print(f'const_memory_footprint = {self.const_memory_footprint}',
|
||||
file=file)
|
||||
print('sharding_specs:', file=file)
|
||||
device_mesh = None
|
||||
for specname, spec in self.sharding_specs.items():
|
||||
print(specname + ', ', end=' ', file=file)
|
||||
spec.print_spec(file)
|
||||
device_mesh = spec.device_mesh
|
||||
|
||||
if best_resharding_cost_only and self.best_resharding_cost:
|
||||
print('best_resharding_costs:', file=file)
|
||||
print_resharding_costs(self.best_resharding_cost)
|
||||
else:
|
||||
print('resharding costs:', file=file)
|
||||
print_resharding_costs(self.resharding_costs)
|
||||
|
||||
|
||||
class StrategiesVector(list):
|
||||
'''
|
||||
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
|
||||
strategies of the node.
|
||||
|
||||
Argument:
|
||||
node (Node): node for which the list of sharding strategies are generated.
|
||||
'''
|
||||
|
||||
def __init__(self, node):
|
||||
super().__init__()
|
||||
self.node = node
|
||||
@ -1,238 +0,0 @@
|
||||
import copy
|
||||
|
||||
from .node import Node
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class Shuffle(Node):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
layer.to_subclass()
|
||||
self.first_tanspose_dims = layer.as_trt().first_transpose
|
||||
self.second_transpose_dims = layer.as_trt().second_transpose
|
||||
self.zero_is_placeholder = layer.as_trt().zero_is_placeholder
|
||||
self.is_first_transepose_identity = (sorted(
|
||||
self.first_tanspose_dims) == list(self.first_tanspose_dims))
|
||||
self.input_shape = self.get_input(0).shape
|
||||
self.is_second_transepose_identity = (sorted(
|
||||
self.second_transpose_dims) == list(self.second_transpose_dims))
|
||||
|
||||
output_shape = list(self.get_output(0).shape)
|
||||
self.reshape_dims = copy.deepcopy(output_shape)
|
||||
if not self.is_second_transepose_identity:
|
||||
for i in self.second_transpose_dims:
|
||||
if self.second_transpose_dims[i] != i:
|
||||
self.reshape_dims[
|
||||
self.second_transpose_dims[i]] = output_shape[i]
|
||||
self.is_reshape_identity = (list(self.reshape_dims) == list(
|
||||
self.input_shape))
|
||||
layer.to_base_class()
|
||||
|
||||
def _collect_transpose_strategies(self, device_mesh, transpose_dims):
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['input0'].shape)
|
||||
dim_partition_list.append({})
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
||||
strategies_vector = StrategiesVector(self)
|
||||
# dim_partition_dict can be the same as previous node if solver's time is a problem
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
in0_partition_dict = dim_partition_dict
|
||||
out_partition_dict = {}
|
||||
for split_dim, mesh_dim in in0_partition_dict.items():
|
||||
trans_dim = transpose_dims[split_dim]
|
||||
out_partition_dict[trans_dim] = mesh_dim
|
||||
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_partition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
if self.num_inputs == 2:
|
||||
dim_partition_dict_mapping["input1"] = {}
|
||||
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
name = '{} = <shuffle_transpose_only op> {}'.format(
|
||||
sharding_spec_mapping['output0'].sharding_sequence,
|
||||
sharding_spec_mapping['input0'].sharding_sequence)
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
return strategies_vector
|
||||
|
||||
def _find_reshape_partitions(self, input_shape, output_shape,
|
||||
input_partition_dict):
|
||||
len_input_shape, len_output_shape = len(input_shape), len(output_shape)
|
||||
output_partition_dict = {}
|
||||
i, j = 0, 0
|
||||
while i < len_input_shape or j < len_output_shape:
|
||||
if i < len_input_shape and input_shape[i] == 1:
|
||||
i = i + 1
|
||||
continue
|
||||
if j < len_output_shape and output_shape[j] == 1:
|
||||
j = j + 1
|
||||
continue
|
||||
|
||||
if input_shape[i] == output_shape[j]:
|
||||
if i in input_partition_dict:
|
||||
output_partition_dict[j] = input_partition_dict[i]
|
||||
# it keep the dimension, so need to keep the partition dims
|
||||
i, j = i + 1, j + 1
|
||||
|
||||
elif input_shape[i] < output_shape[j]:
|
||||
# we detect if the input dims are merged in the reshape dim
|
||||
value = input_shape[i]
|
||||
for ii in range(i + 1, len_input_shape):
|
||||
value = value * input_shape[ii]
|
||||
if value == output_shape[j]:
|
||||
# it is merged, we set the output's merged dim partition as all inputs' dims
|
||||
mesh_dim = []
|
||||
for in_dim in range(i, ii + 1):
|
||||
if in_dim in input_partition_dict:
|
||||
mesh_dim = mesh_dim + input_partition_dict[
|
||||
in_dim]
|
||||
if len(mesh_dim) > 0:
|
||||
output_partition_dict[j] = sorted(mesh_dim)
|
||||
i, j = ii + 1, j + 1
|
||||
break
|
||||
else:
|
||||
# we don't find the merged dimensions, the difference may from random reshape, we don't support it now
|
||||
return {}, {}
|
||||
else:
|
||||
# we detect if the input dim is split into reshape dims
|
||||
value = output_shape[j]
|
||||
for jj in range(j + 1, len_output_shape):
|
||||
value = value * output_shape[jj]
|
||||
if value == input_shape[i]:
|
||||
# it is split pattern
|
||||
if i in input_partition_dict:
|
||||
output_partition_dict[j] = input_partition_dict[i]
|
||||
i, j = i + 1, jj + 1
|
||||
break
|
||||
else:
|
||||
# we don't find the split dimensions, the difference may from random reshape
|
||||
return {}, {}
|
||||
return input_partition_dict, output_partition_dict
|
||||
|
||||
def _collect_reshape_strategies(self, device_mesh):
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['input0'].shape)
|
||||
dim_partition_list.append({})
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
||||
strategies_vector = StrategiesVector(self)
|
||||
# dim_partition_dict can be the same as previous node if solver's time is a problem
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
in0_partition_dict = dim_partition_dict
|
||||
in0_partition_dict, out_partition_dict = self._find_reshape_partitions(
|
||||
self.input_shape, self.reshape_dims, in0_partition_dict)
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_partition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
if self.num_inputs == 2:
|
||||
dim_partition_dict_mapping["input1"] = {}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
name = '{} = <shuffle_reshape op> {}'.format(
|
||||
sharding_spec_mapping['output0'].sharding_sequence,
|
||||
sharding_spec_mapping['input0'].sharding_sequence)
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
return strategies_vector
|
||||
|
||||
def _collect_identity_strategies(self, device_mesh):
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['input0'].shape)
|
||||
dim_partition_list.append({})
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
||||
strategies_vector = StrategiesVector(self)
|
||||
# dim_partition_dict can be the same as previous node if solver's time is a problem
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
in0_partition_dict = dim_partition_dict
|
||||
out_partition_dict = copy.deepcopy(dim_partition_dict)
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_partition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
if self.num_inputs == 2:
|
||||
dim_partition_dict_mapping["input1"] = {}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
name = '{} = <shuffle_identity op> {}'.format(
|
||||
sharding_spec_mapping['output0'].sharding_sequence,
|
||||
sharding_spec_mapping['input0'].sharding_sequence)
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
return strategies_vector
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
is_identify_list = (self.is_first_transepose_identity,
|
||||
self.is_reshape_identity,
|
||||
self.is_second_transepose_identity)
|
||||
if is_identify_list == (True, True, True):
|
||||
return self._collect_identity_strategies(device_mesh)
|
||||
elif is_identify_list == (True, True, False):
|
||||
return self._collect_transpose_strategies(
|
||||
device_mesh, self.second_transpose_dims)
|
||||
elif is_identify_list == (False, True, True):
|
||||
return self._collect_transpose_strategies(device_mesh,
|
||||
self.first_transpose_dims)
|
||||
elif is_identify_list == (True, False, True):
|
||||
return self._collect_reshape_strategies(device_mesh)
|
||||
else:
|
||||
assert False, f"Unsupported shuffle pattern now {is_identify_list}"
|
||||
|
||||
def _profile_sharding_cost(self, strategy, device_mesh):
|
||||
updated_layer_attrs = {}
|
||||
updated_input_values = {}
|
||||
output_shape = strategy.sharding_specs[
|
||||
'output0'].get_sharded_shape_per_device()
|
||||
self.layer.to_subclass()
|
||||
second_transpose = self.layer.as_trt().second_transpose
|
||||
self.layer.to_base_class()
|
||||
reshape_dims = [*output_shape]
|
||||
for i in range(len(output_shape)):
|
||||
reshape_dims[second_transpose[i]] = output_shape[i]
|
||||
if self.layer.num_inputs >= 2:
|
||||
updated_input_values[1] = reshape_dims
|
||||
else:
|
||||
updated_layer_attrs['reshape_dims'] = reshape_dims
|
||||
elapsed_time = self.node_runtime_profiler.runtime_profile(
|
||||
self.layer, updated_layer_attrs, updated_input_values, strategy,
|
||||
device_mesh)
|
||||
return elapsed_time
|
||||
@ -1,100 +0,0 @@
|
||||
import copy
|
||||
|
||||
from .node import Node
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class Slice(Node):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
layer.to_subclass()
|
||||
input_shape = self.get_input(0).shape
|
||||
output_shape = self.get_output(0).shape
|
||||
assert len(input_shape) == len(
|
||||
output_shape
|
||||
), f'dims of input shape {input_shape} != dims of output shape {output_shape}'
|
||||
if layer.num_inputs >= 2 and layer.get_input(1) is not None:
|
||||
start = layer.get_input(1).value
|
||||
else:
|
||||
start = layer.as_trt().start
|
||||
if layer.num_inputs >= 4 and layer.get_input(3) is not None:
|
||||
stride = layer.get_input(3).value
|
||||
else:
|
||||
stride = layer.as_trt().stride
|
||||
self.keep_partition_dims = [(input_shape[i] == output_shape[i]
|
||||
and start[i] == 0 and stride[i] == 1)
|
||||
for i in range(len(input_shape))]
|
||||
layer.to_base_class()
|
||||
|
||||
def _update_memory_cost(self, strategies):
|
||||
for strategy in strategies:
|
||||
# for slice node, it input0's read = output0's write
|
||||
inout_memory_footprint = strategy.sharding_specs[
|
||||
'output0'].get_sharded_size_per_device() * 2
|
||||
strategy.inout_memory_footprint = inout_memory_footprint
|
||||
strategy.peak_memory_footprint = (
|
||||
strategy.sharding_specs['input0'].
|
||||
get_max_sharded_size_per_device() + strategy.
|
||||
sharding_specs['output0'].get_max_sharded_size_per_device())
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['input0'].shape)
|
||||
dim_partition_list.append({})
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
||||
strategies_vector = StrategiesVector(self)
|
||||
# dim_partition_dict can be the same as previous node if solver's time is a problem
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
for dim in range(len(self.keep_partition_dims)):
|
||||
if (not self.keep_partition_dims[dim]
|
||||
) and dim in dim_partition_dict:
|
||||
dim_partition_dict.pop(dim)
|
||||
|
||||
in0_partition_dict = dim_partition_dict
|
||||
out_partition_dict = copy.deepcopy(dim_partition_dict)
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_partition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
for i in range(1, self.num_inputs):
|
||||
if self.predecessor_nodes[i]:
|
||||
dim_partition_dict_mapping[f"input{i}"] = {}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
name = '{} = {} <slice op> '.format(
|
||||
sharding_spec_mapping['output0'].sharding_sequence,
|
||||
sharding_spec_mapping['input0'].sharding_sequence)
|
||||
for i in range(1, self.num_inputs):
|
||||
if self.predecessor_nodes[i]:
|
||||
name = name + str(
|
||||
sharding_spec_mapping[f'input{i}'].sharding_sequence)
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
return strategies_vector
|
||||
|
||||
def _profile_sharding_cost(self, strategy, device_mesh):
|
||||
updated_layer_attrs = {}
|
||||
updated_input_values = {}
|
||||
shape = strategy.sharding_specs['output0'].get_sharded_shape_per_device(
|
||||
)
|
||||
if self.layer.num_inputs >= 3 and self.layer.get_input(2) is not None:
|
||||
updated_input_values[2] = shape
|
||||
else:
|
||||
updated_layer_attrs['shape'] = shape
|
||||
elapsed_time = self.node_runtime_profiler.runtime_profile(
|
||||
self.layer, updated_layer_attrs, updated_input_values, strategy,
|
||||
device_mesh)
|
||||
return elapsed_time
|
||||
@ -1,54 +0,0 @@
|
||||
import copy
|
||||
|
||||
from tensorrt_llm._utils import trt_axes_to_dim
|
||||
|
||||
from .node import Node
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class SoftMax(Node):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
layer.to_subclass()
|
||||
self.softmax_dim = trt_axes_to_dim(layer.as_trt().axes)[0]
|
||||
layer.to_base_class()
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['input0'].shape)
|
||||
dim_partition_list.append({})
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
||||
strategies_vector = StrategiesVector(self)
|
||||
# dim_partition_dict can be the same as previous node if solver's time is a problem
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
if self.softmax_dim in dim_partition_dict:
|
||||
dim_partition_dict.pop(self.softmax_dim)
|
||||
|
||||
in0_partition_dict = dim_partition_dict
|
||||
out_partition_dict = copy.deepcopy(dim_partition_dict)
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_partition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
name = '{} = <softmax along dim {}> {}'.format(
|
||||
sharding_spec_mapping['output0'].sharding_sequence,
|
||||
self.softmax_dim,
|
||||
sharding_spec_mapping['input0'].sharding_sequence)
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
return strategies_vector
|
||||
@ -1,42 +0,0 @@
|
||||
import copy
|
||||
|
||||
from .node import Node
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
|
||||
class Unary(Node):
|
||||
|
||||
def _collect_strategies(self, device_mesh):
|
||||
dim_partition_list = []
|
||||
dim_size = len(self.op_data['input0'].shape)
|
||||
dim_partition_list.append({})
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
|
||||
dim_partition_list.extend(
|
||||
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
|
||||
strategies_vector = StrategiesVector(self)
|
||||
# dim_partition_dict can be the same as previous node if solver's time is a problem
|
||||
for dim_partition_dict in dim_partition_list:
|
||||
in0_partition_dict = dim_partition_dict
|
||||
out_partition_dict = copy.deepcopy(dim_partition_dict)
|
||||
dim_partition_dict_mapping = {
|
||||
"input0": in0_partition_dict,
|
||||
"output0": out_partition_dict,
|
||||
}
|
||||
sharding_spec_mapping = self._to_sharding_spec_mapping(
|
||||
dim_partition_dict_mapping, device_mesh)
|
||||
if 0 == len(sharding_spec_mapping):
|
||||
continue
|
||||
name = '{} = <unary op> {}'.format(
|
||||
sharding_spec_mapping['output0'].sharding_sequence,
|
||||
sharding_spec_mapping['input0'].sharding_sequence)
|
||||
sharding_strategy = self._get_sharding_strategy(
|
||||
name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
strategies_vector.append(sharding_strategy)
|
||||
return strategies_vector
|
||||
@ -1,308 +0,0 @@
|
||||
import contextlib
|
||||
import threading
|
||||
|
||||
try:
|
||||
from types import NoneType
|
||||
except ImportError:
|
||||
NoneType = type(None)
|
||||
from typing import ByteString, Iterable, MutableMapping
|
||||
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._utils import get_extra_attr, np_dtype_to_trt, set_extra_attr
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.network import PluginInfo, get_plugin_info
|
||||
|
||||
LAYER_TYPE_2_CLASS = {
|
||||
trt.LayerType.ACTIVATION: trt.IActivationLayer,
|
||||
trt.LayerType.CONCATENATION: trt.IConcatenationLayer,
|
||||
trt.LayerType.CONSTANT: trt.IConstantLayer,
|
||||
trt.LayerType.ELEMENTWISE: trt.IElementWiseLayer,
|
||||
trt.LayerType.FILL: trt.IFillLayer,
|
||||
trt.LayerType.GATHER: trt.IGatherLayer,
|
||||
trt.LayerType.MATRIX_MULTIPLY: trt.IMatrixMultiplyLayer,
|
||||
trt.LayerType.REDUCE: trt.IReduceLayer,
|
||||
trt.LayerType.SELECT: trt.ISelectLayer,
|
||||
trt.LayerType.SHUFFLE: trt.IShuffleLayer,
|
||||
trt.LayerType.SLICE: trt.ISliceLayer,
|
||||
trt.LayerType.SOFTMAX: trt.ISoftMaxLayer,
|
||||
trt.LayerType.UNARY: trt.IUnaryLayer,
|
||||
trt.LayerType.SHAPE: trt.IShapeLayer,
|
||||
trt.LayerType.ASSERTION: trt.IAssertionLayer,
|
||||
trt.LayerType.CAST: trt.ICastLayer,
|
||||
trt.LayerType.NORMALIZATION: trt.INormalizationLayer,
|
||||
trt.LayerType.IDENTITY: trt.IIdentityLayer,
|
||||
trt.LayerType.PLUGIN_V2: trt.IPluginV2Layer,
|
||||
}
|
||||
|
||||
|
||||
def to_subclass_layer(trt_layer):
|
||||
trt_layer.__class__ = LAYER_TYPE_2_CLASS[trt_layer.type]
|
||||
|
||||
|
||||
def to_base_class_layer(trt_layer):
|
||||
trt_layer.__class__ = trt.ILayer
|
||||
|
||||
|
||||
def to_trt_weights(ndarray):
|
||||
weight = trt.Weights(
|
||||
np_dtype_to_trt(ndarray.dtype),
|
||||
ndarray.ctypes.data,
|
||||
ndarray.size,
|
||||
)
|
||||
# Prevent numpy array from going out of weight's lifetime scope
|
||||
set_extra_attr(weight, "numpy", ndarray)
|
||||
return weight
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def silent_trt_logger():
|
||||
min_severity = logger.trt_logger.min_severity
|
||||
logger.trt_logger.min_severity = trt.Logger.ERROR
|
||||
yield
|
||||
logger.trt_logger.min_severity = min_severity
|
||||
|
||||
|
||||
def compare_tensor(trt_tensor, new_trt_tensor):
|
||||
assert trt_tensor.name == new_trt_tensor.name
|
||||
assert trt_tensor.dtype == new_trt_tensor.dtype
|
||||
assert tuple(trt_tensor.shape) == tuple(new_trt_tensor.shape)
|
||||
assert trt_tensor.broadcast_across_batch == new_trt_tensor.broadcast_across_batch
|
||||
assert trt_tensor.location == new_trt_tensor.location
|
||||
assert trt_tensor.is_network_input == new_trt_tensor.is_network_input
|
||||
assert trt_tensor.is_network_output == new_trt_tensor.is_network_output
|
||||
assert trt_tensor.dynamic_range == new_trt_tensor.dynamic_range
|
||||
assert trt_tensor.is_shape_tensor == new_trt_tensor.is_shape_tensor
|
||||
assert trt_tensor.is_execution_tensor == new_trt_tensor.is_execution_tensor
|
||||
assert trt_tensor.allowed_formats == new_trt_tensor.allowed_formats
|
||||
|
||||
|
||||
def compare_network(trt_network, new_trt_network):
|
||||
assert trt_network.num_inputs == new_trt_network.num_inputs
|
||||
for i in range(trt_network.num_inputs):
|
||||
input = trt_network.get_input(i)
|
||||
new_input = new_trt_network.get_input(i)
|
||||
compare_tensor(input, new_input)
|
||||
assert trt_network.num_outputs == new_trt_network.num_outputs
|
||||
for i in range(trt_network.num_outputs):
|
||||
output = trt_network.get_output(i)
|
||||
new_output = new_trt_network.get_output(i)
|
||||
compare_tensor(output, new_output)
|
||||
assert trt_network.num_layers == new_trt_network.num_layers
|
||||
for index, new_index in zip(get_sorted_layer_ids(trt_network),
|
||||
get_sorted_layer_ids(new_trt_network)):
|
||||
layer = trt_network.get_layer(index)
|
||||
new_layer = new_trt_network.get_layer(new_index)
|
||||
assert layer.name == new_layer.name
|
||||
assert layer.type == new_layer.type
|
||||
assert layer.precision_is_set == new_layer.precision_is_set
|
||||
assert layer.precision == new_layer.precision
|
||||
assert layer.num_inputs == new_layer.num_inputs
|
||||
for j in range(layer.num_inputs):
|
||||
input = layer.get_input(j)
|
||||
new_input = new_layer.get_input(j)
|
||||
if input is None:
|
||||
assert new_input is None
|
||||
else:
|
||||
assert new_input is not None
|
||||
compare_tensor(input, new_input)
|
||||
assert layer.num_outputs == new_layer.num_outputs
|
||||
for j in range(layer.num_outputs):
|
||||
output = layer.get_output(j)
|
||||
new_output = new_layer.get_output(j)
|
||||
compare_tensor(output, new_output)
|
||||
assert layer.output_type_is_set(j) == new_layer.output_type_is_set(
|
||||
j)
|
||||
if layer.output_type_is_set(j):
|
||||
assert layer.get_output_type(j) == new_layer.get_output_type(j)
|
||||
|
||||
|
||||
def get_sorted_layer_ids(trt_network):
|
||||
inputs = set()
|
||||
for i in range(trt_network.num_inputs):
|
||||
inputs.add(trt_network.get_input(i).name)
|
||||
layer_ids = [*range(trt_network.num_layers)]
|
||||
sorted_layer_ids = []
|
||||
walked_tensors = set(inputs)
|
||||
while len(layer_ids) > 0:
|
||||
layer_id = layer_ids.pop(0)
|
||||
layer = trt_network.get_layer(layer_id)
|
||||
no_dependencies = True
|
||||
for j in range(layer.num_inputs):
|
||||
input = layer.get_input(j)
|
||||
if input is None:
|
||||
continue
|
||||
if input.name in walked_tensors:
|
||||
continue
|
||||
else:
|
||||
no_dependencies = False
|
||||
break
|
||||
if no_dependencies:
|
||||
sorted_layer_ids.append(layer_id)
|
||||
for j in range(layer.num_outputs):
|
||||
output = layer.get_output(j)
|
||||
if output is None:
|
||||
continue
|
||||
walked_tensors.add(output.name)
|
||||
else:
|
||||
layer_ids.append(layer_id)
|
||||
assert len(sorted_layer_ids) == trt_network.num_layers
|
||||
return sorted_layer_ids
|
||||
|
||||
|
||||
def to_tuple(values):
|
||||
if isinstance(values, (int, float, str, bool, NoneType, ByteString)):
|
||||
return values
|
||||
elif isinstance(values, (trt.Dims, trt.Permutation)):
|
||||
if values.__len__() < 0:
|
||||
return None
|
||||
else:
|
||||
return tuple(values)
|
||||
elif isinstance(values, Iterable):
|
||||
return tuple(to_tuple(v) for v in values)
|
||||
elif isinstance(values, MutableMapping):
|
||||
return tuple((k, to_tuple(v)) for k, v in values.items())
|
||||
else:
|
||||
return values
|
||||
|
||||
|
||||
_base_layer_attr_names = set(dir(trt.ILayer))
|
||||
|
||||
|
||||
def get_cache_key(layer, shapes, values, dtypes=None, updated_attrs=None):
|
||||
updated_attrs = updated_attrs or {}
|
||||
layer_type = layer.type
|
||||
to_subclass_layer(layer)
|
||||
attr_names = set(dir(layer)) - _base_layer_attr_names
|
||||
if layer_type == trt.LayerType.CONSTANT:
|
||||
attr_names.remove("weights")
|
||||
elif layer_type == trt.LayerType.SHUFFLE:
|
||||
if layer.num_inputs >= 2:
|
||||
attr_names.remove("reshape_dims")
|
||||
elif layer_type == trt.LayerType.SLICE:
|
||||
if layer.num_inputs >= 2 and layer.get_input(1) is not None:
|
||||
attr_names.remove("start")
|
||||
if layer.num_inputs >= 3 and layer.get_input(2) is not None:
|
||||
attr_names.remove("shape")
|
||||
if layer.num_inputs >= 4 and layer.get_input(3) is not None:
|
||||
attr_names.remove("stride")
|
||||
elif layer_type == trt.LayerType.FILL:
|
||||
attr_names.remove("is_alpha_beta_int64")
|
||||
if layer.num_inputs >= 1 and layer.get_input(0) is not None:
|
||||
attr_names.remove("shape")
|
||||
if layer.num_inputs >= 2 and layer.get_input(1) is not None:
|
||||
attr_names.remove("alpha")
|
||||
if layer.num_inputs >= 3 and layer.get_input(2) is not None:
|
||||
attr_names.remove("beta")
|
||||
if layer_type != trt.LayerType.PLUGIN_V2:
|
||||
attr_key = tuple(
|
||||
(name, to_tuple(updated_attrs.get(name) or getattr(layer, name)))
|
||||
for name in sorted(attr_names))
|
||||
else:
|
||||
network = get_trt_network(layer)
|
||||
plugin_info = get_plugin_info(network, layer.name)
|
||||
assert plugin_info is not None, f"layer {layer.name} does not register plugin info"
|
||||
attr_key = tuple(
|
||||
(name, tuple(updated_attrs.get(name) or data))
|
||||
for name, data in sorted(plugin_info.pfc_as_list.items()))
|
||||
to_base_class_layer(layer)
|
||||
shape_key = ()
|
||||
value_key = ()
|
||||
dtype_key = ()
|
||||
for i in range(layer.num_inputs):
|
||||
input = layer.get_input(i)
|
||||
if input is not None:
|
||||
shape_key += (tuple(shapes[input.name]), )
|
||||
if input.name in values:
|
||||
value = values[input.name]
|
||||
# All torch tensors are derived from input shapes and pfc,
|
||||
# thus we ignore them in cache key
|
||||
if isinstance(value, torch.Tensor):
|
||||
value = None
|
||||
else:
|
||||
value = tuple(value)
|
||||
value_key += (value, )
|
||||
else:
|
||||
value_key += (None, )
|
||||
if dtypes is not None:
|
||||
dtype_key += (dtypes[input.name], )
|
||||
else:
|
||||
shape_key += (None, )
|
||||
value_key += (None, )
|
||||
dtype_key += (None, )
|
||||
if dtypes is not None:
|
||||
for i in range(layer.num_outputs):
|
||||
output = layer.get_output(i)
|
||||
dtype_key += (dtypes[output.name], )
|
||||
cache_key = (layer.type, attr_key, shape_key, value_key)
|
||||
if dtypes is not None:
|
||||
cache_key += (dtype_key, )
|
||||
return cache_key
|
||||
|
||||
|
||||
def get_trt_network(layer: trt.ILayer):
|
||||
network = get_extra_attr(layer, "network")
|
||||
assert network is not None
|
||||
return network
|
||||
|
||||
|
||||
def set_trt_network(layer: trt.ILayer, network: trt.INetworkDefinition):
|
||||
set_extra_attr(layer, "network", network)
|
||||
|
||||
|
||||
def get_updated_plugin(plugin_info: PluginInfo, updated_attrs):
|
||||
fields = []
|
||||
for field in plugin_info.pfc:
|
||||
name = field.name
|
||||
if name in updated_attrs:
|
||||
field = trt.PluginField(name, updated_attrs[name], field.type)
|
||||
else:
|
||||
field = trt.PluginField(name, plugin_info.pfc_as_ndarray[name],
|
||||
field.type)
|
||||
fields.append(field)
|
||||
pfc = trt.PluginFieldCollection(fields)
|
||||
plugin = plugin_info.plugin_creator.create_plugin(plugin_info.plugin_name,
|
||||
pfc)
|
||||
new_plugin_info = PluginInfo(plugin_info.plugin_creator,
|
||||
plugin_info.plugin_name, pfc)
|
||||
return plugin, new_plugin_info
|
||||
|
||||
|
||||
_builder_flags = threading.local()
|
||||
_strongly_typed = threading.local()
|
||||
|
||||
|
||||
def get_builder_flags():
|
||||
return getattr(_builder_flags, 'value', 0)
|
||||
|
||||
|
||||
def get_strongly_typed():
|
||||
return getattr(_strongly_typed, 'value', False)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def current_flags(builder_flags, strongly_typed):
|
||||
previous_builder_flags = get_builder_flags()
|
||||
_builder_flags.value = builder_flags
|
||||
previous_strongly_typed = get_strongly_typed()
|
||||
_strongly_typed.value = strongly_typed
|
||||
yield
|
||||
_builder_flags.value = previous_builder_flags
|
||||
_strongly_typed.value = previous_strongly_typed
|
||||
|
||||
|
||||
def get_engine_information(engine_file) -> str:
|
||||
with open(engine_file, "rb") as f:
|
||||
engine_buffer = f.read()
|
||||
runtime = trt.Runtime(logger.trt_logger)
|
||||
engine = runtime.deserialize_cuda_engine(engine_buffer)
|
||||
inspector = engine.create_engine_inspector()
|
||||
return inspector.get_engine_information(trt.LayerInformationFormat.JSON)
|
||||
|
||||
|
||||
def print_engine_info(engine_file) -> dict:
|
||||
with open(engine_file, "rb") as f:
|
||||
engine_buffer = f.read()
|
||||
from tensorrt_llm.runtime.session import Session
|
||||
Session.from_serialized_engine(engine_buffer)._print_engine_info()
|
||||
@ -44,16 +44,8 @@ def get_settings_from_engine(
|
||||
with open(config_path, "r") as config_json:
|
||||
config = json.load(config_json)
|
||||
|
||||
engine_world_map = config["pretrained_config"]["mapping"]
|
||||
mapping = config["pretrained_config"]["mapping"]
|
||||
engine_build_cfg = config["build_config"]
|
||||
engine_parallel_map = engine_build_cfg["auto_parallel_config"]
|
||||
|
||||
world_config = {
|
||||
"pp_size": engine_world_map["pp_size"],
|
||||
"tp_size": engine_world_map["tp_size"],
|
||||
"world_size": engine_world_map["world_size"],
|
||||
"gpus_per_node": engine_parallel_map["gpus_per_node"],
|
||||
}
|
||||
|
||||
executor_settings = {
|
||||
"max_batch_size": engine_build_cfg["max_batch_size"],
|
||||
@ -64,7 +56,7 @@ def get_settings_from_engine(
|
||||
"sw_version": config["version"],
|
||||
"engine_dir": str(engine_path.absolute()),
|
||||
"settings_config": executor_settings,
|
||||
"world_config": world_config,
|
||||
"mapping": mapping,
|
||||
})
|
||||
|
||||
runtime_config["performance_options"] = {}
|
||||
@ -104,12 +96,13 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
|
||||
enable_chunked_prefill = llm_args_dict.get("enable_chunked_prefill",
|
||||
enable_chunked_prefill)
|
||||
|
||||
world_config = {
|
||||
mapping = {
|
||||
"pp_size": params.get("pp"),
|
||||
"tp_size": params.get("tp"),
|
||||
"world_size": params.get("pp") * params.get("tp"),
|
||||
"ep_size": params.get("ep"),
|
||||
"cluster_size": params.get("cluster_size"),
|
||||
"moe_ep_size": params.get("ep"),
|
||||
"moe_cluster_size": params.get("cluster_size"),
|
||||
"gpus_per_node": params.get("gpus_per_node"),
|
||||
}
|
||||
|
||||
if params.get("max_batch_size") and params.get("max_num_tokens"):
|
||||
@ -184,7 +177,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
|
||||
"max_num_tokens": int(max_num_tokens),
|
||||
"chunking": enable_chunked_prefill,
|
||||
},
|
||||
"world_config": world_config,
|
||||
"mapping": mapping,
|
||||
"backend": backend,
|
||||
"decoding_config": {},
|
||||
"performance_options": {
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
@ -29,7 +28,9 @@ class RuntimeConfig(BaseModel):
|
||||
engine_dir: Optional[Path] = None
|
||||
sw_version: str
|
||||
settings_config: ExecutorSettingsConfig
|
||||
world_config: ExecutorWorldConfig
|
||||
# TODO: this is a dict corresponding to the Mapping class, the type should be
|
||||
# changed to Mapping after the Mapping class is migrated to a Pydantic model.
|
||||
mapping: Dict[str, Any]
|
||||
decoding_config: Optional[DecodingConfig] = None
|
||||
performance_options: PerformanceOptions
|
||||
backend: Literal["pytorch", "_autodeploy", None] = None
|
||||
@ -47,15 +48,15 @@ class RuntimeConfig(BaseModel):
|
||||
"skip_tokenizer_init":
|
||||
True,
|
||||
"pipeline_parallel_size":
|
||||
self.world_config.pp_size,
|
||||
self.mapping["pp_size"],
|
||||
"tensor_parallel_size":
|
||||
self.world_config.tp_size,
|
||||
self.mapping["tp_size"],
|
||||
"gpus_per_node":
|
||||
self.world_config.gpus_per_node,
|
||||
self.mapping["gpus_per_node"],
|
||||
"moe_expert_parallel_size":
|
||||
self.world_config.ep_size,
|
||||
self.mapping["moe_ep_size"],
|
||||
"moe_cluster_parallel_size":
|
||||
self.world_config.cluster_size,
|
||||
self.mapping["moe_cluster_size"],
|
||||
"trust_remote_code":
|
||||
True,
|
||||
"enable_chunked_prefill":
|
||||
@ -164,57 +165,6 @@ class DecodingConfig(BaseModel):
|
||||
return trtllm.DecodingConfig(**kwargs)
|
||||
|
||||
|
||||
class ExecutorWorldConfig(BaseModel):
|
||||
pp_size: int = 1
|
||||
tp_size: int = 1
|
||||
# None to make LLM-API deduce it with a rule.
|
||||
gpus_per_node: Optional[int] = None
|
||||
leader_mode: bool = False
|
||||
ep_size: Optional[int] = None
|
||||
cluster_size: Optional[int] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_world_size(self) -> ExecutorWorldConfig:
|
||||
if self.gpus_per_node is None:
|
||||
return self
|
||||
|
||||
parallel_world = self.pp_size * self.tp_size
|
||||
num_gpus = self.world_size * self.gpus_per_node
|
||||
valid_world = bool(num_gpus >= parallel_world)
|
||||
|
||||
if not valid_world:
|
||||
raise ValueError(
|
||||
f"World configuration is invalid, TP * PP ({parallel_world})"
|
||||
"does not equal the total number of available GPUs"
|
||||
f"({num_gpus}).")
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return self.pp_size * self.tp_size
|
||||
|
||||
def _get_tensorrt_llm_executor_worker_path(self) -> Path:
|
||||
module_path = find_spec("tensorrt_llm").loader.get_filename()
|
||||
exec_path = Path(module_path).parent / 'bin' / 'executorWorker'
|
||||
return exec_path.absolute()
|
||||
|
||||
def get_parallel_config(self) -> trtllm.ParallelConfig:
|
||||
if self.leader_mode:
|
||||
comm_mode = trtllm.CommunicationMode.LEADER
|
||||
orchestrator_config = None
|
||||
else:
|
||||
comm_mode = trtllm.CommunicationMode.ORCHESTRATOR
|
||||
orchestrator_config = trtllm.OrchestratorConfig(
|
||||
True, str(self._get_tensorrt_llm_executor_worker_path()))
|
||||
|
||||
return trtllm.ParallelConfig(
|
||||
trtllm.CommunicationType.MPI,
|
||||
comm_mode,
|
||||
orchestrator_config=orchestrator_config,
|
||||
)
|
||||
|
||||
|
||||
class ExecutorSettingsConfig(BaseModel):
|
||||
chunking: bool = True
|
||||
scheduler_policy: CapacitySchedulerPolicy = Field(
|
||||
|
||||
@ -336,10 +336,10 @@ class ReportUtility:
|
||||
|
||||
# World and runtime info
|
||||
stats_dict["world_info"] = {
|
||||
"tp_size": self.rt_cfg.world_config.tp_size,
|
||||
"pp_size": self.rt_cfg.world_config.pp_size,
|
||||
"ep_size": self.rt_cfg.world_config.ep_size,
|
||||
"world_size": self.rt_cfg.world_config.world_size,
|
||||
"tp_size": self.rt_cfg.mapping["tp_size"],
|
||||
"pp_size": self.rt_cfg.mapping["pp_size"],
|
||||
"ep_size": self.rt_cfg.mapping["moe_ep_size"],
|
||||
"world_size": self.rt_cfg.mapping["world_size"],
|
||||
"max_batch_size": self.rt_cfg.settings_config.max_batch_size,
|
||||
"max_num_tokens": self.rt_cfg.settings_config.max_num_tokens,
|
||||
"scheduling_policy": self.rt_cfg.settings_config.scheduler_policy,
|
||||
@ -380,7 +380,7 @@ class ReportUtility:
|
||||
self.per_user_output_throughput_tok_s,
|
||||
# Output throughput per GPU (total throughput / world size)
|
||||
"output_throughput_per_gpu_tok_s":
|
||||
self.output_throughput_tok_s / self.rt_cfg.world_config.world_size,
|
||||
self.output_throughput_tok_s / self.rt_cfg.mapping["world_size"],
|
||||
# Request latency percentiles
|
||||
"request_latency_percentiles_ms":
|
||||
self.statistics.request_latency_percentiles.model_dump(
|
||||
|
||||
@ -30,8 +30,6 @@ import tensorrt as trt
|
||||
from ._common import _is_building, check_max_num_tokens, serialize_engine
|
||||
from ._utils import (get_sm_version, np_bfloat16, np_float8, str_dtype_to_trt,
|
||||
to_json_file, trt_gte)
|
||||
from .auto_parallel import auto_parallel
|
||||
from .auto_parallel.config import AutoParallelConfig
|
||||
from .bindings import KVCacheType
|
||||
from .functional import PositionEmbeddingType
|
||||
from .graph_rewriting import optimize
|
||||
@ -84,19 +82,12 @@ class BuilderConfig(object):
|
||||
"plugin_config": {
|
||||
# the network plugin_config (if any) attached to this BuilderConfig object
|
||||
# inside the Builder.build_engine
|
||||
},
|
||||
"auto_parallel_config": {
|
||||
# the network auto_parallel_config (if any) attached to this BuilderConfig object
|
||||
# inside the Builder.build_engine
|
||||
}
|
||||
}
|
||||
'''
|
||||
config = {'builder_config': {}}
|
||||
for k in self.__dict__.keys():
|
||||
if k not in [
|
||||
'_trt_builder_config', 'plugin_config',
|
||||
'auto_parallel_config'
|
||||
]:
|
||||
if k not in ['_trt_builder_config', 'plugin_config']:
|
||||
config['builder_config'][k] = self.__getattribute__(k)
|
||||
if hasattr(self, 'plugin_config'):
|
||||
assert isinstance(self.plugin_config, PluginConfig), \
|
||||
@ -270,17 +261,6 @@ class Builder():
|
||||
min_shape = [*shape_profile.min]
|
||||
opt_shape = [*shape_profile.opt]
|
||||
max_shape = [*shape_profile.max]
|
||||
if network._auto_parallel_config is not None:
|
||||
io_shards = network._auto_parallel_config["io_shards"]
|
||||
if input_name in io_shards:
|
||||
shards = io_shards[input_name]
|
||||
for dim, shard_num in shards.items():
|
||||
min_shape[dim] = int(
|
||||
math.floor(min_shape[dim] / shard_num))
|
||||
opt_shape[dim] = int(
|
||||
round(opt_shape[dim] / shard_num))
|
||||
max_shape[dim] = int(
|
||||
math.ceil(max_shape[dim] / shard_num))
|
||||
profile.set_shape(input_name, min_shape, opt_shape, max_shape)
|
||||
logger.debug(
|
||||
f'{input_name}, min: {min_shape}, opt: {opt_shape}, max: {max_shape}, dimension names: {shape_profile.dimension_names}'
|
||||
@ -389,7 +369,6 @@ class Builder():
|
||||
'''
|
||||
assert isinstance(network, Network)
|
||||
builder_config.plugin_config = network.plugin_config
|
||||
builder_config.auto_parallel_config = network.auto_parallel_config
|
||||
if builder_config.trt_builder_config.num_optimization_profiles == 0:
|
||||
self._add_optimization_profile(network, builder_config)
|
||||
logger.info(
|
||||
@ -506,7 +485,6 @@ class BuildConfig:
|
||||
input_timing_cache (str, optional): Path to input timing cache file. If None, no input cache used. Defaults to None.
|
||||
output_timing_cache (str): Path to output timing cache file. Defaults to 'model.cache'.
|
||||
lora_config (LoraConfig): Configuration for LoRA (Low-Rank Adaptation) fine-tuning. Defaults to default LoraConfig.
|
||||
auto_parallel_config (AutoParallelConfig): Configuration for automatic parallelization. Defaults to default AutoParallelConfig.
|
||||
weight_sparsity (bool): Whether to enable weight sparsity optimization. Defaults to False.
|
||||
weight_streaming (bool): Whether to enable weight streaming for large models. Defaults to False.
|
||||
plugin_config (PluginConfig): Configuration for TensorRT LLM plugins. Defaults to default PluginConfig.
|
||||
@ -538,8 +516,6 @@ class BuildConfig:
|
||||
input_timing_cache: str = None
|
||||
output_timing_cache: str = 'model.cache'
|
||||
lora_config: LoraConfig = field(default_factory=LoraConfig)
|
||||
auto_parallel_config: AutoParallelConfig = field(
|
||||
default_factory=AutoParallelConfig)
|
||||
weight_sparsity: bool = False
|
||||
weight_streaming: bool = False
|
||||
plugin_config: PluginConfig = field(default_factory=PluginConfig)
|
||||
@ -659,9 +635,7 @@ class BuildConfig:
|
||||
defaults.get('input_timing_cache'))
|
||||
output_timing_cache = config.pop('output_timing_cache',
|
||||
defaults.get('output_timing_cache'))
|
||||
lora_config = LoraConfig.from_dict(config.get('lora_config', {}))
|
||||
auto_parallel_config = AutoParallelConfig.from_dict(
|
||||
config.get('auto_parallel_config', {}))
|
||||
lora_config = LoraConfig(**config.get('lora_config', {}))
|
||||
max_encoder_input_len = config.pop(
|
||||
'max_encoder_input_len', defaults.get('max_encoder_input_len'))
|
||||
weight_streaming = config.pop('weight_streaming',
|
||||
@ -704,7 +678,6 @@ class BuildConfig:
|
||||
input_timing_cache=input_timing_cache,
|
||||
output_timing_cache=output_timing_cache,
|
||||
lora_config=lora_config,
|
||||
auto_parallel_config=auto_parallel_config,
|
||||
use_strip_plan=use_strip_plan,
|
||||
max_encoder_input_len=max_encoder_input_len,
|
||||
weight_sparsity=weight_sparsity,
|
||||
@ -727,9 +700,7 @@ class BuildConfig:
|
||||
if output.get('kv_cache_type', None) is not None:
|
||||
output['kv_cache_type'] = str(output['kv_cache_type'].name)
|
||||
output['plugin_config'] = output['plugin_config'].model_dump()
|
||||
output['lora_config'] = output['lora_config'].to_dict()
|
||||
output['auto_parallel_config'] = output['auto_parallel_config'].to_dict(
|
||||
)
|
||||
output['lora_config'] = output['lora_config'].model_dump()
|
||||
return output
|
||||
|
||||
def update_from_dict(self, config: dict):
|
||||
@ -892,7 +863,6 @@ def get_engine_version(engine_dir: str) -> Union[None, str]:
|
||||
|
||||
def optimize_model_with_config(model: PretrainedModel,
|
||||
build_config: BuildConfig):
|
||||
use_auto_parallel = build_config.auto_parallel_config.enabled
|
||||
gemm_swiglu_plugin = build_config.plugin_config.gemm_swiglu_plugin
|
||||
low_latency_gemm_swiglu_plugin = build_config.plugin_config.low_latency_gemm_swiglu_plugin
|
||||
if gemm_swiglu_plugin or low_latency_gemm_swiglu_plugin:
|
||||
@ -922,12 +892,11 @@ def optimize_model_with_config(model: PretrainedModel,
|
||||
use_ootb_moe=build_config.plugin_config.moe_plugin is None,
|
||||
use_fused_mlp=(build_config.plugin_config.use_fused_mlp
|
||||
and not is_enc_dec
|
||||
and not (is_recurrent_gemma and is_fp8)
|
||||
and not use_auto_parallel),
|
||||
and not (is_recurrent_gemma and is_fp8)),
|
||||
gemm_swiglu_plugin_dtype=gemm_swiglu_plugin,
|
||||
low_latency_gemm_swiglu_plugin_dtype=low_latency_gemm_swiglu_plugin,
|
||||
use_fused_rg_lru=is_recurrent_gemma,
|
||||
use_unfused_qkv_gemm=use_auto_parallel,
|
||||
use_unfused_qkv_gemm=False,
|
||||
use_prompt_tuning=(build_config.max_prompt_embedding_table_size > 0),
|
||||
use_lora=build_config.plugin_config.lora_plugin is not None,
|
||||
max_lora_rank=build_config.lora_config.max_lora_rank,
|
||||
@ -1281,7 +1250,6 @@ def build(model: PretrainedModel, build_config: BuildConfig) -> Engine:
|
||||
network = builder.create_network()
|
||||
network.plugin_config = build_config.plugin_config
|
||||
|
||||
use_auto_parallel = build_config.auto_parallel_config.enabled
|
||||
use_weight_only = model.config.quant_mode.is_weight_only()
|
||||
per_group = model.config.quant_mode.has_per_group_scaling()
|
||||
use_smooth_quant = model.config.quant_mode.has_act_and_weight_quant()
|
||||
@ -1391,15 +1359,6 @@ def build(model: PretrainedModel, build_config: BuildConfig) -> Engine:
|
||||
if model.config.architecture != "DecoderModel":
|
||||
optimize(network)
|
||||
|
||||
if use_auto_parallel:
|
||||
config = build_config.auto_parallel_config
|
||||
config.builder_flags = builder_config.trt_builder_config.flags
|
||||
sharded_networks = auto_parallel(network, config)
|
||||
network = sharded_networks[model.config.mapping.rank]
|
||||
if not build_config.auto_parallel_config.debug_mode:
|
||||
mapping = network.auto_parallel_config["mapping"]
|
||||
model.config.mapping = mapping
|
||||
|
||||
if build_config.visualize_network is not None:
|
||||
with net_guard(network):
|
||||
network.to_onnx(build_config.visualize_network)
|
||||
|
||||
@ -26,8 +26,6 @@ import torch
|
||||
|
||||
from tensorrt_llm._utils import (local_mpi_rank, local_mpi_size, mpi_barrier,
|
||||
mpi_comm, mpi_rank, mpi_world_size)
|
||||
from tensorrt_llm.auto_parallel import infer_cluster_config
|
||||
from tensorrt_llm.auto_parallel.cluster_info import cluster_infos
|
||||
from tensorrt_llm.bindings import KVCacheType
|
||||
from tensorrt_llm.builder import BuildConfig, Engine, build
|
||||
from tensorrt_llm.logger import logger, severity_map
|
||||
@ -305,29 +303,6 @@ def parse_arguments():
|
||||
"Maximum lengths of draft tokens for speculative decoding target model."
|
||||
)
|
||||
|
||||
autopp_parser = parser.add_argument_group("Auto parallel arguments")
|
||||
autopp_parser.add_argument('--auto_parallel',
|
||||
type=int,
|
||||
default=1,
|
||||
help="MPI world size for auto parallel.")
|
||||
autopp_parser.add_argument(
|
||||
'--gpus_per_node',
|
||||
type=int,
|
||||
default=8,
|
||||
help=
|
||||
"Number of GPUs each node has in a multi-node setup. This is a cluster spec and can be greater/smaller than world size. "
|
||||
"This option is only used for auto parallel specified with ``--auto_parallel``."
|
||||
)
|
||||
autopp_parser.add_argument(
|
||||
'--cluster_key',
|
||||
type=str,
|
||||
default=None,
|
||||
choices=cluster_infos.keys(),
|
||||
help=
|
||||
"Unique name for target GPU type. Inferred from current GPU type if not specified. "
|
||||
"This option is only used for auto parallel specified with ``--auto_parallel``."
|
||||
)
|
||||
|
||||
plugin_config_parser = parser.add_argument_group("Plugin config arguments")
|
||||
add_plugin_argument(plugin_config_parser)
|
||||
return parser
|
||||
@ -355,15 +330,6 @@ def build_model(
|
||||
assert not build_config.plugin_config.pp_reduce_scatter or architecture == "MixtralForCausalLM", \
|
||||
"PP reduce scatter is only supported in the mixtral model."
|
||||
|
||||
model_config.mapping.gpus_per_node = build_config.auto_parallel_config.gpus_per_node
|
||||
if build_config.auto_parallel_config.enabled:
|
||||
assert rank < build_config.auto_parallel_config.world_size
|
||||
assert model_config.mapping.pp_size == 1 and model_config.mapping.tp_size == 1, \
|
||||
"You must convert to full model with TP=1&&PP=1 to use auto parallel planner"
|
||||
model_config.mapping.auto_parallel = True
|
||||
model_config.mapping.world_size = build_config.auto_parallel_config.world_size
|
||||
model_config.mapping.rank = rank
|
||||
else:
|
||||
assert rank < model_config.mapping.world_size
|
||||
|
||||
rank_config = copy.deepcopy(model_config)
|
||||
@ -419,17 +385,7 @@ def parallel_build(model_config: PretrainedConfig,
|
||||
model_cls=None,
|
||||
**kwargs):
|
||||
|
||||
if build_config.auto_parallel_config.enabled:
|
||||
if model_config.mapping.world_size > 1:
|
||||
raise RuntimeError(
|
||||
"manually TP and PP are not supported in auto parallel mode.")
|
||||
if build_config.auto_parallel_config.debug_mode:
|
||||
world_size = 1
|
||||
else:
|
||||
world_size = build_config.auto_parallel_config.world_size
|
||||
else:
|
||||
world_size = model_config.mapping.world_size
|
||||
|
||||
use_mpi = mpi_world_size() > 1
|
||||
|
||||
if not use_mpi and workers == 1:
|
||||
@ -550,10 +506,6 @@ def main():
|
||||
raise RuntimeError(
|
||||
"multiple_profiles is enabled, while opt_num_tokens is set. "
|
||||
"They are not supposed to be working in the same time for now.")
|
||||
if args.cluster_key is not None:
|
||||
cluster_config = dict(cluster_key=args.cluster_key)
|
||||
else:
|
||||
cluster_config = infer_cluster_config()
|
||||
|
||||
# This should only be used for debugging.
|
||||
# The env var BUILDER_FORCE_NUM_PROFILES should override the number of
|
||||
@ -605,20 +557,6 @@ def main():
|
||||
args.input_timing_cache,
|
||||
'output_timing_cache':
|
||||
args.output_timing_cache,
|
||||
'auto_parallel_config': {
|
||||
'world_size':
|
||||
args.auto_parallel,
|
||||
'gpus_per_node':
|
||||
args.gpus_per_node,
|
||||
'sharded_io_allowlist': [
|
||||
'past_key_value_\\d+',
|
||||
'present_key_value_\\d*',
|
||||
],
|
||||
'same_buffer_io': {
|
||||
'past_key_value_(\\d+)': 'present_key_value_\\1',
|
||||
},
|
||||
**cluster_config,
|
||||
},
|
||||
'dry_run':
|
||||
args.dry_run,
|
||||
'visualize_network':
|
||||
|
||||
@ -140,8 +140,7 @@ class BuildCache:
|
||||
|
||||
@staticmethod
|
||||
def prune_build_config_for_cache_key(build_config: dict) -> dict:
|
||||
# The BuildCache will be disabled once auto_pp is enabled, so 'auto_parallel_config' should be removed
|
||||
black_list = ['auto_parallel_config', 'dry_run']
|
||||
black_list = ['dry_run']
|
||||
dic = build_config.copy()
|
||||
for key in black_list:
|
||||
if key in dic:
|
||||
|
||||
@ -6,7 +6,7 @@ import math
|
||||
import os
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, EnumMeta
|
||||
from pathlib import Path
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional,
|
||||
@ -25,7 +25,6 @@ from tensorrt_llm.lora_helper import (LoraConfig,
|
||||
get_default_trtllm_modules_to_hf_modules)
|
||||
|
||||
from .._utils import mpi_rank
|
||||
from ..auto_parallel import AutoParallelConfig, infer_cluster_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
@ -277,23 +276,21 @@ class AttentionDpConfig(StrictBaseModel):
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ParallelConfig:
|
||||
''' The model distribution configs for LLM. '''
|
||||
class _ParallelConfig(StrictBaseModel):
|
||||
"""The model distribution configs for LLM."""
|
||||
tp_size: int = 1
|
||||
pp_size: int = 1
|
||||
cp_size: int = 1
|
||||
gpus_per_node: int = 8
|
||||
moe_cluster_size: int = 1
|
||||
moe_tp_size: int = 1
|
||||
moe_ep_size: int = 1
|
||||
cp_config: dict = field(default_factory=dict)
|
||||
# Set default for MoE fields to -1 to trigger auto-calculation in Mapping
|
||||
moe_cluster_size: int = -1
|
||||
moe_tp_size: int = -1
|
||||
moe_ep_size: int = -1
|
||||
cp_config: dict = Field(default_factory=dict)
|
||||
enable_attention_dp: bool = False
|
||||
enable_lm_head_tp_in_adp: bool = False
|
||||
auto_parallel: bool = False
|
||||
|
||||
_world_size: int = field(default=1, init=False)
|
||||
_devices: Optional[List[int]] = field(default=None, init=False)
|
||||
_devices: Optional[List[int]] = PrivateAttr(default=None)
|
||||
|
||||
@property
|
||||
def devices(self) -> List[int]:
|
||||
@ -310,18 +307,7 @@ class _ParallelConfig:
|
||||
self._devices = devices
|
||||
|
||||
@property
|
||||
def world_size(self) -> bool:
|
||||
|
||||
if self.auto_parallel:
|
||||
if self.tp_size > 1 or self.pp_size > 1 or self.cp_size > 1:
|
||||
raise RuntimeError(
|
||||
"manually TP and PP are not supported in auto parallel mode."
|
||||
)
|
||||
return self._world_size
|
||||
|
||||
if self._world_size > 1:
|
||||
raise RuntimeError(
|
||||
"world_size > 1 is only supported in auto parallel mode.")
|
||||
def world_size(self) -> int:
|
||||
return self.tp_size * self.pp_size * self.cp_size
|
||||
|
||||
@property
|
||||
@ -332,12 +318,9 @@ class _ParallelConfig:
|
||||
|
||||
@world_size.setter
|
||||
def world_size(self, world_size: int):
|
||||
if self.auto_parallel:
|
||||
self._world_size = world_size
|
||||
elif (not self.auto_parallel
|
||||
) and world_size != self.tp_size * self.pp_size * self.cp_size:
|
||||
if world_size != self.tp_size * self.pp_size * self.cp_size:
|
||||
raise ValueError(
|
||||
f"world_size {world_size} should be equal to tp_size * pp_size {self.tp_size * self.pp_size * self.cp_size} "
|
||||
f"world_size {world_size} should be equal to tp_size * pp_size * cp_size {self.tp_size * self.pp_size * self.cp_size} "
|
||||
)
|
||||
|
||||
@property
|
||||
@ -356,8 +339,7 @@ class _ParallelConfig:
|
||||
enable_lm_head_tp_in_adp=self.enable_lm_head_tp_in_adp,
|
||||
moe_cluster_size=self.moe_cluster_size,
|
||||
moe_tp_size=self.moe_tp_size,
|
||||
moe_ep_size=self.moe_ep_size,
|
||||
auto_parallel=self.auto_parallel)
|
||||
moe_ep_size=self.moe_ep_size)
|
||||
|
||||
|
||||
class CalibConfig(StrictBaseModel):
|
||||
@ -1704,7 +1686,7 @@ class BaseLlmArgs(StrictBaseModel):
|
||||
status="prototype",
|
||||
)
|
||||
|
||||
_parallel_config: Optional[object] = PrivateAttr(default=None)
|
||||
_parallel_config: Optional[_ParallelConfig] = PrivateAttr(default=None)
|
||||
_model_format: Optional[_ModelFormatKind] = PrivateAttr(default=None)
|
||||
_speculative_model: Optional[str] = PrivateAttr(default=None)
|
||||
_speculative_model_format: Optional[_ModelFormatKind] = PrivateAttr(
|
||||
@ -1979,7 +1961,7 @@ class BaseLlmArgs(StrictBaseModel):
|
||||
is QuantAlgo.FP8):
|
||||
self._update_plugin_config("manage_weights", True)
|
||||
|
||||
if self.parallel_config._world_size == 1 and self.build_config:
|
||||
if self.parallel_config.world_size == 1 and self.build_config:
|
||||
self.build_config.plugin_config.nccl_plugin = None
|
||||
|
||||
if self.enable_lora and self.backend != 'pytorch':
|
||||
@ -2182,7 +2164,6 @@ class BaseLlmArgs(StrictBaseModel):
|
||||
moe_cluster_size = pretrained_config.mapping.moe_cluster_size
|
||||
moe_tp_size = pretrained_config.mapping.moe_tp_size
|
||||
moe_ep_size = pretrained_config.mapping.moe_ep_size
|
||||
world_size = pretrained_config.mapping.world_size
|
||||
gpus_per_node = pretrained_config.mapping.gpus_per_node
|
||||
# load parallel_config
|
||||
if self.parallel_config.tp_size != 1 and self.parallel_config.tp_size != tp_size:
|
||||
@ -2197,12 +2178,6 @@ class BaseLlmArgs(StrictBaseModel):
|
||||
raise ValueError(
|
||||
f"cp_size {self.parallel_config.cp_size} is not consistent with the checkpoint's cp_size {cp_size}"
|
||||
)
|
||||
if (self.parallel_config.auto_parallel
|
||||
and self.parallel_config.world_size != 1 and world_size != 1):
|
||||
raise ValueError(
|
||||
f"auto parallel with world_size {self.parallel_config.world_size} does not support checkpoint with "
|
||||
"world_size {world_size} > 1")
|
||||
if not self.parallel_config.auto_parallel:
|
||||
self._parallel_config = _ParallelConfig(
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
@ -2222,21 +2197,6 @@ class BaseLlmArgs(StrictBaseModel):
|
||||
|
||||
|
||||
class TrtLlmArgs(BaseLlmArgs):
|
||||
|
||||
auto_parallel: bool = Field(
|
||||
default=False,
|
||||
description="Enable auto parallel mode.",
|
||||
deprecated=
|
||||
"Use tensor_parallel_size/pipeline_parallel_size/xxx_parallel_size instead.",
|
||||
)
|
||||
|
||||
auto_parallel_world_size: Optional[int] = Field(
|
||||
default=None,
|
||||
description="The world size for auto parallel mode.",
|
||||
deprecated=
|
||||
"Use tensor_parallel_size/pipeline_parallel_size/xxx_parallel_size instead.",
|
||||
)
|
||||
|
||||
enable_tqdm: bool = Field(default=False,
|
||||
description="Enable tqdm for progress bar.")
|
||||
|
||||
@ -2288,16 +2248,10 @@ class TrtLlmArgs(BaseLlmArgs):
|
||||
default=False, description="Normalize log probabilities.")
|
||||
|
||||
# Private attributes
|
||||
_auto_parallel_config: Optional[AutoParallelConfig] = PrivateAttr(
|
||||
default=None)
|
||||
# This is used to hold the options for convert_checkpoint
|
||||
_convert_checkpoint_options: Dict[str,
|
||||
Any] = PrivateAttr(default_factory=dict)
|
||||
|
||||
@property
|
||||
def auto_parallel_config(self) -> AutoParallelConfig:
|
||||
return self._auto_parallel_config
|
||||
|
||||
@field_validator('calib_config', mode='before')
|
||||
@classmethod
|
||||
def init_calib_config(cls, v):
|
||||
@ -2325,26 +2279,6 @@ class TrtLlmArgs(BaseLlmArgs):
|
||||
# No else clause needed since validation already happened
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_auto_parallel(self):
|
||||
self._auto_parallel_config = AutoParallelConfig(
|
||||
sharded_io_allowlist=[
|
||||
"past_key_value_\\d+",
|
||||
"present_key_value_\\d*",
|
||||
],
|
||||
same_buffer_io={
|
||||
"past_key_value_(\\d+)": "present_key_value_\\1",
|
||||
},
|
||||
**infer_cluster_config(),
|
||||
)
|
||||
|
||||
self.parallel_config.auto_parallel = self.auto_parallel
|
||||
|
||||
if self.parallel_config.auto_parallel:
|
||||
self.parallel_config.world_size = self.auto_parallel_world_size
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_enable_build_cache(self):
|
||||
if not self.enable_build_cache:
|
||||
|
||||
@ -5,17 +5,17 @@ import shutil
|
||||
import tempfile
|
||||
import time
|
||||
import weakref
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from dataclasses import asdict, dataclass, field, is_dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
|
||||
from .._utils import (global_mpi_rank, local_mpi_rank, mpi_barrier,
|
||||
mpi_broadcast, mpi_rank, release_gc)
|
||||
from ..auto_parallel import AutoParallelConfig
|
||||
# yapf: disable
|
||||
from ..bindings.executor import (BatchingType, CapacitySchedulerPolicy,
|
||||
ContextChunkingPolicy, ExecutorConfig,
|
||||
@ -134,18 +134,6 @@ class ModelLoader:
|
||||
assert self.llm_args.build_config
|
||||
self.build_config = self.llm_args.build_config
|
||||
|
||||
self.auto_parallel_config = AutoParallelConfig(
|
||||
world_size=llm_args.parallel_config.world_size if llm_args.
|
||||
parallel_config.auto_parallel else 1)
|
||||
|
||||
default_config = self.llm_args.auto_parallel_config
|
||||
self.auto_parallel_config.set_defaults(
|
||||
cluster_key=default_config.cluster_key,
|
||||
cluster_info=default_config.cluster_info,
|
||||
same_buffer_io=default_config.same_buffer_io,
|
||||
sharded_io_allowlist=default_config.sharded_io_allowlist,
|
||||
)
|
||||
|
||||
self._gather_build_steps()
|
||||
|
||||
def _gather_build_steps(self):
|
||||
@ -545,11 +533,7 @@ class ModelLoader:
|
||||
# avoid the original build_config is modified, avoid the side effect
|
||||
copied_build_config = copy.deepcopy(self.build_config)
|
||||
|
||||
copied_build_config.update(
|
||||
auto_parallel_config=self.auto_parallel_config)
|
||||
copied_build_config.update_kv_cache_type(self._model_info.architecture)
|
||||
if self.auto_parallel_config.enabled:
|
||||
self.model.config.mapping.rank = self.rank
|
||||
assert self.model is not None, "model is loaded yet."
|
||||
|
||||
self._engine = build(self.model, copied_build_config)
|
||||
@ -732,9 +716,9 @@ class CachedModelLoader:
|
||||
def build_cache_enabled(self) -> bool:
|
||||
_enable_build_cache, _ = get_build_cache_config_from_env()
|
||||
|
||||
return (self.llm_args.enable_build_cache or _enable_build_cache) and (
|
||||
self.llm_args.model_format is _ModelFormatKind.HF
|
||||
) and not self.llm_args.parallel_config.auto_parallel
|
||||
return (self.llm_args.enable_build_cache
|
||||
or _enable_build_cache) and (self.llm_args.model_format
|
||||
is _ModelFormatKind.HF)
|
||||
|
||||
def _get_engine_cache_stage(self) -> CachedStage:
|
||||
''' Get the cache stage for engine building. '''
|
||||
@ -743,15 +727,19 @@ class CachedModelLoader:
|
||||
assert self._hf_model_dir is not None, "HF model dir is required for cache key."
|
||||
|
||||
def serialize(d) -> str:
|
||||
dic = asdict(d) if not isinstance(
|
||||
d, PretrainedConfig) else d.to_dict()
|
||||
if hasattr(d, "to_dict"):
|
||||
dic = d.to_dict()
|
||||
elif is_dataclass(d):
|
||||
dic = asdict(d)
|
||||
elif isinstance(d, BaseModel):
|
||||
dic = d.model_dump(mode="json")
|
||||
else:
|
||||
raise ValueError(f"Could not serialize type: {type(d)}")
|
||||
return json.dumps(dic, sort_keys=True)
|
||||
|
||||
parallel_config = self.llm_args.parallel_config
|
||||
|
||||
force_rebuild = False
|
||||
if parallel_config.auto_parallel:
|
||||
force_rebuild = True
|
||||
if self.llm_args.model_format is not _ModelFormatKind.HF:
|
||||
force_rebuild = True
|
||||
|
||||
|
||||
@ -13,10 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Literal, Optional
|
||||
|
||||
from ._utils import DictConversion
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
def get_missing_qkv_modules_from_lora_modules(
|
||||
@ -80,23 +79,16 @@ def use_lora(
|
||||
f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraConfig(DictConversion):
|
||||
lora_dir: List[str] = field(default_factory=list)
|
||||
lora_ckpt_source: str = "hf"
|
||||
class LoraConfig(BaseModel):
|
||||
lora_dir: List[str] = Field(default_factory=list)
|
||||
lora_ckpt_source: Literal["hf", "nemo"] = "hf"
|
||||
max_lora_rank: int = 64
|
||||
lora_target_modules: List[str] = field(default_factory=list)
|
||||
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
|
||||
lora_target_modules: List[str] = Field(default_factory=list)
|
||||
trtllm_modules_to_hf_modules: Dict[str, str] = Field(default_factory=dict)
|
||||
max_loras: Optional[int] = None
|
||||
max_cpu_loras: Optional[int] = None
|
||||
swap_gate_up_proj_lora_b_weight: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.lora_ckpt_source in [
|
||||
"hf", "nemo"
|
||||
], (f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
|
||||
)
|
||||
|
||||
@property
|
||||
def missing_qkv_modules(self) -> List[str]:
|
||||
return get_missing_qkv_modules_from_lora_modules(
|
||||
|
||||
@ -54,7 +54,6 @@ class MappingBase:
|
||||
moe_ep_size=-1, # -1 means no moe
|
||||
attn_tp_size=-1,
|
||||
attn_cp_size=-1,
|
||||
auto_parallel=False,
|
||||
enable_attention_dp=False,
|
||||
enable_lm_head_tp_in_adp=False):
|
||||
# set default values for non-moe cases
|
||||
@ -97,17 +96,10 @@ class MappingBase:
|
||||
f"attn_cp_size must be 1 for now for ulysses, but got {attn_tp_size}, {attn_cp_size}."
|
||||
)
|
||||
|
||||
if auto_parallel:
|
||||
if tp_size != 1 or pp_size != 1 or cp_size != 1:
|
||||
raise ValueError(
|
||||
"When auto parallel is enabled, tp_size, pp_size, cp_size must be 1, "
|
||||
f"but got {tp_size}, {pp_size}, {cp_size}.")
|
||||
else:
|
||||
if tp_size * pp_size * cp_size != world_size:
|
||||
raise ValueError(
|
||||
"world_size must equal to tp_size * pp_size * cp_size, "
|
||||
f"but got {world_size} != {tp_size} * {pp_size} * {cp_size}."
|
||||
)
|
||||
f"but got {world_size} != {tp_size} * {pp_size} * {cp_size}.")
|
||||
|
||||
moe_tp_ep_size = moe_tp_size * moe_ep_size
|
||||
self.moe_tp_cluster_ep_size = moe_tp_ep_size * moe_cluster_size
|
||||
@ -139,7 +131,6 @@ class MappingBase:
|
||||
self.moe_cluster_size = moe_cluster_size
|
||||
self.attn_tp_size = attn_tp_size
|
||||
self.attn_cp_size = attn_cp_size
|
||||
self.auto_parallel = auto_parallel
|
||||
self.world_size = world_size
|
||||
self.enable_attention_dp = enable_attention_dp
|
||||
if enable_lm_head_tp_in_adp:
|
||||
@ -169,8 +160,7 @@ class MappingBase:
|
||||
and self.moe_ep_size == other.moe_ep_size
|
||||
and self.attn_tp_size == other.attn_tp_size
|
||||
and self.attn_cp_size == other.attn_cp_size
|
||||
and self.cp_config == other.cp_config
|
||||
and self.auto_parallel == other.auto_parallel)
|
||||
and self.cp_config == other.cp_config)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
@ -187,7 +177,6 @@ class MappingBase:
|
||||
self.attn_cp_size,
|
||||
# note: we do not allow updating cp_config after initialization
|
||||
tuple(sorted(self.cp_config.items())),
|
||||
self.auto_parallel,
|
||||
))
|
||||
|
||||
@property
|
||||
@ -339,7 +328,6 @@ class MappingBase:
|
||||
'attn_tp_size': self.attn_tp_size,
|
||||
'attn_cp_size': self.attn_cp_size,
|
||||
'cp_config': self.cp_config,
|
||||
'auto_parallel': self.auto_parallel,
|
||||
'enable_attention_dp': self.enable_attention_dp,
|
||||
'enable_lm_head_tp_in_adp': self.enable_lm_head_tp_in_adp,
|
||||
}
|
||||
@ -463,7 +451,6 @@ class Mapping(MappingBase):
|
||||
moe_ep_size=-1, # -1 means no moe
|
||||
attn_tp_size=-1,
|
||||
attn_cp_size=-1,
|
||||
auto_parallel=False,
|
||||
enable_attention_dp=False,
|
||||
enable_lm_head_tp_in_adp=False):
|
||||
super().__init__(world_size=world_size,
|
||||
@ -478,7 +465,6 @@ class Mapping(MappingBase):
|
||||
moe_ep_size=moe_ep_size,
|
||||
attn_tp_size=attn_tp_size,
|
||||
attn_cp_size=attn_cp_size,
|
||||
auto_parallel=auto_parallel,
|
||||
enable_attention_dp=enable_attention_dp,
|
||||
enable_lm_head_tp_in_adp=enable_lm_head_tp_in_adp)
|
||||
|
||||
@ -516,17 +502,15 @@ class MpiTopology(Mapping):
|
||||
|
||||
@property
|
||||
def tp_rank(self) -> int:
|
||||
return 0 if self.auto_parallel else self.rank % self.tp_size
|
||||
return self.rank % self.tp_size
|
||||
|
||||
@property
|
||||
def pp_rank(self) -> int:
|
||||
return 0 if self.auto_parallel else self.rank // (self.tp_size *
|
||||
self.cp_size)
|
||||
return self.rank // (self.tp_size * self.cp_size)
|
||||
|
||||
@property
|
||||
def cp_rank(self) -> int:
|
||||
return 0 if self.auto_parallel else self.rank % (
|
||||
self.tp_size * self.cp_size) // self.tp_size
|
||||
return self.rank % (self.tp_size * self.cp_size) // self.tp_size
|
||||
|
||||
@property
|
||||
def tp_group(self) -> List[int]:
|
||||
|
||||
@ -739,9 +739,7 @@ class PretrainedModel(Module,
|
||||
config.set_rank(rank)
|
||||
|
||||
rank = config.mapping.rank
|
||||
if config.mapping.auto_parallel:
|
||||
rank = 0
|
||||
elif config.mapping.cp_size > 1:
|
||||
if config.mapping.cp_size > 1:
|
||||
# tp_cp_pp rank -> tp_pp rank: because different cp ranks share the same ckpt
|
||||
tp_size = config.mapping.tp_size
|
||||
cp_size = config.mapping.cp_size
|
||||
@ -1307,7 +1305,7 @@ def unfuse_qkv_gemm(model: PretrainedModel) -> PretrainedModel:
|
||||
|
||||
for name, layer in model.named_modules():
|
||||
if isinstance(layer, Attention) and not layer.cross_attention:
|
||||
assert layer.tp_size == 1, "please disable manual tp when enable auto parallel"
|
||||
assert layer.tp_size == 1, "unfuse_qkv_gemm requires tp_size == 1"
|
||||
if layer.qkv is None:
|
||||
continue
|
||||
qkv_params = get_init_params(layer.qkv, ColumnLinear)
|
||||
|
||||
@ -151,7 +151,6 @@ class Network(object):
|
||||
self._strongly_typed = trt.INetworkDefinition.get_flag(
|
||||
self._trt_network, trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
|
||||
self._unfilled_weights: Dict[str, Tuple[np.array, np.array]] = {}
|
||||
self._auto_parallel_config: Dict[str, Any] = None
|
||||
|
||||
return self
|
||||
|
||||
@ -209,10 +208,6 @@ class Network(object):
|
||||
def strongly_typed(self) -> bool:
|
||||
return self._strongly_typed
|
||||
|
||||
@property
|
||||
def auto_parallel_config(self) -> Dict[str, Any]:
|
||||
return self._auto_parallel_config
|
||||
|
||||
def _add_input(self,
|
||||
tensor,
|
||||
name,
|
||||
|
||||
@ -27,13 +27,12 @@ def read_config(config_path: Path):
|
||||
plugin_config = builder_config['plugin_config']
|
||||
pretrained_config = config['pretrained_config']
|
||||
lora_config = builder_config['lora_config']
|
||||
auto_parallel_config = builder_config['auto_parallel_config']
|
||||
use_gpt_attention_plugin = plugin_config["gpt_attention_plugin"]
|
||||
remove_input_padding = plugin_config["remove_input_padding"]
|
||||
use_lora_plugin = plugin_config["lora_plugin"]
|
||||
tp_size = pretrained_config['mapping']['tp_size']
|
||||
pp_size = pretrained_config['mapping']['pp_size']
|
||||
gpus_per_node = auto_parallel_config['gpus_per_node']
|
||||
gpus_per_node = pretrained_config['mapping']['gpus_per_node']
|
||||
world_size = tp_size * pp_size
|
||||
assert world_size == mpi_world_size(), \
|
||||
f'Engine world size ({world_size}) != Runtime world size ({mpi_world_size()})'
|
||||
|
||||
@ -50,17 +50,12 @@ BASE_EXAMPLE_CLASSES = {
|
||||
"tensorrt_llm._torch.models.modeling_siglip": ["SiglipVisionModel"],
|
||||
"tensorrt_llm._torch.models.modeling_vila": ["VilaModel"],
|
||||
"tensorrt_llm._torch.models.modeling_gpt_oss": ["GptOssForCausalLM"],
|
||||
### ending import of torch models classes
|
||||
"tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig", "LoadFormat"],
|
||||
"tensorrt_llm._torch.pyexecutor.llm_request":
|
||||
["LogitsStorage", "PyResult", "LlmResult", "LlmResponse", "LogProbStorage"],
|
||||
"tensorrt_llm._torch.speculative.mtp": ["MTPConfig"],
|
||||
"tensorrt_llm._torch.speculative.interface": ["SpeculativeDecodingMode"],
|
||||
"tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig", "LoadFormat"],
|
||||
"tensorrt_llm.auto_parallel.config": ["AutoParallelConfig", "CostModel"],
|
||||
"tensorrt_llm.auto_parallel.cluster_info":
|
||||
["ClusterInfo", "MathThroughput"],
|
||||
"tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig", "LoadFormat"],
|
||||
### ending import of torch models classes
|
||||
"tensorrt_llm.bindings.executor": [
|
||||
"BatchingType", "CacheTransceiverConfig", "CapacitySchedulerPolicy",
|
||||
"ContextPhaseParams", "ContextChunkingPolicy", "DynamicBatchConfig",
|
||||
@ -70,8 +65,6 @@ BASE_EXAMPLE_CLASSES = {
|
||||
"KvCacheRetentionConfig.TokenRangeRetentionConfig", "PeftCacheConfig",
|
||||
"SchedulerConfig"
|
||||
],
|
||||
"tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig"],
|
||||
"tensorrt_llm._torch.model_config": ["MoeLoadBalancerConfig"],
|
||||
"tensorrt_llm.builder": ["BuildConfig"],
|
||||
"tensorrt_llm.disaggregated_params": ["DisaggregatedParams"],
|
||||
"tensorrt_llm.inputs.multimodal": ["MultimodalInput"],
|
||||
|
||||
@ -633,11 +633,10 @@
|
||||
"examples/test_llama.py::test_llm_llama_lookahead_xqa_fp8_1gpu[llama-3.1-8b]": 121.70169674104545,
|
||||
"examples/test_llama.py::test_llm_llama_lookahead_xqa_fp8_1gpu[llama-3.2-1b]": 79.68221819098108,
|
||||
"examples/test_llama.py::test_llm_llama_v1_1gpu_kv_cache_reuse_with_prompt_table[llama-7b]": 167.92376559507102,
|
||||
"examples/test_llama.py::test_llm_llama_v1_2gpu_summary[llama-7b-nb:4-enable_auto_parallel]": 317.7816583644599,
|
||||
"examples/test_llama.py::test_llm_llama_v1_2gpu_summary[llama-7b-nb:4]": 317.7816583644599,
|
||||
"examples/test_llama.py::test_llm_llama_v1_4gpu_paged_kv_cache[llama-3.1-8b]": 122.89023815206019,
|
||||
"examples/test_llama.py::test_llm_llama_v1_multiple_lora_1gpu[luotuo_japan-llama-7b-lora_fp16-base_fp16]": 119.51703953905962,
|
||||
"examples/test_llama.py::test_llm_llama_v1_multiple_lora_1gpu[luotuo_japan-llama-7b-lora_fp16-base_fp8]": 176.65850483701797,
|
||||
"examples/test_llama.py::test_llm_llama_v2_1gpu_auto_parallel[llama-v2-7b-hf]": 535.973838724196,
|
||||
"examples/test_llama.py::test_llm_llama_v2_lora_1gpu[chinese-llama-2-lora-13b-llama-v2-13b-hf-lora_fp16-base_awq]": 420.1779588930076,
|
||||
"examples/test_llama.py::test_llm_llama_v2_lora_1gpu[chinese-llama-2-lora-13b-llama-v2-13b-hf-lora_fp16-base_fp16]": 895.7611340929288,
|
||||
"examples/test_llama.py::test_llm_llama_v2_lora_1gpu[chinese-llama-2-lora-13b-llama-v2-13b-hf-lora_fp16-base_fp8]": 314.3205590210273,
|
||||
|
||||
@ -894,32 +894,6 @@ def test_llm_llama_v2_gather_logits_2gpu_pp2(llama_example_root,
|
||||
summary_cmd)
|
||||
|
||||
|
||||
@skip_post_blackwell
|
||||
@pytest.mark.parametrize("llama_model_root", ['llama-v2-7b-hf'], indirect=True)
|
||||
def test_llm_llama_v2_1gpu_auto_parallel(llama_example_root, llama_model_root,
|
||||
llm_venv, cmodel_dir, engine_dir):
|
||||
model_name = 'llama_v2'
|
||||
data_type = 'float16'
|
||||
model_dir = convert_weights(llm_venv=llm_venv,
|
||||
example_root=llama_example_root,
|
||||
cmodel_dir=cmodel_dir,
|
||||
model=model_name,
|
||||
model_path=llama_model_root,
|
||||
data_type=data_type)
|
||||
|
||||
print("Build engines...")
|
||||
build_cmd = [
|
||||
"trtllm-build",
|
||||
f"--checkpoint_dir={model_dir}",
|
||||
f"--max_batch_size={1}",
|
||||
f"--max_input_len={1024}",
|
||||
f"--output_dir={engine_dir}",
|
||||
"--auto_parallel=8",
|
||||
]
|
||||
|
||||
check_call(" ".join(build_cmd), shell=True, env=llm_venv._new_env)
|
||||
|
||||
|
||||
@skip_post_blackwell
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.parametrize("num_beams", [1, 4],
|
||||
@ -1078,24 +1052,21 @@ def test_llm_llama_v3_1_autoq_2gpu_mmlu(llama_example_root, llama_model_root,
|
||||
@skip_post_blackwell
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.parametrize("use_auto_parallel", [True, False],
|
||||
ids=["enable_auto_parallel", "disable_auto_parallel"])
|
||||
@pytest.mark.parametrize("num_beams", [4],
|
||||
ids=lambda num_beams: f'nb:{num_beams}')
|
||||
@pytest.mark.parametrize("llama_model_root", ['llama-7b', 'llama-30b'],
|
||||
indirect=True)
|
||||
def test_llm_llama_v1_2gpu_summary(llama_example_root, llama_model_root,
|
||||
llm_datasets_root, llm_rouge_root, llm_venv,
|
||||
cmodel_dir, engine_dir, num_beams,
|
||||
use_auto_parallel):
|
||||
cmodel_dir, engine_dir, num_beams):
|
||||
model_name = 'llama_v1_2gpu'
|
||||
model_dir = convert_weights(llm_venv=llm_venv,
|
||||
example_root=llama_example_root,
|
||||
cmodel_dir=cmodel_dir,
|
||||
model=model_name,
|
||||
model_path=llama_model_root,
|
||||
gpus=1 if use_auto_parallel else 2,
|
||||
tp_size=1 if use_auto_parallel else 2,
|
||||
gpus=2,
|
||||
tp_size=2,
|
||||
pp_size=1)
|
||||
|
||||
print("Build engines...")
|
||||
@ -1108,8 +1079,6 @@ def test_llm_llama_v1_2gpu_summary(llama_example_root, llama_model_root,
|
||||
"--remove_input_padding=enable",
|
||||
f"--max_beam_width={num_beams}",
|
||||
]
|
||||
if use_auto_parallel:
|
||||
build_cmd += ["--auto_parallel=2"]
|
||||
|
||||
check_call(" ".join(build_cmd), shell=True, env=llm_venv._new_env)
|
||||
|
||||
|
||||
@ -102,11 +102,10 @@ examples/test_llama.py::test_llm_llama_lookahead_single_gpu_summary[llama-3.1-8b
|
||||
examples/test_llama.py::test_llm_llama_lookahead_xqa_fp8_1gpu[llama-3.1-8b]
|
||||
examples/test_llama.py::test_llm_llama_lookahead_xqa_fp8_1gpu[llama-3.2-1b]
|
||||
examples/test_llama.py::test_llm_api_lookahead_decoding_1gpu[Llama-3.1-8B-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct]
|
||||
examples/test_llama.py::test_llm_llama_v1_2gpu_summary[llama-7b-nb:4-enable_auto_parallel]
|
||||
examples/test_llama.py::test_llm_llama_v1_2gpu_summary[llama-7b-nb:4]
|
||||
examples/test_llama.py::test_llm_llama_v1_4gpu_paged_kv_cache[llama-3.1-8b]
|
||||
examples/test_llama.py::test_llm_llama_v1_multiple_lora_1gpu[luotuo_japan-llama-7b-lora_fp16-base_fp16]
|
||||
examples/test_llama.py::test_llm_llama_v1_multiple_lora_1gpu[luotuo_japan-llama-7b-lora_fp16-base_fp8]
|
||||
examples/test_llama.py::test_llm_llama_v2_1gpu_auto_parallel[llama-v2-7b-hf]
|
||||
examples/test_llama.py::test_llm_llama_v2_lora_1gpu[chinese-llama-2-lora-13b-llama-v2-13b-hf-lora_fp16-base_awq]
|
||||
examples/test_llama.py::test_llm_llama_v2_lora_1gpu[chinese-llama-2-lora-13b-llama-v2-13b-hf-lora_fp16-base_fp16]
|
||||
examples/test_llama.py::test_llm_llama_v2_lora_1gpu[chinese-llama-2-lora-13b-llama-v2-13b-hf-lora_fp16-base_fp8]
|
||||
|
||||
@ -157,7 +157,6 @@ l0_a30:
|
||||
- examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-bart-large-cnn-bfloat16-enable_gemm_plugin-enable_attention_plugin-disable_paged_kv_cache-tp:1-pp:1-nb:1-disable_fp8] # 1 mins
|
||||
- accuracy/test_cli_flow.py::TestLlama7B::test_manage_weights # 2 mins
|
||||
- accuracy/test_llm_api.py::TestQwen2_7BInstruct::test_weight_only
|
||||
- examples/test_llama.py::test_llm_llama_v2_1gpu_auto_parallel[llama-v2-7b-hf] # 15 mins
|
||||
- accuracy/test_cli_flow.py::TestMistral7B::test_beam_search # 5 mins
|
||||
- examples/test_mistral.py::test_llm_mistral_v1_1gpu[mistral-7b-v0.1-float16-max_attention_window_size_4096-summarization_long]
|
||||
- examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-t5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:2-disable_fp8]
|
||||
|
||||
@ -253,7 +253,6 @@ examples/test_qwen2audio.py::test_llm_qwen2audio_single_gpu[qwen2_audio_7b_instr
|
||||
examples/test_nemotron_nas.py::test_nemotron_nas_summary_2gpu[DeciLM-7B] SKIP (https://nvbugs/5444636)
|
||||
examples/test_multimodal.py::test_llm_fp8_multimodal_general[fp8-fp8-cnn_dailymail-Qwen2-VL-7B-Instruct-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False] SKIP (https://nvbugs/5453709)
|
||||
examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5453709)
|
||||
examples/test_llama.py::test_llm_llama_v2_1gpu_auto_parallel[llama-v2-7b-hf] SKIP (https://nvbugs/5453742)
|
||||
triton_server/test_triton.py::test_gpt_ib[gpt-ib] SKIP (https://nvbugs/5431116)
|
||||
accuracy/test_llm_api.py::TestMistralNemo12B::test_fp8 SKIP (https://nvbugs/5413197)
|
||||
triton_server/test_triton.py::test_gpt_ib_streaming[gpt-ib-streaming] SKIP (https://nvbugs/5371349)
|
||||
@ -267,7 +266,6 @@ examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-3-mini-128k-instruct] SKIP
|
||||
examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-3-small-128k-instruct] SKIP (https://nvbugs/5465143)
|
||||
examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-3.5-mini-instruct] SKIP (https://nvbugs/5465143)
|
||||
examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-4-mini-instruct] SKIP (https://nvbugs/5465143)
|
||||
examples/test_llama.py::test_llm_llama_v1_2gpu_summary[llama-7b-nb:4-enable_auto_parallel] SKIP (https://nvbugs/5453742)
|
||||
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen1.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
|
||||
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
|
||||
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_vl_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5359696)
|
||||
|
||||
@ -26,9 +26,8 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
|
||||
# check expected parallel config
|
||||
world_size = expected_ad_config.world_size
|
||||
expected_parallel_config = _ParallelConfig(
|
||||
auto_parallel=True, gpus_per_node=expected_llm_args.gpus_per_node
|
||||
tp_size=world_size, gpus_per_node=expected_llm_args.gpus_per_node
|
||||
)
|
||||
expected_parallel_config.world_size = world_size
|
||||
assert llm_args._parallel_config == expected_parallel_config, (
|
||||
f"Expected parallel_config {expected_parallel_config}, got {llm_args._parallel_config}"
|
||||
)
|
||||
|
||||
@ -53,7 +53,6 @@ def build_engine(model_name, request):
|
||||
|
||||
llm = LLM(model_path,
|
||||
tensor_parallel_size=tp_size,
|
||||
auto_parallel_world_size=tp_size,
|
||||
build_config=build_config)
|
||||
|
||||
engine_dir = TemporaryDirectory(suffix="-engine_dir")
|
||||
|
||||
@ -48,7 +48,6 @@ def engine_from_fp8_quantization(model_name):
|
||||
|
||||
llm = LLM(model_path,
|
||||
tensor_parallel_size=tp_size,
|
||||
auto_parallel_world_size=tp_size,
|
||||
quant_config=quant_config,
|
||||
calib_config=calib_config,
|
||||
build_config=build_config)
|
||||
|
||||
@ -459,21 +459,12 @@ def test_llm_generate_async():
|
||||
|
||||
def _test_llm_generate_async(model_name=default_model_name,
|
||||
tp_size: int = 1,
|
||||
use_auto_parallel: bool = False,
|
||||
tokenizer=None):
|
||||
if "Mixtral" in model_name and use_auto_parallel:
|
||||
pytest.skip("Auto parallel is not supported for Mixtral models")
|
||||
|
||||
tp_size = tp_size if not use_auto_parallel else 1
|
||||
world_size = tp_size if use_auto_parallel else None
|
||||
|
||||
llm = LLM(
|
||||
model=get_model_path(model_name),
|
||||
tokenizer=tokenizer,
|
||||
kv_cache_config=global_kvcache_config,
|
||||
tensor_parallel_size=tp_size,
|
||||
auto_parallel=use_auto_parallel,
|
||||
auto_parallel_world_size=world_size,
|
||||
fast_build=True,
|
||||
)
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ import yaml
|
||||
|
||||
import tensorrt_llm.bindings.executor as tle
|
||||
from tensorrt_llm import LLM as TorchLLM
|
||||
from tensorrt_llm import AutoParallelConfig
|
||||
from tensorrt_llm._tensorrt_engine import LLM
|
||||
from tensorrt_llm.builder import LoraConfig
|
||||
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
|
||||
@ -334,10 +333,6 @@ def test_update_llm_args_with_extra_dict_with_nested_dict():
|
||||
plugin_config = PluginConfig(dtype='float16', nccl_plugin=None)
|
||||
build_config = BuildConfig(max_input_len=1024,
|
||||
lora_config=LoraConfig(lora_ckpt_source='hf'),
|
||||
auto_parallel_config=AutoParallelConfig(
|
||||
world_size=1,
|
||||
same_buffer_io={},
|
||||
debug_outputs=[]),
|
||||
plugin_config=plugin_config)
|
||||
extra_llm_args_dict = {
|
||||
"build_config": build_config.to_dict(),
|
||||
|
||||
@ -135,25 +135,17 @@ def test_llm_return_logprobs_tp2(prompt_logprobs: Optional[int],
|
||||
tp_size=2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_auto_parallel", [True, False],
|
||||
ids=["enable_auto_parallel", "disable_auto_parallel"])
|
||||
@pytest.mark.parametrize("from_ckpt", [True, False],
|
||||
ids=["from_ckpt", "from_hf"])
|
||||
@pytest.mark.gpu2
|
||||
@pytest.mark.part0
|
||||
def test_llm_generate_async_tp2(
|
||||
engine_from_checkpoint: tempfile.TemporaryDirectory, from_ckpt: bool,
|
||||
use_auto_parallel: bool):
|
||||
if use_auto_parallel and from_ckpt:
|
||||
pytest.skip("Skip auto parallel for TP2 checkpoint")
|
||||
engine_from_checkpoint: tempfile.TemporaryDirectory, from_ckpt: bool):
|
||||
model_dir = engine_from_checkpoint.name if from_ckpt else get_model_path(
|
||||
llama_model_path)
|
||||
tokenizer_dir = get_model_path(llama_model_path)
|
||||
tokenizer = TransformersTokenizer.from_pretrained(tokenizer_dir)
|
||||
_test_llm_generate_async(model_dir,
|
||||
tp_size=2,
|
||||
use_auto_parallel=use_auto_parallel,
|
||||
tokenizer=tokenizer)
|
||||
_test_llm_generate_async(model_dir, tp_size=2, tokenizer=tokenizer)
|
||||
|
||||
|
||||
@skip_gpu_memory_less_than(70 * 1024**3)
|
||||
@ -201,7 +193,6 @@ def test_llm_pp2():
|
||||
prompts, ["D E F G H I J K"],
|
||||
sampling_params=SamplingParams(max_tokens=8),
|
||||
pipeline_parallel_size=2,
|
||||
auto_parallel=False,
|
||||
kv_cache_config=global_kv_cache_config)
|
||||
|
||||
|
||||
|
||||
@ -1,12 +1,6 @@
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm import serialization
|
||||
from tensorrt_llm.auto_parallel.config import AutoParallelConfig
|
||||
from tensorrt_llm.auto_parallel.parallelization import ParallelConfig
|
||||
from tensorrt_llm.auto_parallel.simplifier import GraphConfig, StageType
|
||||
|
||||
|
||||
class TestClass:
|
||||
@ -83,38 +77,5 @@ def test_serialization_complex_object_disallowed_class():
|
||||
excep) == "Import torch._utils | _rebuild_tensor_v2 is not allowed"
|
||||
|
||||
|
||||
def test_parallel_config_serialization():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create a ParallelConfig instance with some test data
|
||||
config = ParallelConfig()
|
||||
config.version = "test_version"
|
||||
config.network_hash = "test_hash"
|
||||
config.auto_parallel_config = AutoParallelConfig(
|
||||
world_size=2, gpus_per_node=2, cluster_key="test_cluster")
|
||||
config.graph_config = GraphConfig(num_micro_batches=2,
|
||||
num_blocks=3,
|
||||
num_stages=2)
|
||||
config.cost = 1.5
|
||||
config.stage_type = StageType.START
|
||||
|
||||
config_path = os.path.join(tmpdir, "parallel_config.pkl")
|
||||
config.save(config_path)
|
||||
|
||||
loaded_config = ParallelConfig.from_file(config_path)
|
||||
|
||||
# Verify the loaded config matches the original
|
||||
assert loaded_config.version == config.version
|
||||
assert loaded_config.network_hash == config.network_hash
|
||||
assert loaded_config.auto_parallel_config.world_size == config.auto_parallel_config.world_size
|
||||
assert loaded_config.auto_parallel_config.gpus_per_node == config.auto_parallel_config.gpus_per_node
|
||||
assert loaded_config.auto_parallel_config.cluster_key == config.auto_parallel_config.cluster_key
|
||||
assert loaded_config.graph_config.num_micro_batches == config.graph_config.num_micro_batches
|
||||
assert loaded_config.graph_config.num_blocks == config.graph_config.num_blocks
|
||||
assert loaded_config.graph_config.num_stages == config.graph_config.num_stages
|
||||
assert loaded_config.cost == config.cost
|
||||
assert loaded_config.stage_type == config.stage_type
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_serialization_allowed_class()
|
||||
test_parallel_config_serialization()
|
||||
|
||||
@ -15,7 +15,6 @@ from utils.util import force_ampere
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm import BuildConfig, Mapping, SamplingParams
|
||||
from tensorrt_llm._utils import mpi_barrier
|
||||
from tensorrt_llm.auto_parallel import AutoParallelConfig, infer_cluster_config
|
||||
from tensorrt_llm.executor import GenerationExecutor
|
||||
from tensorrt_llm.models import LLaMAForCausalLM
|
||||
|
||||
@ -56,7 +55,7 @@ def get_batch_output_text_expected(model_name):
|
||||
|
||||
|
||||
# 76s on ipp1-1197, loading weights 18s (varies based on network speed), network/engine creation 27s
|
||||
def build_and_run_tp2(rank, model_name, engine_dir, use_auto_parallel):
|
||||
def build_and_run_tp2(rank, model_name, engine_dir):
|
||||
'''Do not save the engine, all in one LLaMAForCausalLM object
|
||||
'''
|
||||
batch_output_text_expected = get_batch_output_text_expected(model_name)
|
||||
@ -67,26 +66,10 @@ def build_and_run_tp2(rank, model_name, engine_dir, use_auto_parallel):
|
||||
max_batch_size, max_isl, max_osl = 8, 256, 256
|
||||
hf_model_dir = str(llm_models_root() / model_name)
|
||||
mapping = Mapping(world_size=TP_SIZE, rank=rank, tp_size=TP_SIZE)
|
||||
auto_parallel_config = AutoParallelConfig()
|
||||
if use_auto_parallel:
|
||||
mapping = Mapping()
|
||||
mapping.rank = rank
|
||||
auto_parallel_config = AutoParallelConfig(
|
||||
world_size=TP_SIZE,
|
||||
sharded_io_allowlist=[
|
||||
"past_key_value_\\d+",
|
||||
"present_key_value_\\d*",
|
||||
],
|
||||
same_buffer_io={
|
||||
"past_key_value_(\\d+)": "present_key_value_\\1",
|
||||
},
|
||||
**infer_cluster_config(),
|
||||
)
|
||||
build_config = BuildConfig(max_batch_size=max_batch_size,
|
||||
max_input_len=max_isl,
|
||||
max_seq_len=max_osl + max_isl,
|
||||
strongly_typed=True,
|
||||
auto_parallel_config=auto_parallel_config)
|
||||
strongly_typed=True)
|
||||
# build and run by one llama object
|
||||
llama = LLaMAForCausalLM.from_hugging_face(hf_model_dir, mapping=mapping)
|
||||
engine = tensorrt_llm.build(llama, build_config)
|
||||
@ -117,21 +100,17 @@ def build_and_run_tp2(rank, model_name, engine_dir, use_auto_parallel):
|
||||
|
||||
|
||||
@force_ampere
|
||||
@pytest.mark.parametrize("use_auto_parallel", [True, False],
|
||||
ids=["enable_auto_parallel", "disable_auto_parallel"])
|
||||
@pytest.mark.parametrize("model_name",
|
||||
["llama-models/llama-7b-hf", "Mixtral-8x7B-v0.1"])
|
||||
def test_multi_gpu(model_name, use_auto_parallel):
|
||||
def test_multi_gpu(model_name):
|
||||
if torch.cuda.device_count() < TP_SIZE:
|
||||
print(f"The test needs at least ${TP_SIZE} GPUs, skipping")
|
||||
return
|
||||
if "Mixtral" in model_name and use_auto_parallel:
|
||||
pytest.skip("Auto parallel is not supported for Mixtral models")
|
||||
engine_dir = tempfile.TemporaryDirectory().name
|
||||
|
||||
with MPIPoolExecutor(max_workers=TP_SIZE) as executor:
|
||||
results = executor.map(build_and_run_tp2, (0, 1), [model_name] * 2,
|
||||
[engine_dir] * 2, [use_auto_parallel] * 2)
|
||||
[engine_dir] * 2)
|
||||
for r in results:
|
||||
assert r is True
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user