[TRTLLM-8682][chore] Remove auto_parallel module (#8329)

Signed-off-by: Anish Shanbhag <ashanbhag@nvidia.com>
This commit is contained in:
Anish Shanbhag 2025-10-22 17:53:08 -07:00 committed by GitHub
parent e5865de518
commit 15de45d782
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
80 changed files with 100 additions and 12633 deletions

View File

@ -93,36 +93,6 @@
],
"trtllm_modules_to_hf_modules": {}
},
"auto_parallel_config": {
"world_size": 1,
"gpus_per_node": 8,
"cluster_key": "A100-PCIe-80GB",
"cluster_info": null,
"sharding_cost_model": "alpha_beta",
"comm_cost_model": "alpha_beta",
"enable_pipeline_parallelism": false,
"enable_shard_unbalanced_shape": false,
"enable_shard_dynamic_shape": false,
"enable_reduce_scatter": true,
"builder_flags": null,
"debug_mode": false,
"infer_shape": true,
"validation_mode": false,
"same_buffer_io": {
"past_key_value_(\\d+)": "present_key_value_\\1"
},
"same_spec_io": {},
"sharded_io_allowlist": [
"past_key_value_\\d+",
"present_key_value_\\d*"
],
"fast_reduce": true,
"fill_weights": false,
"parallel_config_cache": null,
"profile_cache": null,
"dump_path": null,
"debug_outputs": []
},
"weight_sparsity": false,
"weight_streaming": false,
"plugin_config": {

View File

@ -132,16 +132,6 @@ trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp16_wq \
--output_dir ./tmp/llama/7B/trt_engines/weight_only/1-gpu/ \
--gemm_plugin auto
# Build LLaMA 7B using 2-way auto parallelism (deprecated).
python convert_checkpoint.py --model_dir ./tmp/llama/7B/ \
--output_dir ./tllm_checkpoint_1gpu_fp16 \
--dtype float16
trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp16 \
--output_dir ./tmp/llama/7B/trt_engines/fp16/2-gpu/ \
--gemm_plugin auto \
--auto_parallel 2
# Build LLaMA 7B using 2-way tensor parallelism.
python convert_checkpoint.py --model_dir ./tmp/llama/7B/ \
--output_dir ./tllm_checkpoint_2gpu_tp2 \

View File

@ -77,7 +77,6 @@ from ._utils import (default_gpus_per_node, local_mpi_rank, local_mpi_size,
mpi_barrier, mpi_comm, mpi_rank, mpi_world_size,
set_mpi_comm, str_dtype_to_torch, str_dtype_to_trt,
torch_dtype_to_trt)
from .auto_parallel import AutoParallelConfig, auto_parallel
from .builder import BuildConfig, Builder, BuilderConfig, build
from .disaggregated_params import DisaggregatedParams
from .functional import Tensor, constant
@ -130,8 +129,6 @@ __all__ = [
'Module',
'functional',
'models',
'auto_parallel',
'AutoParallelConfig',
'quantization',
'tools',
'LLM',

View File

@ -391,10 +391,13 @@ class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings):
rank to automatically shard the model. This is just to ensure that other objects in the
runtime that may read parallel_config can do so.
"""
# Set tp_size = self.world_size so that _ParallelConfig.world_size will return the
# correct value (computed as tp_size * pp_size * cp_size). This does not necessarily
# mean that TP will actually be used.
self._parallel_config = _ParallelConfig(
auto_parallel=True, gpus_per_node=self.gpus_per_node
tp_size=self.world_size, gpus_per_node=self.gpus_per_node
)
self._parallel_config.world_size = self.world_size
return self
@model_validator(mode="after")

View File

@ -74,18 +74,15 @@ class DeviceMeshTopologyImpl(_MappingBaseForTypeCheck):
# Access rank
@property
def tp_rank(self) -> int:
assert not self.auto_parallel, "Auto parallel is not currently supported in Ray mode."
return self.tp_group_pg.rank()
@property
def pp_rank(self) -> int:
assert not self.auto_parallel, "Auto parallel is not currently supported in Ray mode."
return self.pp_group_pg.rank()
@property
def cp_rank(self) -> int:
# TODO: WIP
assert not self.auto_parallel, "Auto parallel is not currently supported in Ray mode."
return self.cp_group_pg.rank()
# Access group ranks

View File

@ -25,7 +25,6 @@ import tempfile
import trace
import weakref
from contextlib import contextmanager
from dataclasses import asdict
from enum import EnumMeta
from functools import lru_cache, partial, wraps
from pathlib import Path
@ -799,38 +798,6 @@ def trace_func(func):
return wrapper
class DictConversion:
@classmethod
def from_dict(cls, config: Dict[str, Any]):
obj = cls()
fields = obj.__dataclass_fields__
for key, value in config.items():
assert hasattr(obj, key), f"cannot find {key} in {obj}"
field_cls = fields[key].type
if (isinstance(field_cls, type)
and issubclass(field_cls, DictConversion)
and isinstance(value, dict)):
value = field_cls.from_dict(value)
setattr(obj, key, value)
return obj
def to_dict(self):
return asdict(self)
@classmethod
def from_json_file(cls, file):
with open(file) as f:
return cls.from_dict(json.load(f))
def set_defaults(self, **kwargs):
for key, default in kwargs.items():
value = getattr(self, key)
if (value is None
or (isinstance(value, (list, dict)) and len(value) == 0)):
setattr(self, key, default)
class BaseEnumMeta(EnumMeta):
def __contains__(cls, item):

View File

@ -1,9 +0,0 @@
from .auto_parallel import auto_parallel
from .cluster_info import infer_cluster_config
from .config import AutoParallelConfig
__all__ = [
'auto_parallel',
'AutoParallelConfig',
'infer_cluster_config',
]

View File

@ -1,266 +0,0 @@
import gc
import os
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import tensorrt as trt
import torch
from filelock import FileLock
from tensorrt_llm.functional import DimRange, Tensor
from tensorrt_llm.logger import logger
from tensorrt_llm.network import Network, net_guard
from .config import AutoParallelConfig
from .device_mesh import LogicalDeviceMesh, PhysicalDeviceMesh
from .node_graph import NodeGraph
from .parallelization import ParallelConfig, parallelize
from .pipeline_graph import PipelineGraph
from .simplifier import GraphConfig, Simplifier, StageType
from .utils import current_flags
def to_network(graph: PipelineGraph, network: Network):
logger.debug("Converting graph to network")
trt_network = graph.as_trt()
trt_network.name = network.trt_network.name
new_network = Network()
new_network._init(trt_network)
new_network._dtype = network._dtype
new_network._plugin_config = network._plugin_config
new_network._unfilled_weights = graph._unfilled_weights
new_network._auto_parallel_config = graph._auto_parallel_config
with net_guard(network):
for i in range(trt_network.num_inputs):
input = trt_network.get_input(i)
tensor = Tensor(is_network_input=False)
if input.name in network._inputs:
profiles = network._inputs[input.name].profiles
elif len(network._inputs) == 0:
profiles = []
else:
shape = input.shape
num_profiles = len(list(network._inputs.values())[0].profiles)
profile = DimRange(shape, [None] * len(shape))
profiles = [profile] * num_profiles
tensor.profiles = profiles
tensor.trt_tensor = input
new_network._inputs[input.name] = tensor
return new_network
def find_solution(
node_graph: NodeGraph,
graph_config: GraphConfig,
lmesh: LogicalDeviceMesh,
memory_budget: int,
flags: list,
device: int,
dump_path: str,
) -> ParallelConfig:
torch.cuda.set_device(device)
with current_flags(*flags):
cost_graph = node_graph.get_cost_graph(lmesh)
num_stages = graph_config.num_stages
if num_stages == 1:
stage_types = [None]
elif num_stages == 2:
stage_types = [StageType.START, StageType.END]
else:
stage_types = [StageType.START, StageType.BLOCK, StageType.END]
best_config, best_solution = None, None
for stage_type in stage_types:
if stage_type is not None:
node_graph.set_slowest_stage(stage_type, graph_config)
solution = node_graph.find_solution(
cost_graph,
memory_budget,
)
cost = solution.total_cost
if best_config is None or cost < best_config.cost:
best_config = ParallelConfig()
best_config.graph_config = graph_config
best_config.lmesh = lmesh
best_config.cost = cost
best_config.graph_strategy = solution.node_best_strategy
best_config.stage_type = stage_type
best_solution = solution
if dump_path is not None:
lock = FileLock(f"{dump_path}/path.lock", thread_local=False)
vlz_name = f"{dump_path}/solution."
if graph_config.num_micro_batches != 1:
vlz_name += f"mbs{graph_config.num_micro_batches}."
if graph_config.num_stages != 1:
vlz_name += f"stages{graph_config.num_stages}."
vlz_name += lmesh.cluster_key
with lock:
node_graph.visualize_solution(
best_solution,
vlz_name,
ignore_shape_io=True,
)
return best_config
def infer_builder_flags(network):
fp16_enabled = False
bf16_enabled = False
int8_enabled = False
fp8_enabled = False
def check_dtype(tensor):
nonlocal fp16_enabled
nonlocal bf16_enabled
nonlocal int8_enabled
nonlocal fp8_enabled
if tensor.dtype == trt.DataType.HALF:
fp16_enabled = True
elif tensor.dtype == trt.DataType.BF16:
bf16_enabled = True
elif tensor.dtype == trt.DataType.INT8:
int8_enabled = True
elif tensor.dtype == trt.DataType.FP8:
fp8_enabled = True
trt_network = network.trt_network
for i in range(trt_network.num_inputs):
input = trt_network.get_input(i)
check_dtype(input)
for i in range(trt_network.num_layers):
layer = trt_network.get_layer(i)
for j in range(layer.num_outputs):
output = layer.get_output(j)
check_dtype(output)
builder_flags = 0
if fp16_enabled:
builder_flags |= 1 << int(trt.BuilderFlag.FP16)
builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
if bf16_enabled:
builder_flags |= 1 << int(trt.BuilderFlag.BF16)
builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
if int8_enabled:
builder_flags |= 1 << int(trt.BuilderFlag.INT8)
if fp8_enabled:
builder_flags |= 1 << int(trt.BuilderFlag.FP8)
builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
return builder_flags
def auto_parallel(network: Network, config: AutoParallelConfig):
logger.warning(
"auto_parallel is deprecated, "
"please use explicit parallelism like tp_size/pp_size instead.")
debug_mode = config.debug_mode
memory_budget = config.get_cluster_info(
).memory_budget_per_device * 1024 * 1024 * 1024
enable_pipeline_parallelism = config.enable_pipeline_parallelism
if config.world_size < config.gpus_per_node:
num_hosts = 1
num_devices_per_host = config.world_size
else:
assert config.world_size % config.gpus_per_node == 0
num_hosts = config.world_size // config.gpus_per_node
num_devices_per_host = config.gpus_per_node
parallel_config_cache = config.parallel_config_cache
dump_path = config.dump_path if debug_mode else None
fill_weights = config.fill_weights
if num_hosts == 1 and num_devices_per_host == 1:
return [network]
if dump_path is not None:
if not os.path.exists(dump_path):
os.makedirs(dump_path)
builder_flags = config.builder_flags or infer_builder_flags(network)
flags = [builder_flags, network.strongly_typed]
with current_flags(*flags):
simplifier = Simplifier(network, config)
network_hash = simplifier.get_network_hash()
best_config = None
if parallel_config_cache is not None and Path(
parallel_config_cache).exists():
parallel_config = ParallelConfig.from_file(parallel_config_cache)
if (ParallelConfig.VERSION == parallel_config.version
and network_hash == parallel_config.network_hash
and config == parallel_config.auto_parallel_config):
logger.info(
f"use cache of parallel config from {parallel_config_cache}"
)
best_config = parallel_config
if best_config is None:
num_devices = num_hosts * num_devices_per_host
phy_ids = [[
i + j * num_devices_per_host
for i in range(num_devices_per_host)
] for j in range(num_hosts)]
phy_mesh = PhysicalDeviceMesh(phy_ids, config)
if enable_pipeline_parallelism:
num_micro_batches_list = simplifier.list_all_num_micro_batches()
else:
num_micro_batches_list = [1]
jobs = []
for num_micro_batches in num_micro_batches_list:
simplifier.infer_shapes(num_micro_batches)
if enable_pipeline_parallelism:
pipeline_configs = phy_mesh.list_all_pipeline_configs()
else:
pipeline_configs = [(1, num_devices)]
for num_stages, num_devices_per_stage in pipeline_configs:
# TODO: add fallback path that allows num_micro_batches >= num_stages
# if no solution satisfies memory budget
if num_micro_batches < num_stages:
continue
simplified_graph, graph_config = simplifier.simplify_graph(
phy_mesh,
num_stages,
num_devices_per_stage,
)
if simplified_graph is None:
continue
node_graph = NodeGraph(simplified_graph)
node_graph.assign_cost_weights(graph_config)
lmeshes = graph_config.stage_phy_meshes[
0].get_logical_meshes()
for lmesh in lmeshes:
jobs.append(
(node_graph, graph_config, lmesh, memory_budget *
(num_devices / num_devices_per_stage)))
try:
with ThreadPoolExecutor() as executor:
best_config = sorted(
executor.map(
lambda x: find_solution(
*x,
flags,
torch.cuda.current_device(),
dump_path,
),
jobs,
),
key=lambda x: x.cost,
)[0]
finally:
phy_mesh.close()
if parallel_config_cache is not None:
best_config.network_hash = network_hash
best_config.auto_parallel_config = config
best_config.save(parallel_config_cache)
new_graphs = parallelize(simplifier, best_config)
networks = [to_network(new_graph, network) for new_graph in new_graphs]
if debug_mode and fill_weights:
networks[0]._fill_weights()
gc.collect()
torch.cuda.empty_cache()
return networks

View File

@ -1,539 +0,0 @@
import copy
import re
from dataclasses import dataclass, field
from typing import Dict, Tuple, Union
import pynvml
import torch
try:
from cuda.bindings import runtime as cudart
except ImportError:
from cuda import cudart
from tensorrt_llm._utils import DictConversion
from tensorrt_llm.logger import logger
from tensorrt_llm.profiler import PyNVMLContext, _device_get_memory_info_fn
@dataclass
class MathThroughput(DictConversion):
int4: int = 0 # Tflops
int8: int = 0 # Tflops
fp8: int = 0 # Tflops
float16: int = 0 # Tflops
bfloat16: int = 0 # Tflops
float32: int = 0 # Tflops
@staticmethod
def to_tflops(
ipc_per_sm: "MathThroughput",
sm_count: int,
clock_mhz: int,
) -> "MathThroughput":
tflops = MathThroughput()
for name in ipc_per_sm.__dataclass_fields__:
setattr(
tflops, name,
getattr(ipc_per_sm, name) * sm_count * clock_mhz // int(1e6))
return tflops
@dataclass
class ClusterInfo(DictConversion):
inter_node_bw_per_device: int = 25 # GBps
intra_node_bw_per_device: int = 0 # GBps
inter_node_latency: int = 10 # us
intra_node_latency: int = 10 # us
intra_node_sharp: bool = False
inter_node_sharp: bool = True
memory_bw: int = 0 # GBps
memory_budget_per_device: int = 0 # GB
math_throughput: MathThroughput = field(default_factory=MathThroughput)
memory_efficiency: float = 1.0
math_efficiency: float = 1.0
communication_efficiency: float = 1.0
_math_throughputs = {
"A100": MathThroughput(
int8=624,
float16=312,
bfloat16=312,
float32=156,
),
}
_bandwidths = {
"PCIe-3": 16,
"PCIe-4": 32,
"PCIe-5": 64,
}
cluster_infos = {
# from https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
"A100-SXM-80GB":
ClusterInfo(
intra_node_bw_per_device=300,
memory_bw=2039,
memory_budget_per_device=80,
math_throughput=_math_throughputs["A100"],
),
"A100-SXM-40GB":
ClusterInfo(
intra_node_bw_per_device=300,
memory_bw=1555,
memory_budget_per_device=40,
math_throughput=_math_throughputs["A100"],
),
"A100-PCIe-80GB":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=1935,
memory_budget_per_device=80,
math_throughput=_math_throughputs["A100"],
),
"A100-PCIe-40GB":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=1555,
memory_budget_per_device=40,
math_throughput=_math_throughputs["A100"],
),
# from https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
"H100-SXM":
ClusterInfo(
inter_node_bw_per_device=50,
intra_node_bw_per_device=450,
intra_node_sharp=True,
memory_bw=3350,
memory_budget_per_device=80,
math_throughput=MathThroughput(
int8=1979,
fp8=1979,
float16=989,
bfloat16=989,
float32=495,
),
),
"H100-PCIe":
ClusterInfo(
inter_node_bw_per_device=50,
intra_node_bw_per_device=_bandwidths["PCIe-5"],
memory_bw=2000,
memory_budget_per_device=80,
math_throughput=MathThroughput(
int8=1513,
fp8=1513,
float16=756,
bfloat16=756,
float32=378,
),
),
"H20":
ClusterInfo(
inter_node_bw_per_device=50,
intra_node_bw_per_device=450,
memory_bw=4000,
memory_budget_per_device=96,
math_throughput=MathThroughput(
int8=293,
fp8=293,
float16=147,
bfloat16=147,
float32=74,
),
),
# from https://nvdam.widen.net/s/nb5zzzsjdf/hpc-datasheet-sc23-h200-datasheet-3002446
"H200-SXM":
ClusterInfo(
inter_node_bw_per_device=50,
intra_node_bw_per_device=450,
memory_bw=4800,
memory_budget_per_device=141,
math_throughput=MathThroughput(
int8=3958,
fp8=3958,
float16=1979,
bfloat16=1979,
float32=67,
),
),
"H200-NVL":
ClusterInfo(
inter_node_bw_per_device=50,
intra_node_bw_per_device=450,
memory_bw=4800,
memory_budget_per_device=141,
math_throughput=MathThroughput(
int8=3341,
fp8=3341,
float16=1671,
bfloat16=1671,
float32=60,
),
),
# from https://images.nvidia.com/content/Solutions/data-center/a40/nvidia-a40-datasheet.pdf
"A40":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=696,
memory_budget_per_device=48,
math_throughput=MathThroughput(
int4=600,
int8=300,
float16=150,
bfloat16=150,
float32=75,
),
),
# from https://www.nvidia.com/content/dam/en-zz/Solutions/data-center/products/a30-gpu/pdf/a30-datasheet.pdf
"A30":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=933,
memory_budget_per_device=24,
math_throughput=MathThroughput(
int4=661,
int8=330,
float16=165,
bfloat16=165,
float32=82,
),
),
# from https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/datasheet-new/nvidia-a10-datasheet.pdf
"A10":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=600,
memory_budget_per_device=24,
math_throughput=MathThroughput(
int4=500,
int8=250,
float16=125,
bfloat16=125,
float32=62.5,
),
),
"A10G":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=600,
memory_budget_per_device=24,
math_throughput=MathThroughput(
int4=280,
int8=140,
float16=70,
bfloat16=70,
float32=35,
),
),
# from https://resources.nvidia.com/en-us-l40s/l40s-datasheet-28413
"L40S":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=864,
memory_budget_per_device=48,
math_throughput=MathThroughput(
int4=733,
int8=733,
fp8=733,
float16=362,
bfloat16=362,
float32=183,
),
),
# from https://images.nvidia.cn/content/Solutions/data-center/vgpu-L40-datasheet.pdf
"L40":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=864,
memory_budget_per_device=48,
math_throughput=MathThroughput(
int4=724,
int8=362,
fp8=362,
float16=181,
bfloat16=181,
float32=90,
),
),
"L20":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=864,
memory_budget_per_device=48,
math_throughput=MathThroughput(
int8=238,
fp8=238,
float16=119,
bfloat16=119,
float32=60,
),
),
# from https://nvdam.widen.net/s/rvq98gbwsw/l4-datasheet-2595652
"L4":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=300,
memory_budget_per_device=24,
math_throughput=MathThroughput(
int8=242,
fp8=242,
float16=120,
bfloat16=120,
float32=60,
),
),
"L2":
ClusterInfo(
intra_node_bw_per_device=_bandwidths["PCIe-4"],
memory_bw=300,
memory_budget_per_device=24,
math_throughput=MathThroughput(
int8=193,
fp8=193,
float16=97,
bfloat16=97,
float32=48,
),
),
}
def infer_cluster_key() -> str:
def match(product, name):
# Use A100 as example, the regex pattern matches for:
# - NVIDIA A100 80GB
# - NVIDIA A100-PCIE
# - NVIDIA A100
# And does not match A1000 etc.
return re.match(f".*{product}([ -]|$).*", name) is not None
def is_sxm():
return "SXM" in device_name
def is_80gb():
return "80GB" in device_name
def is_32gb():
return "32GB" in device_name
device_name = torch.cuda.get_device_name(torch.cuda.current_device())
if match("A100", device_name):
if is_sxm():
if is_80gb():
return "A100-SXM-80GB"
else:
return "A100-SXM-40GB"
else:
if is_80gb():
return "A100-PCIe-80GB"
else:
return "A100-PCIe-40GB"
elif match("A10G", device_name):
return "A10G"
elif match("A10", device_name):
return "A10"
elif match("A30", device_name):
return "A30"
elif match("A40", device_name):
return "A40"
elif match("H100", device_name):
if is_sxm():
return "H100-SXM"
else:
return "H100-PCIe"
elif match("H200", device_name):
if is_sxm():
return "H200-SXM"
else:
return "H200-NVL"
elif match("L40S", device_name):
return "L40S"
elif match("L40", device_name):
return "L40"
elif match("L4", device_name):
return "L4"
return None
def ipc_per_sm(compute_cap: Tuple[int, int]) -> MathThroughput:
ipc_table = {
(9, 0):
MathThroughput(
int8=16384,
fp8=16384,
float16=8192,
bfloat16=8192,
float32=4096,
),
(8, 0):
MathThroughput(
int4=8192,
int8=4096,
float16=2048,
bfloat16=2048,
float32=1024,
),
(8, 6):
MathThroughput(
int4=4096,
int8=2048,
float16=1024,
bfloat16=1024,
float32=512,
),
(8, 9):
MathThroughput(
int4=2048,
int8=1024,
fp8=1024,
float16=512,
bfloat16=512,
float32=256,
),
(7, 0):
MathThroughput(
float16=1024,
float32=128,
),
(7, 5):
MathThroughput(
int4=4096,
int8=2048,
float16=1024,
float32=128,
),
}
return ipc_table.get(compute_cap, MathThroughput())
def nvlink_version(version_enum: int) -> int:
nvl_version_table = {
1: 1,
2: 2,
3: 2, # 2.2
4: 3,
5: 3, # 3.1
6: 4,
7: 5,
}
return nvl_version_table[version_enum]
def nvlink_bandwidth(nvlink_version: int) -> int:
nvl_bw_table = {
1: 80,
2: 150,
3: 300,
4: 450,
5: 900,
}
return nvl_bw_table[nvlink_version]
def infer_cluster_info() -> ClusterInfo:
device = torch.cuda.current_device()
index = device.index if isinstance(device, torch.device) else device
with PyNVMLContext():
handle = pynvml.nvmlDeviceGetHandleByIndex(index)
compute_cap = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
logger.info(f"Compute capability: {compute_cap}")
err, properties = cudart.cudaGetDeviceProperties(index)
sm_count = properties.multiProcessorCount
logger.info(f"SM count: {sm_count}")
sm_clock = pynvml.nvmlDeviceGetMaxClockInfo(
handle,
pynvml.NVML_CLOCK_SM,
)
logger.info(f"SM clock: {sm_clock} MHz")
math_throughput = MathThroughput.to_tflops(
ipc_per_sm(compute_cap),
sm_count,
sm_clock,
)
for name in math_throughput.__dataclass_fields__:
tflops = getattr(math_throughput, name)
logger.info(f"{name} TFLOPS: {tflops}")
mem_info = _device_get_memory_info_fn(handle)
memory_budget = mem_info.total // (1024**3)
logger.info(f"Total Memory: {memory_budget} GiB")
mem_clock = pynvml.nvmlDeviceGetMaxClockInfo(
handle,
pynvml.NVML_CLOCK_MEM,
)
logger.info(f"Memory clock: {mem_clock} MHz")
mem_bus_width = pynvml.nvmlDeviceGetMemoryBusWidth(handle)
logger.info(f"Memory bus width: {mem_bus_width}")
memory_bw = mem_bus_width * mem_clock * 2 // int(8e3)
logger.info(f"Memory bandwidth: {memory_bw} GB/s")
try:
is_nvl_active = bool(pynvml.nvmlDeviceGetNvLinkState(handle, 0))
logger.info(f"NVLink is active: {is_nvl_active}")
except pynvml.NVMLError:
is_nvl_active = False
intra_node_sharp = False
if is_nvl_active:
nvl_version_enum = pynvml.nvmlDeviceGetNvLinkVersion(handle, 0)
nvl_version = nvlink_version(nvl_version_enum)
logger.info(f"NVLink version: {nvl_version}")
nvl_bw = nvlink_bandwidth(nvl_version)
logger.info(f"NVLink bandwidth (unidirectional): {nvl_bw} GB/s")
intra_node_bw = nvl_bw
if nvl_version >= 4:
intra_node_sharp = True
else:
pcie_speed = pynvml.nvmlDeviceGetPcieSpeed(handle)
logger.info(f"PCIe speed: {pcie_speed} Mbps")
pcie_link_width = pynvml.nvmlDeviceGetCurrPcieLinkWidth(handle)
logger.info(f"PCIe link width: {pcie_link_width}")
pcie_bw = pcie_speed * pcie_link_width // int(8e3)
logger.info(f"PCIe bandwidth: {pcie_bw} GB/s")
intra_node_bw = pcie_bw
cluster_info = ClusterInfo(
math_throughput=math_throughput,
memory_bw=memory_bw,
memory_budget_per_device=memory_budget,
intra_node_bw_per_device=intra_node_bw,
intra_node_sharp=intra_node_sharp,
)
return cluster_info
def infer_cluster_config() -> Dict[str, Union[str, ClusterInfo]]:
device_name = torch.cuda.get_device_name(torch.cuda.current_device())
cluster_key = infer_cluster_key()
if cluster_key is not None:
return dict(cluster_key=cluster_key)
else:
try:
cluster_info = infer_cluster_info()
except pynvml.NVMLError:
fallback_cluster_key = "L40"
cluster_info = copy.copy(cluster_infos[fallback_cluster_key])
memory_budget = torch.cuda.mem_get_info()[1] // (1024**3)
cluster_info.memory_budget_per_device = memory_budget
logger.warning(
f"Failed to infer cluster info for {device_name}, "
f"treat it as a {fallback_cluster_key} node with {memory_budget} GB memory. "
"This setting makes no effect if you do not use auto parallel.")
return dict(
cluster_key=device_name.replace(" ", "-"),
cluster_info=cluster_info,
)
if __name__ == "__main__":
logger.set_level("info")
infer_cluster_info()

View File

@ -1,61 +0,0 @@
from dataclasses import dataclass, field
from enum import auto
from typing import Dict, List, Optional, Union
from strenum import LowercaseStrEnum
from tensorrt_llm._utils import BaseEnumMeta, DictConversion
from .cluster_info import ClusterInfo, cluster_infos
class CostModel(LowercaseStrEnum, metaclass=BaseEnumMeta):
ALPHA_BETA = auto()
PROFILE = auto()
S_CURVE = auto()
# Zero cost model is for test purpose.
# Use zero cost model for communication will make solver prefer sharding
# Use zero cost model for computation will make solver prefer replication
ZERO = auto()
@dataclass
class AutoParallelConfig(DictConversion):
# cluster configuration
world_size: int = 1
gpus_per_node: int = 8
cluster_key: str = None
cluster_info: Optional[ClusterInfo] = None
# cost model configuration
sharding_cost_model: str = CostModel.ALPHA_BETA
comm_cost_model: str = CostModel.ALPHA_BETA
# strategy configuration
enable_pipeline_parallelism: bool = False
enable_shard_unbalanced_shape: bool = False
enable_shard_dynamic_shape: bool = False
enable_reduce_scatter: bool = True
# parallelization configuration
builder_flags: Optional[int] = None
debug_mode: bool = False
infer_shape: bool = True
validation_mode: bool = False
same_buffer_io: Dict[str, str] = field(default_factory=dict)
same_spec_io: Dict[str, str] = field(default_factory=dict)
sharded_io_allowlist: List[str] = field(default_factory=list)
fill_weights: bool = False
# debug configuration
parallel_config_cache: Optional[str] = None
profile_cache: Optional[str] = None
dump_path: Optional[str] = None
debug_outputs: Union[List[str], str] = field(default_factory=list)
def get_cluster_info(self) -> ClusterInfo:
return self.cluster_info or cluster_infos[self.cluster_key]
@property
def enabled(self) -> bool:
return self.world_size > 1

View File

@ -1,612 +0,0 @@
import os
import re
from abc import ABC, abstractmethod
from typing import List
import h5py
import numpy as np
from filelock import FileLock
from .config import AutoParallelConfig, CostModel
from .tensor_parallel.shape_consistency import ShapeConsistencyManager
class ProfileDB(ABC):
"""A database that stores profiling results for multiple device mesh
shapes."""
@abstractmethod
def query(self, cluster_key, data_key):
...
@abstractmethod
def update(self, cluster_key, data_key, mesh_result):
...
def close(self):
pass
class MemDB(ProfileDB):
def __init__(self):
self.data = {}
def query(self, cluster_key, data_key):
key = (cluster_key, data_key)
mesh_result = self.data.get(key, None)
if mesh_result is None:
return None
else:
return mesh_result[0]
def update(self, cluster_key, data_key, mesh_result):
key = (cluster_key, data_key)
self.data[key] = mesh_result
class Hdf5DB(ProfileDB):
def __init__(self, name):
self.name = name
lock_name = self.name + ".lock"
self.lock = FileLock(lock_name, thread_local=False)
def query(self, cluster_key, data_key):
file_name = f"{self.name}.hdf5"
key = str((cluster_key, data_key))
self.lock.acquire()
mesh_result = None
with h5py.File(file_name, 'a') as f:
if key in f:
self.lock.release()
mesh_result = f[key]
return mesh_result[0]
else:
return None
def update(self, cluster_key, data_key, mesh_result):
key = str((cluster_key, data_key))
file_name = f"{self.name}.hdf5"
with h5py.File(file_name, 'a') as f:
f[key] = mesh_result
def close(self):
self.lock.release(force=True)
class LogicalDeviceMesh(object):
def __init__(self,
phy_mesh_shape,
mesh_shape,
phy_ids,
config: AutoParallelConfig,
alpha,
beta,
sharp,
prof_database=None,
shape_consistency_manager=None,
host_ips=None):
self.phy_mesh_shape = phy_mesh_shape
self.mesh_shape = mesh_shape
self.phy_ids = phy_ids
self.host_ips = host_ips
self.cluster_key = config.cluster_key + '_mesh_shape{}'.format('_'.join(
[str(i) for i in mesh_shape]))
self.prof_min_max_size = [1, 2**34]
self.prof_comm_dtypes = [
"int8", "uint8", "int32", "uint32", "int64", "uint64", "float16",
"float32", "float64", "bfloat16"
]
self.devices_group = {
(0, ): [self.phy_ids.transpose(), self.mesh_shape[1] - 1],
(1, ): [self.phy_ids, self.mesh_shape[1]],
(0, 1): [self.phy_ids.reshape([1, self.phy_ids.size]), 0]
}
self.prof_database = prof_database
self.shape_consistency_manager = shape_consistency_manager
self.config = config
self.cluster_info = config.get_cluster_info()
self.hw_alpha = alpha
self.hw_beta = beta
self.hw_sharp = sharp
self.algo_alpha_beta = self._estimate_algo_alpha_beta()
self.comm_op_to_nccl_test_func_name = {
'all_reduce': 'all_reduce_perf_mpi',
'all_gather': 'all_gather_perf_mpi',
'all_to_all': 'alltoall_perf_mpi',
'reduce_scatter': 'reduce_scatter_perf_mpi',
'split': 'split',
}
@property
def size(self) -> int:
return self.phy_ids.size
def _estimate_algo_alpha_beta(self):
ret = {}
ar_alpha, ar_beta = {}, {}
ag_alpha, ag_beta = {}, {}
rs_alpha, rs_beta = {}, {}
a2a_alpha, a2a_beta = {}, {}
phy_num_hosts, phy_num_devices_per_host = self.phy_mesh_shape
if phy_num_hosts == 1 or phy_num_devices_per_host == 1:
for dims in [(0, ), (1, ), (0, 1), (1, 0)]:
num_devices = 1
for dim in dims:
num_devices = self.mesh_shape[dim] * num_devices
if num_devices != 1:
ar_alpha[dims] = self.hw_alpha[0] if self.hw_sharp[
0] else self.hw_alpha[0] * num_devices / 2 / (
num_devices - 1)
ar_beta[dims] = self.hw_beta[0]
ag_alpha[dims] = self.hw_alpha[0] * num_devices / (
num_devices - 1)
ag_beta[dims] = self.hw_beta[0]
rs_alpha[dims] = self.hw_alpha[0] * num_devices / (
num_devices - 1)
rs_beta[dims] = self.hw_beta[0]
a2a_alpha[dims] = self.hw_alpha[0] * num_devices / (
num_devices - 1)
a2a_beta[dims] = self.hw_beta[0]
# phy and logical have the same mesh shape if num_hosts > 1 and num_devices_per_host > 1
else:
for dims in [(0, ), (1, ), (0, 1), (1, 0)]:
num_devices = 1
for dim in dims:
num_devices = self.mesh_shape[dim] * num_devices
if num_devices != 1:
if len(dims) == 1:
dim = dims[0]
ar_alpha[dims] = self.hw_alpha[dim] if self.hw_sharp[
dim] else self.hw_alpha[dim] * num_devices / 2 / (
num_devices - 1)
ar_beta[dims] = self.hw_beta[dim]
ag_alpha[dims] = self.hw_alpha[dim] * num_devices / (
num_devices - 1)
ag_beta[dims] = self.hw_beta[dim]
rs_alpha[dims] = self.hw_alpha[dim] * num_devices / (
num_devices - 1)
rs_beta[dims] = self.hw_beta[dim]
a2a_alpha[dims] = self.hw_alpha[dim] * num_devices / (
num_devices - 1)
a2a_beta[dims] = self.hw_beta[dim]
elif len(dims) == 2: # two level communication
num_hosts, num_devices_per_host = phy_num_hosts, phy_num_devices_per_host
inter_node_col_alpha = self.hw_alpha[
0] * num_devices_per_host
inter_node_ar_alpha = inter_node_col_alpha if self.hw_sharp[
0] else inter_node_col_alpha * num_hosts / 2 / (
num_hosts - 1)
intra_node_ar_alpha = self.hw_alpha[1]
intra_node_ar_alpha = intra_node_ar_alpha if self.hw_sharp[
1] else intra_node_ar_alpha * num_devices_per_host / 2 / (
num_devices_per_host - 1)
ar_alpha[dims] = min(inter_node_ar_alpha,
intra_node_ar_alpha)
ar_beta[dims] = max(self.hw_beta)
ag_alpha[dims] = min(
inter_node_col_alpha * num_hosts / (num_hosts - 1),
self.hw_alpha[1] * num_devices_per_host /
(num_devices_per_host - 1))
ag_beta[dims] = max(self.hw_beta)
rs_alpha[dims] = ag_alpha[dims]
rs_beta[dims] = ag_beta[dims]
a2a_alpha[dims] = min(
num_hosts * self.hw_alpha[0] / (num_hosts - 1),
self.hw_alpha[1] * num_hosts)
a2a_beta[dims] = max(self.hw_beta)
else:
pass
ret['all_to_all'] = [a2a_alpha, a2a_beta]
ret['all_reduce'] = [ar_alpha, ar_beta]
ret['all_gather'] = [ag_alpha, ag_beta]
ret['reduce_scatter'] = [rs_alpha, rs_beta]
ret['p2p_cross_device'] = [
self.cluster_info.intra_node_bw_per_device,
self.cluster_info.intra_node_latency
]
ret['p2p_cross_host'] = [
self.cluster_info.inter_node_bw_per_device,
self.cluster_info.inter_node_latency
]
return ret
#[ToDo] stub functions here
def _profile_split(self, min_max_comm_size):
comm_size, elapsed_time = [], []
size = min_max_comm_size[0]
while size <= min_max_comm_size[1]:
time = size * 2 / self.cluster_info.memory_bw
comm_size.append(size)
elapsed_time.append(time)
size = size * 2
return np.array([comm_size, elapsed_time])
def _prase_nccl_test_results(self, f_nccl_test_out_log):
'''[ToDo] There is some dtye that may not been supported by nccl test, using default dtype (float)'''
start_parse = False
comm_size, elapsed_time = [], []
try:
with open(f_nccl_test_out_log, 'r') as lines:
for line in lines:
if start_parse:
prof_data = re.split(r"[ ]+", line.strip())
if len(prof_data) != 13:
continue
comm_size.append(float(prof_data[0]))
elapsed_time.append(float(prof_data[5]))
if 'GB/s' in line and 'us' in line:
start_parse = True
except Exception:
print(f'failed to parse {f_nccl_test_out_log}')
return comm_size, elapsed_time
def _profile_with_nccl_test(self, min_max_comm_size, dtype, device_group,
func_name, step, workload_key):
if func_name == 'split':
if 2 == step:
return self._profile_split(min_max_comm_size)
else:
return None
workspace_dir = self.config['profiling_workspace'] + f'/{workload_key}'
os.makedirs(workspace_dir, exist_ok=True)
outfile, errfile = workspace_dir + '/profile.out', workspace_dir + '/profile.err'
if 1 == step:
num_nodes = len(self.host_ips)
num_gpus = self.mesh_shape[0] * self.mesh_shape[1]
ntasks_per_node = num_gpus // num_nodes
nccl_test_command = '"export NCCL_TESTS_SPLIT_MASK={} && export NCCL_COLLNET_ENABLE=1 && {} -b {} -e {} -g 1 -d {} -f {}"'.format(
device_group[1], func_name, min_max_comm_size[0],
min_max_comm_size[1], dtype, 2)
sbatch_command = '#!/bin/bash\n'
sbatch_command += '#SBATCH -p {}\n'.format(self.config['partition'])
sbatch_command += '#SBATCH -A {}\n'.format(self.config['account'])
sbatch_command += '#SBATCH -J {}\n'.format(self.config['jobname'])
sbatch_command += '#SBATCH -N {}\n'.format(num_nodes)
sbatch_command += '#SBATCH -t {}\n'.format(self.config['time'])
sbatch_command += '#SBATCH --ntasks-per-node={}\n'.format(
ntasks_per_node)
sbatch_command += '#SBATCH --exclusive\n'
sbatch_command += '#SBATCH --mem=0\n'
sbatch_command += '#SBATCH --network=sharp\n'
sbatch_command += '#SBATCH --mail-type=FAIL\n'
srun_command = 'srun --nodes={} --mpi=pmix --ntasks-per-node={} --network=sharp -o {} -e {} --container-image={} bash -c '.format(
num_nodes, ntasks_per_node, outfile, errfile,
self.config['container'])
command = sbatch_command + srun_command + nccl_test_command
with open(workspace_dir + '/workload.sub', 'w') as f:
f.write(command)
with open('./preprofiling_step1.sh', 'a') as f:
f.write(f'sbatch {workspace_dir}/workload.sub\n')
return None
else:
comm_size, elapsed_time = self._prase_nccl_test_results(outfile)
if len(comm_size) < 2:
assert 0, 'the profiling for {} was failed at step1, please try again'.format(
workload_key)
else:
print(workload_key, comm_size, elapsed_time)
return np.array([comm_size, elapsed_time])
def _profile_single_comm_perf(self, device_group, comm_op, step, data_key):
results = {}
func_name = self.comm_op_to_nccl_test_func_name[comm_op]
for dtype in self.prof_comm_dtypes:
size_time = self._profile_with_nccl_test(
self.prof_min_max_size, dtype, device_group, func_name, step,
data_key + f'_dtype{dtype}')
results[dtype] = size_time
return results
def profile_all_comms_perf(self, step):
if self.mesh_shape == (1, 1):
return None
mesh_results = self.prof_database.query(self.cluster_key,
self.mesh_shape)
if mesh_results:
return mesh_results
mesh_results = {}
data_key = self.cluster_key + f'_mesh_shape{self.mesh_shape[0]}x{self.mesh_shape[1]}'
for comm_op in [
'all_reduce', 'all_to_all', 'all_gather', 'reduce_scatter',
'split'
]:
comm_perf = {}
for dim, device_group in self.devices_group.items():
# don't need to profile for mesh dim == 1
if len(dim) == 1 and self.mesh_shape[dim[0]] == 1:
continue
comm_perf[dim] = self._profile_single_comm_perf(
device_group, comm_op, step, data_key +
'_comm_op{}_dim{}'.format(comm_op, ''.join(map(str, dim))))
mesh_results[comm_op] = comm_perf
if 2 == step:
self.prof_database.update(self.cluster_key, self.mesh_shape,
mesh_results)
return mesh_results
def _model_comm_cost_from_s_curve(self, size_time_array, realsize):
assert size_time_array[0][0] <= realsize <= size_time_array[0][-1],\
'the comm_size: {} is not in the profile range: [{}{}]'\
.format(realsize, size_time_array[0][0], size_time_array[0][-1])
return np.interp(realsize, size_time_array[0], size_time_array[1])
def _model_comm_cost_from_alpha_beta(self, comm_op, dim_key, size_in_bytes):
elapsed_time = 0.0
if 'split' == comm_op:
elapsed_time = size_in_bytes * 2 / (
self.cluster_info.memory_bw *
self.cluster_info.memory_efficiency) * 1e-3
else:
dict_alpha, dict_beta = self.algo_alpha_beta[comm_op]
alpha, beta = dict_alpha[dim_key], dict_beta[dim_key]
elapsed_time = (size_in_bytes /
(alpha * self.cluster_info.communication_efficiency)
* 1e-3) + beta
return elapsed_time
def _input_size_to_comm_size(self, comm_op, dims, input_size):
ret = input_size
if 'all_gather' == comm_op:
for dim in dims:
ret = ret * self.mesh_shape[dim]
return ret
def estimate_comm_cost(self, comm_op, dim, input_size, dtype):
size = self._input_size_to_comm_size(comm_op, dim, input_size)
if self.config.comm_cost_model == CostModel.S_CURVE:
mesh_perf = self.prof_database.query(self.cluster_key,
self.mesh_shape)
assert mesh_perf is not None, 'the mesh is not profiled, mesh_shape = {}'.format(
self.mesh_shape)
comm_op_perf = mesh_perf.get(comm_op, None)
assert comm_op_perf is not None, '{} is not profiled'.format(
comm_op)
elapsed_time = self._model_comm_cost_from_s_curve(
comm_op_perf[tuple(dim)][dtype], size)
return elapsed_time
elif self.config.comm_cost_model == CostModel.ALPHA_BETA:
elapsed_time = self._model_comm_cost_from_alpha_beta(
comm_op, tuple(dim), size)
elif self.config.comm_cost_model == CostModel.PROFILE:
assert False, 'Unsupported profile based communication cost model now'
elif self.config.comm_cost_model == CostModel.ZERO:
elapsed_time = 0.0
return elapsed_time # us
class PhysicalDeviceMesh(object):
def __init__(self,
phy_devices_id,
config: AutoParallelConfig,
prof_database=None,
shape_consistency_manager=None,
host_ips=None):
self.phy_devices_id = np.array(phy_devices_id)
self.num_hosts, self.num_devices_per_host = self.phy_devices_id.shape
self.host_ips = host_ips
if host_ips is None:
self.host_ips = [''] * self.num_hosts
self.config = config
self.cluster_info = config.get_cluster_info()
self.prof_database: ProfileDB = prof_database
self.shape_consistency_manager = shape_consistency_manager
if self.config.comm_cost_model not in CostModel:
raise ValueError(
f'unsupported communication cost model: {self.config.comm_cost_model}'
)
if self.config.sharding_cost_model not in CostModel:
raise ValueError(
f'unsupported sharding cost model: {self.config.sharding_cost_model}'
)
if self.config.comm_cost_model == CostModel.S_CURVE or self.config.sharding_cost_model == CostModel.PROFILE:
if self.prof_database is None:
profile_cache = config.profile_cache
if profile_cache is None:
self.prof_database = MemDB()
else:
self.prof_database = Hdf5DB(profile_cache)
elif self.config.comm_cost_model == CostModel.ALPHA_BETA:
assert self.cluster_info.intra_node_bw_per_device > 0, 'intra_node_bw_per_device is needed for alpha_beta method'
assert self.cluster_info.inter_node_bw_per_device > 0, 'inter_node_bw_per_device is needed for alpha_beta method'
if self.config.sharding_cost_model == CostModel.ALPHA_BETA:
assert self.cluster_info.memory_bw > 0, 'memory_bw is needed for alpha_beta method'
if not shape_consistency_manager:
self.shape_consistency_manager = ShapeConsistencyManager()
@property
def size(self) -> int:
return self.phy_devices_id.size
def close(self):
if self.prof_database is not None:
self.prof_database.close()
def split_pipeline_meshes(
self, num_stages,
num_devices_per_stage) -> List["PhysicalDeviceMesh"]:
sub_meshes = []
if num_devices_per_stage <= self.num_devices_per_host:
assert self.num_devices_per_host % num_devices_per_stage == 0, \
"num_devices_per_host ({}) % num_devices_per_stage ({}) != 0"\
.format(self.num_devices_per_host, num_devices_per_stage)
num_clusters_per_host = self.num_devices_per_host // num_devices_per_stage
num_clusters = self.num_hosts * num_clusters_per_host
assert num_stages % num_clusters == 0, \
"num_stages({}) % num_clusters({}) !=0".format(num_stages, num_clusters)
for mesh_id in range(num_stages):
cluster_id = mesh_id % num_clusters
cluster_col = cluster_id % num_clusters_per_host
cluster_row = cluster_id // num_clusters_per_host
sub_devices_id = [
self.phy_devices_id[cluster_row][cluster_col *
num_devices_per_stage:(
(cluster_col + 1) *
num_devices_per_stage)]
]
sub_meshes.append(
PhysicalDeviceMesh(sub_devices_id, self.config,
self.prof_database,
self.shape_consistency_manager,
[self.host_ips[cluster_row]]))
else:
assert num_devices_per_stage % self.num_devices_per_host == 0, \
"num_devices_per_stage ({}) % num_devices_per_host ({}) != 0"\
.format(num_devices_per_stage, self.num_devices_per_host)
num_host_per_cluster = num_devices_per_stage // self.num_devices_per_host
assert self.num_hosts % num_host_per_cluster == 0, \
"num_hosts ({}) % num_host_per_cluster({}) != 0".format(self.num_hosts, num_host_per_cluster)
num_clusters = self.num_hosts // num_host_per_cluster
for mesh_id in range(num_stages):
cluster_id = mesh_id % num_clusters
cluster_row = cluster_id * num_host_per_cluster
sub_devices_id = self.phy_devices_id[cluster_row:(
cluster_row + num_host_per_cluster)]
host_ips = self.host_ips[cluster_row:(cluster_row +
num_host_per_cluster)]
sub_meshes.append(
PhysicalDeviceMesh(sub_devices_id, self.config,
self.prof_database,
self.shape_consistency_manager,
host_ips))
return sub_meshes
def _profile_logical_meshes(self, logical_meshes, step):
for lmesh in logical_meshes:
lmesh.profile_all_comms_perf(step)
def as_logical_mesh(self) -> LogicalDeviceMesh:
alpha = [
self.cluster_info.inter_node_bw_per_device,
self.cluster_info.intra_node_bw_per_device
]
beta = [
self.cluster_info.inter_node_latency,
self.cluster_info.intra_node_latency
]
sharp = [
self.cluster_info.inter_node_sharp,
self.cluster_info.intra_node_sharp
]
return LogicalDeviceMesh(
self.phy_devices_id.shape,
self.phy_devices_id.shape,
self.phy_devices_id,
self.config,
alpha,
beta,
sharp,
self.prof_database,
self.shape_consistency_manager,
self.host_ips,
)
def get_logical_meshes(self):
logical_meshes = []
# (1, 2) -> (1, 2)
# (1, 4) -> (2, 2)
# (1, 8) -> (2, 4)
# (1, 16) -> (2, 8), (4, 4)
# (1, 32) -> (2, 16), (4, 8)
# (1, 48) -> (2, 24), (3, 16), (4, 12), (6, 8)
# (1, 64) -> (2, 32), (4, 16), (8, 8)
# we will traverse logical shape's axis in sharding spec, thus (2, 8) contains (8, 2)
# we will merge logical shapes' axis, thus (2, 8) contains (1, 16) and (16, 1)
if self.num_hosts == 1:
alpha = [self.cluster_info.intra_node_bw_per_device]
beta = [self.cluster_info.intra_node_latency]
sharp = [self.cluster_info.intra_node_sharp]
for i in range(2, self.num_devices_per_host):
if self.num_devices_per_host % i == 0 and i * i <= self.num_devices_per_host:
lmesh_shape = (i, self.num_devices_per_host // i)
lmesh_phy_ids = self.phy_devices_id.reshape(lmesh_shape)
logical_meshes.append(
LogicalDeviceMesh(self.phy_devices_id.shape,
lmesh_shape, lmesh_phy_ids,
self.config, alpha, beta, sharp,
self.prof_database,
self.shape_consistency_manager,
self.host_ips))
# (8, 1) -> (2, 4)
# (16, 1) -> (2, 8), (4, 4)
elif self.num_devices_per_host == 1:
alpha = [self.cluster_info.inter_node_bw_per_device]
beta = [self.cluster_info.inter_node_latency]
sharp = [self.cluster_info.inter_node_sharp]
for i in range(2, self.num_hosts):
if self.num_hosts % i == 0 and i * i <= self.num_hosts:
lmesh_shape = (i, self.num_hosts // i)
lmesh_phy_ids = self.phy_devices_id.reshape(lmesh_shape)
logical_meshes.append(
LogicalDeviceMesh(self.phy_devices_id.shape,
lmesh_phy_ids, self.config, alpha,
beta, sharp, self.prof_database,
self.shape_consistency_manager,
self.host_ips))
# (2, 1) -> (2, 1)
# (2, 8) -> (2, 8)
# (1, 2) -> (1, 2)
# (1, 3) -> (1, 3)
# (1, 5) -> (1, 5)
if 0 == len(logical_meshes):
logical_meshes.append(self.as_logical_mesh())
return logical_meshes
'''
we assume we can evenly split the pipeline and deviceMesh
'''
def _list_all_sub_meshes(self):
sub_meshes = []
for num_devices_per_stage in range(1, self.num_devices_per_host + 1):
if self.num_devices_per_host % num_devices_per_stage == 0:
num_stages = self.num_hosts * self.num_devices_per_host // num_devices_per_stage
sub_meshes.append(
self.split_pipeline_meshes(num_stages,
num_devices_per_stage)[0])
for num_hosts_per_stage in range(2, self.num_hosts + 1):
if self.num_hosts % num_hosts_per_stage == 0:
num_stages = self.num_hosts // num_hosts_per_stage
sub_meshes.append(
self.split_pipeline_meshes(
num_stages,
num_hosts_per_stage * self.num_devices_per_host)[0])
return sub_meshes
def list_all_pipeline_configs(self):
configs = []
for num_devices_per_stage in range(1, self.num_devices_per_host + 1):
if self.num_devices_per_host % num_devices_per_stage == 0:
num_stages = self.num_hosts * self.num_devices_per_host // num_devices_per_stage
configs.append((num_stages, num_devices_per_stage))
for num_hosts_per_stage in range(2, self.num_hosts + 1):
if self.num_hosts % num_hosts_per_stage == 0:
num_stages = self.num_hosts // num_hosts_per_stage
configs.append(
(num_stages,
num_hosts_per_stage * self.num_devices_per_host))
return configs
def profile_s_curve(self, step):
sub_phy_device_meshes = self._list_all_sub_meshes()
for phy_mesh in sub_phy_device_meshes:
lmeshes = phy_mesh.get_logical_meshes()
self._profile_logical_meshes(lmeshes, step)
if 2 == step:
self.save_profile_database()
def profile_alpha_beta(self):
alpha = [250, 25]
beta = [100, 100]
return alpha, beta

View File

@ -1,347 +0,0 @@
from typing import List
import pandas as pd
import tensorrt as trt
from .pipeline_graph import PipelineGraph
from .runtime_profiling import RuntimeProfiler
from .simplifier import GraphConfig, StageType
from .solver import CostGraph, Solver
from .tensor_parallel.activation_node import Activation
from .tensor_parallel.assertion_node import Assertion
from .tensor_parallel.cast_node import Cast
from .tensor_parallel.concatenation_node import Concatenation
from .tensor_parallel.constant_node import Constant
from .tensor_parallel.elementwise_node import ElementWise
from .tensor_parallel.fill_node import Fill
from .tensor_parallel.gather_node import Gather
from .tensor_parallel.identity_node import Identity
from .tensor_parallel.input_node import InputNode
from .tensor_parallel.matmul_node import MatrixMultiply
from .tensor_parallel.node import Node
from .tensor_parallel.normalization_node import Normalization
from .tensor_parallel.output_node import OuputNode
from .tensor_parallel.p2p_node import P2PNode, P2PType
from .tensor_parallel.plugin_node import PluginNode
from .tensor_parallel.plugin_nodes.gemm_node import GemmPlugin
from .tensor_parallel.plugin_nodes.gpt_attention_node import GPTAttentionPlugin
from .tensor_parallel.plugin_nodes.identity_node import IdentityPlugin
from .tensor_parallel.plugin_nodes.look_up_node import LookupPlugin
from .tensor_parallel.plugin_nodes.normalization_node import (LayernormPlugin,
RMSnormPlugin)
from .tensor_parallel.reduce_node import Reduce
from .tensor_parallel.select_node import Select
from .tensor_parallel.shape_node import Shape
from .tensor_parallel.shuffle_node import Shuffle
from .tensor_parallel.slice_node import Slice
from .tensor_parallel.softmax_node import SoftMax
from .tensor_parallel.unary_node import Unary
LAYER_TYPE_2_NODE_TYPE = {
trt.LayerType.ACTIVATION: Activation,
trt.LayerType.ASSERTION: Assertion,
trt.LayerType.CAST: Cast,
trt.LayerType.CONCATENATION: Concatenation,
trt.LayerType.CONSTANT: Constant,
trt.LayerType.ELEMENTWISE: ElementWise,
trt.LayerType.FILL: Fill,
trt.LayerType.GATHER: Gather,
trt.LayerType.IDENTITY: Identity,
trt.LayerType.MATRIX_MULTIPLY: MatrixMultiply,
trt.LayerType.NORMALIZATION: Normalization,
trt.LayerType.PLUGIN_V2: PluginNode,
trt.LayerType.REDUCE: Reduce,
trt.LayerType.SELECT: Select,
trt.LayerType.SHAPE: Shape,
trt.LayerType.SHUFFLE: Shuffle,
trt.LayerType.SLICE: Slice,
trt.LayerType.SOFTMAX: SoftMax,
trt.LayerType.UNARY: Unary,
}
# TODO: BertAttention/All Quant plugins
PLUGIN_LAYER_TYPE_2_NODE_TYPE = {
'GPTAttention': GPTAttentionPlugin,
'Gemm': GemmPlugin,
'Layernorm': LayernormPlugin,
'Rmsnorm': RMSnormPlugin,
'Lookup': LookupPlugin,
'Identity': IdentityPlugin,
}
class NodeGraph:
def __init__(self, graph: PipelineGraph):
self._nodes = {}
# construct nodes
for input in graph.inputs:
self._nodes[input.name] = InputNode(input)
for layer in graph.layers:
layer.to_base_class()
if "p2p_type" in layer.attrs:
self._nodes[layer.name] = P2PNode(layer)
elif layer.type == trt.LayerType.PLUGIN_V2:
layer.to_subclass()
plugin_type = layer.as_trt().plugin.plugin_type
layer.to_base_class()
if plugin_type in PLUGIN_LAYER_TYPE_2_NODE_TYPE:
node = PLUGIN_LAYER_TYPE_2_NODE_TYPE[plugin_type](layer)
else:
node = LAYER_TYPE_2_NODE_TYPE[layer.type](layer)
self._nodes[layer.name] = node
else:
node = LAYER_TYPE_2_NODE_TYPE[layer.type](layer)
self._nodes[layer.name] = node
for output in graph.outputs:
self._nodes[output.name] = OuputNode(output)
for node in self.nodes:
node.post_init(self)
node.node_runtime_profiler = RuntimeProfiler()
def get_node(self, name):
return self._nodes[name]
@property
def nodes(self) -> List[Node]:
return [*self._nodes.values()]
def assign_cost_weights(self, graph_config: GraphConfig):
layer_mapping = graph_config.graph_mapping.layer_mapping
for layer_name in layer_mapping.values():
node = self.get_node(layer_name)
node.sharding_weight += 1
node.resharding_weight += 1
same_spec_layer_mapping = graph_config.graph_mapping.same_spec_layer_mapping
for same_spec_layer_name, layer_name in same_spec_layer_mapping.items():
node = self.get_node(layer_name)
same_spec_node = self.get_node(same_spec_layer_name)
same_spec_node.sharding_weight = node.sharding_weight
same_spec_node.resharding_weight = node.resharding_weight
def set_slowest_stage(self, stage_type: StageType,
graph_config: GraphConfig):
num_micro_batches = graph_config.num_micro_batches
block_per_stage = graph_config.num_blocks // graph_config.num_stages
block_pipeline_weight = block_per_stage * (num_micro_batches - 1)
for node in self.nodes:
node.pipeline_weight = 0
node.cost_level = -1
if node.stage_type == StageType.START:
if stage_type == StageType.START:
node.pipeline_weight = num_micro_batches - 1
node.cost_level = 1
else:
node.cost_level = 0
if stage_type == StageType.START and node.in_start_block:
node.pipeline_weight = block_pipeline_weight
if node.stage_type == StageType.END:
if stage_type == StageType.END:
node.pipeline_weight = num_micro_batches - 1
node.cost_level = 1
else:
node.cost_level = 0
if stage_type == StageType.END and node.in_end_block:
node.pipeline_weight = block_pipeline_weight
if isinstance(node, P2PNode):
if (graph_config.has_cross_host
and node.p2p_type == P2PType.CROSS_HOST) or (
not graph_config.has_cross_host
and node.p2p_type == P2PType.CROSS_DEVICE):
if stage_type == StageType.BLOCK:
node.pipeline_weight += num_micro_batches - 1
node.cost_level = 1
else:
node.cost_level = 0
elif (graph_config.has_cross_device
and node.p2p_type == P2PType.CROSS_DEVICE) or (
not graph_config.has_cross_device
and node.p2p_type == P2PType.CROSS_HOST):
node.pipeline_weight += num_micro_batches - 1
if stage_type == StageType.BLOCK and node.in_slowest_block:
node.pipeline_weight = block_pipeline_weight
def get_cost_graph(self, lmesh):
leaf_strategies = []
for node in self.nodes:
if node.is_replicated:
node.set_strategy(None, lmesh)
else:
node.collect_strategies(lmesh)
for node in self.nodes:
strategies_vector = node.update_resharding_cost()
if len(strategies_vector) != 0:
leaf_strategies.append(strategies_vector)
cost_graph = CostGraph(leaf_strategies)
return cost_graph
def find_solution(self, cost_graph, memory_budget):
solver = Solver(cost_graph, memory_budget=memory_budget)
solution = solver.find_solution()[1]
graph_strategy = solution.node_best_strategy
for node_name, strategy in graph_strategy.items():
node = self._nodes[node_name]
for idx, pre_node in enumerate(node.predecessor_nodes):
if pre_node is None:
continue
if pre_node.node_name not in strategy.best_resharding_cost:
continue
strategy.best_resharding_cost[
idx] = strategy.best_resharding_cost[pre_node.node_name]
strategy.node_names[idx] = pre_node.node_name
for key in list(strategy.best_resharding_cost.keys()):
if isinstance(key, str):
del strategy.best_resharding_cost[key]
return solution
def visualize(self, name='pp_graph'):
with open(name + '.dot', 'w') as f:
f.write("digraph {\n")
'''
f.write(" // Value Nodes\n")
for name, tensor in self._tensors.items():
f.write(" \"{}\" [fillcolor = \"green\", label = \"{}\", shape = \"box\", style = \"filled\"];\n".format(name, tensor.shape))
'''
f.write(" // Operation Nodes\n")
for name, node in self._nodes.items():
fillcolor = 'white'
if 'MATRIX_MULTIPLY' in name:
fillcolor = 'green'
label = name
if len(node.outputs) > 0:
label = name + '\\n' + str(node.outputs[0].shape)
f.write(
" \"{}\" [fillcolor = \"{}\", label = \"{}\", shape = \"box\", style = \"filled\"];\n"
.format(name, fillcolor, label))
f.write(" // Edges\n")
for name, node in self._nodes.items():
for successor_node in node.successor_nodes:
if successor_node:
f.write(" \"{}\" ->\"{}\";\n".format(
name, successor_node.node_name))
f.write(" }\n")
def visualize_solution(self,
solution,
fname='pp_graph_solution',
ignore_shape_io=True):
with open(fname + '.dot', 'w') as f:
names, costs, block_ids = [], [], []
f.write("digraph {\n")
f.write(" // Operation Nodes\n")
for name, node in self._nodes.items():
if ignore_shape_io and node.layer is not None and node.layer.is_shape_io:
continue
cost = 0.0
fillcolor = 'white'
if 'MATRIX_MULTIPLY' in name or 'PLUGIN_V2_Gemm' in name:
fillcolor = 'orange'
elif '_same_spec' in name:
fillcolor = 'gray'
elif 'p2p_block' in name:
fillcolor = 'blue'
elif 'PLUGIN' in name:
fillcolor = 'yellow'
shape = 'box'
if 'output_node' == node.node_type or 'input_node' == node.node_type:
shape = 'ellipse'
fillcolor = 'green'
label = name + f'_block{node.building_block_id}_weight{node.sharding_weight}'
if len(node.inputs) > 0:
for idx, input in enumerate(node.inputs):
if not input:
continue
label = label + f'\\ninput{idx}_' + str(
input.shape) + f'_{input.dtype_str_size[0]}_'
if node.node_name in solution.node_best_strategy:
best_strategy = solution.node_best_strategy[
node.node_name]
shard_seq = str(
best_strategy.sharding_specs[f'input{idx}'].
sharding_sequence)
label = label + shard_seq
if idx not in best_strategy.best_resharding_cost:
continue
rcosts = best_strategy.best_resharding_cost[idx][0]
comm_action_sequence, resharding_cost = rcosts[
1], rcosts[2]
if len(comm_action_sequence) > 0:
label = label + '|'
for commspec in comm_action_sequence:
comm = [
commspec.comm_pattern, commspec.gather_dim,
commspec.shard_dim,
commspec.logical_process_axis
]
label = label + '->' + str(comm)
if resharding_cost > 0:
label = label + '_rcost{:.2}'.format(
resharding_cost)
cost = cost + resharding_cost
if len(node.outputs) > 0:
best_strategy = None
for idx, output in enumerate(node.outputs):
label = label + f'\\noutput{idx}_' + str(
output.shape) + f'_{output.dtype_str_size[0]}'
if node.node_name in solution.node_best_strategy:
best_strategy = solution.node_best_strategy[
node.node_name]
shard_seq = str(
best_strategy.sharding_specs[f'output{idx}'].
sharding_sequence)
comm = None
if f'output{idx}' in best_strategy.communication_actions:
commspec = best_strategy.communication_actions[
f'output{idx}']
comm = [
commspec.comm_pattern, commspec.gather_dim,
commspec.shard_dim,
commspec.logical_process_axis
]
label = label + '_' + shard_seq
if comm:
label = label + f' | {comm}'
if best_strategy:
cost = cost + best_strategy.sharding_cost + best_strategy.communication_cost
label = label + '| scost{:.2}'.format(
best_strategy.sharding_cost)
if best_strategy.communication_cost > 0:
label = label + ' | ccost{:.2}'.format(
best_strategy.communication_cost)
names.append(name)
costs.append(cost)
block_ids.append([
node.building_block_id, node.cost_level,
node.sharding_weight + node.pipeline_weight,
node.same_spec_id
])
f.write(
" \"{}\" [fillcolor = \"{}\", label = \"{}\", shape = \"{}\", style = \"filled\"];\n"
.format(name, fillcolor, label, shape))
f.write(" // Edges\n")
for name, node in self._nodes.items():
if ignore_shape_io and node.layer is not None and node.layer.is_shape_io:
continue
for successor_node in node.successor_nodes:
if successor_node:
if ignore_shape_io and successor_node.layer is not None and successor_node.layer.is_shape_io:
continue
f.write(" \"{}\" ->\"{}\";\n".format(
name, successor_node.node_name))
f.write(" }\n")
df = pd.DataFrame.from_dict({
'node':
names,
'cost':
costs,
'block_id': [block[0] for block in block_ids],
'cost_level': [block[1] for block in block_ids],
'sharding_weight': [block[2] for block in block_ids],
'same_spec_id': [block[3] for block in block_ids]
})
df['weight_cost'] = df['sharding_weight'] * df['cost']
df.to_csv(fname + '.csv')

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,150 +0,0 @@
import numpy as np
import tensorrt as trt
import torch
from tensorrt_llm.logger import logger
from tensorrt_llm.network import get_plugin_info
from .shape_info import get_per_layer_graph
from .utils import get_cache_key, get_trt_network, get_updated_plugin
class NvtxProfiler(object):
def __init__(self, nvtx_name, enable=True):
self.nvtx_name = nvtx_name
self.enable = enable
def __enter__(self):
if self.enable:
torch.cuda.nvtx.range_push(self.nvtx_name)
def __exit__(self, exc_type, exc_val, exc_tb):
if self.enable:
torch.cuda.nvtx.range_pop()
class LayerProfiler(trt.IProfiler):
def __init__(self):
trt.IProfiler.__init__(self)
self.layer_count = 0
self.time = 0
def report_layer_time(self, layer_name, ms):
logger.debug(f'{layer_name=}, {self.layer_count=}, time = {ms} ms')
self.time += ms
self.layer_count += 1
class RuntimeProfiler(object):
def __init__(self):
self.timing_cache = None
def _profile(self, layer, layer_attrs, shapes, values, io_buffer_mapping):
is_plugin = layer.type == trt.LayerType.PLUGIN_V2
if is_plugin and len(layer_attrs) > 0:
plugin_info = get_plugin_info(
get_trt_network(layer),
layer.name,
)
new_plugin, _ = get_updated_plugin(plugin_info, layer_attrs)
layer_attrs = {"plugin": new_plugin}
graph, output_mapping = get_per_layer_graph(layer, shapes, values,
layer_attrs)
graph._io_buffer_mapping = io_buffer_mapping
network = graph.as_trt()
if network.num_outputs > 0 and np.all([
network.get_output(i).is_shape_tensor
for i in range(network.num_outputs)
]):
return 0.0
for proxy_output, output in output_mapping.items():
shapes[proxy_output] = shapes[output]
if not self.timing_cache:
self.timing_cache = network.builder.create_builder_config(
).create_timing_cache(b"")
runner = graph.get_runner(
shapes,
values,
timing_cache=self.timing_cache,
)
context = runner.session.context
context.profiler = LayerProfiler()
runner.run()
profiler_time_first_run = context.profiler.time
runner.run()
return (context.profiler.time - profiler_time_first_run) * 1000.0
def runtime_profile(self, layer, layer_attrs, input_values, strategy,
device_mesh):
logger.debug(f"start to profile layer {layer.name}")
shapes = {}
values = {}
dtypes = {}
trt_layer = layer.as_trt()
sharding_sequences = ()
for i in range(layer.num_inputs):
input = trt_layer.get_input(i)
if input is not None:
shapes[input.name] = strategy.sharding_specs[
f'input{i}'].get_sharded_shape_per_device()
dtypes[input.name] = input.dtype
sharding_sequences += (str(
strategy.sharding_specs[f"input{i}"].sharding_sequence), )
if i in input_values:
values[input.name] = input_values[i]
else:
value = layer.get_input(i).value
if value is not None:
values[input.name] = value
else:
sharding_sequences += (None, )
for i in range(layer.num_outputs):
output = trt_layer.get_output(i)
if f'output{i}' in strategy.communication_actions:
shapes[output.name] = strategy.communication_actions[
f'output{i}'].sharding_spec.get_sharded_shape_per_device()
else:
shapes[output.name] = strategy.sharding_specs[
f'output{i}'].get_sharded_shape_per_device()
dtypes[output.name] = output.dtype
sharding_sequences += (str(
strategy.sharding_specs[f"output{i}"].sharding_sequence), )
data_key = get_cache_key(
trt_layer,
shapes,
values,
dtypes=dtypes,
updated_attrs=layer_attrs,
)
data_key += (sharding_sequences, )
elapsed_time = device_mesh.prof_database.query(
device_mesh.cluster_key,
data_key,
)
if elapsed_time:
logger.debug(
f'runtime profiling cache hit {data_key}: {elapsed_time} us')
return elapsed_time
with NvtxProfiler(f'{layer.name}_{data_key}', enable=True):
elapsed_time = self._profile(
layer.as_trt(),
layer_attrs,
shapes,
values,
layer.graph._io_buffer_mapping,
)
logger.debug(
f'runtime profiling cache miss {data_key}: {elapsed_time} us')
device_mesh.prof_database.update(
device_mesh.cluster_key,
data_key,
(elapsed_time, strategy.alpha_beta_cost),
)
return elapsed_time

View File

@ -1,362 +0,0 @@
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Set
import numpy as np
import tensorrt as trt
import torch
from tensorrt_llm._common import _is_building
from tensorrt_llm._utils import (trt_dtype_to_np, trt_dtype_to_str,
trt_dtype_to_torch)
from tensorrt_llm.logger import logger
from .pipeline_graph import PipelineGraph
from .utils import (get_builder_flags, get_cache_key, get_sorted_layer_ids,
set_trt_network, to_base_class_layer, to_subclass_layer,
to_trt_weights)
class ShapeType(Enum):
MIN = 0
OPT = 1
MAX = 2
_trt_to_type_dict = {
trt.int64: int,
trt.bool: bool,
}
def get_shape_layers(trt_network):
shape_layers = set()
for i in range(trt_network.num_layers):
layer = trt_network.get_layer(i)
if (layer.num_inputs > 0 and np.all([
layer.get_input(j).is_shape_tensor
for j in range(layer.num_inputs)
if layer.get_input(j) is not None
])) or (layer.num_outputs > 0 and np.all([
layer.get_output(j).is_shape_tensor
for j in range(layer.num_outputs)
])):
shape_layers.add(layer.name)
return shape_layers
def get_layers_in_shape_network(trt_network, shape_layers, sorted_layer_ids):
layers = set()
shape_tensors = set()
for layer_id in reversed(sorted_layer_ids):
layer = trt_network.get_layer(layer_id)
in_shape_network = False
if layer.name in shape_layers:
in_shape_network = True
else:
for j in range(layer.num_outputs):
output = layer.get_output(j)
if output.name in shape_tensors:
in_shape_network = True
break
if in_shape_network:
layers.add(layer.name)
for j in range(layer.num_inputs):
input = layer.get_input(j)
if input is not None:
shape_tensors.add(input.name)
return layers
def get_shape_network(trt_network,
shapes,
values,
sorted_layer_ids,
profile=None,
shape_type: ShapeType = ShapeType.OPT):
shape_layers = get_shape_layers(trt_network)
layers_in_shape_network = get_layers_in_shape_network(
trt_network, shape_layers, sorted_layer_ids)
shape_graph = PipelineGraph.create_graph()
shape_network = shape_graph.as_trt()
shape_builder = shape_network.builder
shape_profile = shape_builder.create_optimization_profile()
for i in range(trt_network.num_inputs):
input = trt_network.get_input(i)
shapes[input.name] = input.shape
new_input = shape_graph.add_input(input)
if profile is not None:
if -1 in input.shape:
shape = profile.get_shape(input.name)
shape = shape[shape_type.value]
shapes[input.name] = shape
new_input.raw_shape = shape
if input.is_shape_tensor:
shape_values = profile.get_shape_input(input.name)
value = shape_values[shape_type.value]
values[input.name] = value
shape_profile.set_shape_input(input.name, value, value, value)
output_mapping = {}
for layer_id in sorted_layer_ids:
layer = trt_network.get_layer(layer_id)
if layer.name in shape_layers:
new_layer = shape_graph.add_layer(layer)
for i in range(layer.num_outputs):
output = layer.get_output(i)
if output.dtype == trt.DataType.BOOL:
proxy_layer = shape_network.add_cast(
new_layer.as_trt().get_output(i),
trt.DataType.INT32,
)
proxy_output = proxy_layer.get_output(0)
shape_graph.register_layer(proxy_layer)
shape_graph.add_output_shape(proxy_output)
output_mapping[proxy_output.name] = (output.name,
output.dtype)
else:
shape_graph.add_output_shape(output)
elif layer.name in layers_in_shape_network:
if layer.type == trt.LayerType.CONSTANT:
shape_graph.add_input(layer.get_output(0))
else:
shape_graph.add_layer(layer)
return shape_network, shape_profile, shape_layers, output_mapping
def get_per_layer_graph(
layer,
shapes,
values,
updated_attrs=None,
is_shape_io: bool = None,
):
graph = PipelineGraph.create_graph()
network = graph.as_trt()
is_shape_layer = layer.num_inputs != 0
for i in range(layer.num_inputs):
input = layer.get_input(i)
if input is not None:
shape = shapes[input.name]
if (values.get(input.name) is not None
and not isinstance(values[input.name], torch.Tensor)):
value = values[input.name]
weights = np.asarray(value, dtype=trt_dtype_to_np(input.dtype))
weights = to_trt_weights(weights)
input_layer = network.add_constant(shape, weights)
new_input = input_layer.get_output(0)
new_input.name = input.name
graph.register_layer(input_layer)
elif graph.get_input(input.name) is None:
new_input = graph.add_input(input)
new_input.raw_shape = shapes[input.name]
is_shape_layer = False
new_layer = graph.add_layer(
layer,
updated_attrs=updated_attrs,
)
output_mapping = {}
if layer.type == trt.LayerType.SHAPE:
is_shape_layer = True
if layer.num_inputs == 0:
is_shape_layer = False
if is_shape_io is not None:
is_shape_layer = is_shape_io
for i in range(layer.num_outputs):
output = layer.get_output(i)
value = values.get(output.name)
if value is not None and isinstance(value, torch.Tensor):
is_output_shape = False
elif is_shape_layer:
is_output_shape = True
else:
is_output_shape = False
if is_output_shape:
if output.dtype == trt.DataType.BOOL:
proxy_layer = network.add_cast(
new_layer.as_trt().get_output(i),
trt.DataType.INT32,
)
proxy_output = proxy_layer.get_output(0)
graph.register_layer(proxy_layer)
output_mapping[proxy_output.name] = (output.name, output.dtype)
output = proxy_output
graph.add_output_shape(output)
else:
graph.add_output(output)
return graph, output_mapping
@_is_building
def infer_shapes(network, shapes, values, profile=None):
if network.num_outputs == 0:
return
builder = network.builder
config = builder.create_builder_config()
config.builder_optimization_level = 0
config.flags = get_builder_flags()
profile = profile or builder.create_optimization_profile()
config.add_optimization_profile(profile)
plan = builder.build_serialized_network(network, config)
if plan is None:
raise RuntimeError(
'Engine building failed when inferring shapes, please check the error log.'
)
runtime = trt.Runtime(logger.trt_logger)
engine = runtime.deserialize_cuda_engine(plan)
context = engine.create_execution_context()
for i in range(network.num_inputs):
input = network.get_input(i)
if input.is_shape_tensor:
value = values[input.name]
context.set_shape_input(engine[input.name], value)
for i in range(network.num_outputs):
output = network.get_output(i)
shape = context.get_tensor_shape(output.name)
shapes[output.name] = shape
if output.is_shape_tensor:
if shape == [0]:
values[output.name] = []
else:
if shape == []:
shape = [1]
value = torch.empty(
list(shape),
dtype=trt_dtype_to_torch(output.dtype),
device="cpu",
)
values[output.name] = value
context.set_tensor_address(output.name, value.data_ptr())
context.infer_shapes()
assert context.all_binding_shapes_specified
for i in range(network.num_outputs):
output = network.get_output(i)
if isinstance(values.get(output.name), torch.Tensor):
values[output.name] = values[output.name].tolist()
@dataclass
class ShapeInfo:
shapes: Dict[str, trt.Dims]
values: Dict[str, List[int]]
shape_layers: Set[str]
max_shapes: Dict[str, trt.Dims] = None
def set_constant_value(layer, values):
to_subclass_layer(layer)
output_name = layer.get_output(0).name
weights = layer.weights
if isinstance(weights, trt.Weights):
weights = weights.numpy()
values[output_name] = list(weights)
to_base_class_layer(layer)
def infer_per_layer_shapes(
layer: trt.ILayer,
shapes,
values,
cache=None,
is_shape_io=False,
):
if layer.type == trt.LayerType.CONSTANT:
to_subclass_layer(layer)
output_name = layer.get_output(0).name
shape = layer.shape
shapes[output_name] = shape
if is_shape_io:
set_constant_value(layer, values)
to_base_class_layer(layer)
return
elif layer.type == trt.LayerType.SHAPE:
input_name = layer.get_input(0).name
output_name = layer.get_output(0).name
shape = [*shapes[input_name]]
shapes[output_name] = trt.Dims([len(shape)])
values[output_name] = shape
return
if cache is not None:
cache_key = get_cache_key(layer, shapes, values)
if cache_key in cache:
output_shapes, output_values = cache[cache_key]
for i in range(layer.num_outputs):
output = layer.get_output(i)
shapes[output.name] = output_shapes[i]
if output_values[i] is not None:
values[output.name] = output_values[i]
return
graph, output_mapping = get_per_layer_graph(layer, shapes, values)
dtypes = [
trt_dtype_to_str(layer.get_input(i).dtype)
for i in range(layer.num_inputs)
]
layer_info = (f"type={cache_key[0]}, "
f"attrs={dict(cache_key[1])}, "
f"dtypes={dtypes}, "
f"shapes={list(cache_key[2])}, "
f"values={list(cache_key[3])}")
logger.debug(f"infer shapes for layer {layer.name} ({layer_info})")
try:
infer_shapes(graph.as_trt(), shapes, values)
except RuntimeError as e:
raise RuntimeError(
f"infer shapes failed for layer {layer.name} ({layer_info})") from e
for proxy_output, (output, dtype) in output_mapping.items():
shapes[output] = shapes[proxy_output]
del shapes[proxy_output]
if proxy_output in values:
values[output] = [
*map(_trt_to_type_dict[dtype], values[proxy_output])
]
del values[proxy_output]
if cache is not None:
logger.debug(
f"shape inference cache miss, layer: {layer.name}, cache key: {cache_key}"
)
output_shapes = []
output_values = []
for i in range(layer.num_outputs):
output = layer.get_output(i)
output_shapes.append(shapes[output.name])
output_values.append(values.get(output.name))
cache[cache_key] = (output_shapes, output_values)
def get_shape_info(trt_network, profile, shape_type: ShapeType = ShapeType.OPT):
shapes = {}
values = {}
sorted_layer_ids = get_sorted_layer_ids(trt_network)
infer_shape_layers = False
shape_network, shape_profile, shape_layers, output_mapping = get_shape_network(
trt_network,
shapes,
values,
sorted_layer_ids,
profile=profile,
shape_type=shape_type)
try:
infer_shapes(shape_network, shapes, values, shape_profile)
for proxy_output, (output, dtype) in output_mapping.items():
shapes[output] = shapes[proxy_output]
values[output] = [
*map(_trt_to_type_dict[dtype], values[proxy_output])
]
del shapes[proxy_output]
del values[proxy_output]
except RuntimeError:
infer_shape_layers = True
cache = {}
for layer_id in sorted_layer_ids:
layer = trt_network.get_layer(layer_id)
is_shape_io = layer.name in shape_layers
if is_shape_io and not infer_shape_layers:
continue
set_trt_network(layer, trt_network)
infer_per_layer_shapes(layer,
shapes,
values,
cache,
is_shape_io=is_shape_io)
return ShapeInfo(shapes, values, shape_layers)

View File

@ -1,837 +0,0 @@
import math
import re
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Tuple
import numpy as np
from tensorrt_llm.network import Network
from .config import AutoParallelConfig
from .device_mesh import PhysicalDeviceMesh
from .pipeline_graph import PipelineGraph
from .shape_info import ShapeInfo, ShapeType, get_shape_info
from .tensor_parallel.p2p_node import P2PType
from .utils import get_cache_key, get_sorted_layer_ids, silent_trt_logger
class StageType(Enum):
START = 0
BLOCK = 1
END = 2
class BuildingBlock:
def __init__(self, graph, layer_range) -> None:
self.graph = graph
self.layer_range = layer_range
self.network = graph.as_trt()
self.owned_inputs = {}
self.is_edges_collected = False
self.intra_edges = []
self.src_inter_edges = []
self.dst_inter_edges = []
self.relative_src_inter_edges = []
self.relative_dst_inter_edges = []
self.relative_inter_edges = set()
self.edge_hash = None
self.outputs = None
self.type_id = -1
self.block_id = -1
self.p2p_type = None
self.is_superset = False
self.is_subset = False
self.sorted_layer_ids = []
def collect_edges(self):
if self.is_edges_collected:
return
for layer_index in self.layer_range:
trt_layer = self.network.get_layer(layer_index)
layer = self.graph.get_layer(trt_layer.name)
layer_offset = layer.index - self.layer_range.start
for input_index, input in enumerate(layer.inputs):
if input is not None:
if input.is_graph_input:
is_owned = input.graph_input_index in self.owned_inputs
if not is_owned and np.all([
layer.index in self.layer_range or np.all([
output.as_trt().is_shape_tensor
for output in layer.outputs
]) for layer, _ in input.consumers
]):
self.owned_inputs[input.graph_input_index] = len(
self.owned_inputs)
is_owned = True
if is_owned:
self.intra_edges.append(
(-1, self.owned_inputs[input.graph_input_index],
layer_offset, input_index))
else:
self.dst_inter_edges.append(
(-1, input.graph_input_index, layer_offset,
input_index))
else:
src_layer_index = input.producer.index
if src_layer_index < self.layer_range.start or src_layer_index >= self.layer_range.stop:
self.dst_inter_edges.append(
(src_layer_index, input.output_index,
layer_offset, input_index))
else:
src_layer_offset = src_layer_index - self.layer_range.start
self.intra_edges.append(
(src_layer_offset, input.output_index,
layer_offset, input_index))
for output_index, output in enumerate(layer.outputs):
for dst_layer, dst_input_index in output.consumers:
dst_layer_index = dst_layer.index
if dst_layer_index < self.layer_range.start or dst_layer_index >= self.layer_range.stop:
self.src_inter_edges.append(
(layer_offset, output_index, dst_layer_index,
dst_input_index))
self.edge_hash = tuple(self.intra_edges)
self.outputs = sorted(
set((edge[0], edge[1]) for edge in self.src_inter_edges))
self.is_edges_collected = True
def collect_relative_inter_edges(self, layer_to_block):
self.collect_edges()
for src_layer_index, src_output_index, dst_layer_index, dst_input_index in self.dst_inter_edges:
if src_layer_index in layer_to_block:
src_block = layer_to_block[src_layer_index]
src_layer_offset = src_layer_index - src_block.layer_range.start
dst = (self.type_id, dst_layer_index, dst_input_index)
self.relative_dst_inter_edges.append(
(src_block.type_id, src_layer_offset, src_output_index,
*dst))
else:
self.relative_dst_inter_edges.append(
(-1, src_layer_index, src_output_index, self.type_id,
dst_layer_index, dst_input_index))
self.relative_inter_edges = set(self.relative_dst_inter_edges +
self.outputs)
def get_input_names(self):
self.collect_edges()
input_tensor_names = []
for edge in self.dst_inter_edges:
layer_index = edge[0]
output_index = edge[1]
if layer_index == -1:
tensor_name = self.network.get_input(output_index).name
else:
tensor_name = self.network.get_layer(layer_index).get_output(
output_index).name
input_tensor_names.append(tensor_name)
return input_tensor_names
def get_input_mapping(self, last_blocks):
input_mapping = {}
for tensor_name, relative_edge in zip(self.get_input_names(),
self.relative_dst_inter_edges):
type_id = relative_edge[0]
output_index = relative_edge[2]
if type_id >= 0:
last_block = last_blocks[type_id]
layer_offset = relative_edge[1]
mapped_layer_index = last_block.layer_range.start + layer_offset
mapped_tensor_name = self.network.get_layer(
mapped_layer_index).get_output(output_index).name
input_mapping[tensor_name] = mapped_tensor_name
else:
input_mapping[tensor_name] = tensor_name
return input_mapping
@dataclass
class GraphMapping:
layer_mapping: Dict[int, int] = None
block_mapping: Dict[int, int] = None
p2p_types: Dict[int, P2PType] = None
p2p_tensors: Dict[int, List[str]] = None
block_to_stage: Dict[int, int] = None
same_spec_layer_mapping: Dict[str, str] = None
@dataclass
class GraphConfig:
num_micro_batches: int = 1
num_blocks: int = 1
num_stages: int = 1
has_cross_device: bool = False
has_cross_host: bool = False
graph_mapping: GraphMapping = None
phy_mesh: PhysicalDeviceMesh = None
stage_phy_meshes: List[PhysicalDeviceMesh] = None
class Simplifier:
def __init__(self, network: Network, config: AutoParallelConfig):
self.config = config
self.sharded_io_allowlist = config.sharded_io_allowlist
self.same_buffer_io = config.same_buffer_io
self.same_spec_io = config.same_spec_io.copy()
for key, value in self.same_buffer_io.items():
if key not in self.same_spec_io:
self.same_spec_io[key] = value
self.llm_network = network
self.network = network.trt_network
self.module_to_layer_range_map = network._module_call_stack.module_to_layer_range_map
self.graph = self.get_graph()
self.init_layer_hash()
module_tree = self.get_module_tree()
building_blocks = self.collect_building_blocks(module_tree)
blocks_by_module_hash = self.get_blocks_by_module_hash(building_blocks)
self.blocks_by_edge_hash = self.get_blocks_by_edge_hash(
blocks_by_module_hash)
self.layer_to_block = self.get_layer_to_block()
self.blocks = self.get_all_blocks()
self.backbone_blocks = self.get_backbone_blocks()
self.graph_mapping_for_shape = self.get_graph_mapping_for_shape()
self.graph_for_shape = self.create_simplified_graph_for_shape()
self.shape_info = None
self.num_micro_batches = None
def infer_shapes(self, num_micro_batches):
if self.num_micro_batches == num_micro_batches:
return
with silent_trt_logger():
self.shape_info = self.get_full_shape_info(num_micro_batches)
self.graph.assign_shapes(self.shape_info)
self.num_micro_batches = num_micro_batches
def list_all_num_micro_batches(self):
opt_batch_size = self.get_opt_batch_size()
candidates = []
for num_micro_batches in range(1, self.get_opt_batch_size() + 1):
if opt_batch_size % num_micro_batches == 0:
candidates.append(num_micro_batches)
return candidates
def get_graph(self):
graph = PipelineGraph.from_trt(self.network)
graph._unfilled_weights = self.llm_network._unfilled_weights.copy()
graph._io_buffer_mapping
for input in graph.inputs:
input_name = input.name
for pattern, repl in self.same_buffer_io.items():
if re.match(pattern, input_name):
output_name = re.sub(pattern, repl, input_name)
output = graph.get_output(output_name)
if output is not None:
graph._io_buffer_mapping[output_name] = input_name
return graph
def get_opt_batch_size(self):
input_tensors = self.llm_network._inputs
num_profiles = len(list(input_tensors.values())[0].profiles)
opt_batch_sizes = []
for i in range(num_profiles):
for input_tensor in input_tensors.values():
shape_profile = input_tensor.profiles[i]
opt_shape = shape_profile.opt
for j in range(len(input_tensor.shape)):
name = input_tensor.trt_tensor.get_dimension_name(j)
if name == 'batch_size':
opt_batch_sizes.append(opt_shape[j])
return min(opt_batch_sizes)
def get_module_hash(self, layer_range):
module_hash = ()
for i in layer_range:
assert i < self.network.num_layers, f"layer index {i} in {layer_range} out of range of {self.network.num_layers}"
layer_name = self.network.get_layer(i).name
layer = self.graph.get_layer(layer_name)
module_hash += (layer.attrs["hash"], )
return module_hash
def get_network_hash(self) -> str:
return str(self.get_module_hash(range(self.network.num_layers)))
def collect_building_blocks(self, module_tree):
building_blocks = {}
queue = []
for tree in module_tree["children"].values():
queue.append(tree)
while len(queue) > 0:
while len(queue) > 0:
tree = queue.pop(0)
module_name = tree["name"]
if module_name is None:
for child in tree["children"].values():
queue.append(child)
continue
layer_range = self.module_to_layer_range_map[module_name]
module_hash = self.get_module_hash(layer_range)
if module_hash in building_blocks:
building_blocks[module_hash].append(tree)
else:
building_blocks[module_hash] = [tree]
for module_hash in [*building_blocks.keys()]:
if len(building_blocks[module_hash]) == 1:
tree = building_blocks[module_hash][0]
for child in tree["children"].values():
queue.append(child)
del building_blocks[module_hash]
blocks_by_module_hash = {
module_hash: [
BuildingBlock(self.graph,
self.module_to_layer_range_map[tree["name"]])
for tree in trees
]
for module_hash, trees in building_blocks.items()
}
building_blocks = []
for block_list in blocks_by_module_hash.values():
for block in block_list:
building_blocks.append(block)
building_blocks = sorted(building_blocks,
key=lambda x: x.layer_range.start)
if len(building_blocks) >= 2:
for block, next_block in zip(building_blocks[:-1],
building_blocks[1:]):
block.layer_range = range(block.layer_range.start,
next_block.layer_range.start)
return building_blocks
def get_all_blocks(self):
building_blocks = []
for block_list in self.blocks_by_edge_hash.values():
for block in block_list:
building_blocks.append(block)
building_blocks = sorted(building_blocks,
key=lambda x: x.layer_range.start)
all_blocks = []
current_layer_index = 0
block_id = 0
for block in building_blocks:
assert current_layer_index <= block.layer_range.start
if current_layer_index < block.layer_range.start:
new_block = BuildingBlock(
self.graph,
range(current_layer_index, block.layer_range.start))
new_block.block_id = block_id
block_id += 1
all_blocks.append(new_block)
block.block_id = block_id
block_id += 1
all_blocks.append(block)
current_layer_index = block.layer_range.stop
if current_layer_index < self.graph.num_layers:
new_block = BuildingBlock(
self.graph, range(current_layer_index, self.graph.num_layers))
new_block.block_id = block_id
all_blocks.append(new_block)
sorted_layer_ids = get_sorted_layer_ids(self.network)
for block in all_blocks:
block.collect_relative_inter_edges(self.layer_to_block)
for layer_id in sorted_layer_ids:
if layer_id in block.layer_range:
block.sorted_layer_ids.append(layer_id)
return all_blocks
def get_backbone_blocks(self):
sorted_blocks = sorted(
self.blocks_by_edge_hash.values(),
key=lambda blocks: (len(blocks), len(blocks[0].layer_range)),
)
if len(sorted_blocks) == 0:
return []
else:
return sorted_blocks[-1]
def get_blocks_by_module_hash(self, blocks):
blocks_by_module_hash = {}
for block in blocks:
module_hash = self.get_module_hash(block.layer_range)
if module_hash not in blocks_by_module_hash:
blocks_by_module_hash[module_hash] = []
blocks_by_module_hash[module_hash].append(block)
for module_hash in [*blocks_by_module_hash.keys()]:
if len(blocks_by_module_hash[module_hash]) == 1:
del blocks_by_module_hash[module_hash]
return blocks_by_module_hash
def get_module_tree(self):
module_tree = {"children": {}, "name": None}
for module_name in self.module_to_layer_range_map.keys():
full_name = module_name.split('.')
current_tree = module_tree["children"]
for depth, name in enumerate(full_name):
if name not in current_tree:
current_tree[name] = {"children": {}, "name": None}
if depth == len(full_name) - 1:
current_tree[name]["name"] = module_name
else:
current_tree = current_tree[name]["children"]
return module_tree
def get_blocks_by_edge_hash(self, blocks_by_module_hash):
blocks_by_edge_hash = {}
for block_list in blocks_by_module_hash.values():
for block in block_list:
block.collect_edges()
edge_hash = block.edge_hash
if edge_hash not in blocks_by_edge_hash:
blocks_by_edge_hash[edge_hash] = []
blocks_by_edge_hash[edge_hash].append(block)
for edge_hash in [*blocks_by_edge_hash.keys()]:
if len(blocks_by_edge_hash[edge_hash]) == 1:
del blocks_by_edge_hash[edge_hash]
else:
block_list = blocks_by_edge_hash[edge_hash]
blocks_by_edge_hash[edge_hash] = sorted(
block_list, key=lambda x: x.layer_range.start)
for type_id, block_list in enumerate(blocks_by_edge_hash.values()):
for block in block_list:
block.type_id = type_id
return blocks_by_edge_hash
def get_layer_to_block(self):
layer_to_block = {}
for block_list in self.blocks_by_edge_hash.values():
for block in block_list:
for layer_index in block.layer_range:
layer_to_block[layer_index] = block
return layer_to_block
def clean_blocks(self):
for block in self.blocks:
block.p2p_type = None
block.is_superset = False
block.is_subset = False
def mark_p2p_type(self, phy_mesh, stage_phy_meshes,
graph_config: GraphConfig):
if len(self.backbone_blocks) == 0 or len(stage_phy_meshes) == 1:
return
assert len(self.backbone_blocks) % len(stage_phy_meshes) == 0
block_per_stage = len(self.backbone_blocks) // len(stage_phy_meshes)
for block in self.backbone_blocks:
block.p2p_type = None
for stage_index, stage_phy_mesh in enumerate(stage_phy_meshes[:-1]):
next_stage_phy_mesh = stage_phy_meshes[stage_index + 1]
last_device_id = stage_phy_mesh.phy_devices_id.flatten()[-1]
next_first_device_id = next_stage_phy_mesh.phy_devices_id.flatten(
)[0]
num_devices_per_host = phy_mesh.num_devices_per_host
next_block = self.backbone_blocks[(stage_index + 1) *
block_per_stage]
if last_device_id // num_devices_per_host != next_first_device_id // num_devices_per_host:
next_block.p2p_type = P2PType.CROSS_HOST
graph_config.has_cross_host = True
else:
next_block.p2p_type = P2PType.CROSS_DEVICE
graph_config.has_cross_device = True
def get_graph_mapping(self):
layer_mapping = {}
block_mapping = {}
p2p_types = {}
p2p_tensors = {}
for block_list in self.blocks_by_edge_hash.values():
superset_blocks = []
superset_block_index = {}
for block in block_list:
block_added = False
for index, superset_block in enumerate(list(superset_blocks)):
if block.p2p_type == superset_block.p2p_type:
if block.relative_inter_edges.issubset(
superset_block.relative_inter_edges):
block.is_subset = True
block.is_superset = False
superset_block_index[id(block)] = index
block_added = True
break
elif superset_block.relative_inter_edges.issubset(
block.relative_inter_edges):
superset_block.is_subset = True
superset_block.is_superset = False
block.is_subset = False
block.is_superset = True
superset_blocks[index] = block
superset_block_index[id(block)] = index
block_added = True
break
if not block_added:
block.is_subset = False
block.is_superset = True
superset_blocks.append(block)
superset_block_index[id(block)] = len(superset_blocks) - 1
for block in block_list:
assert not (block.is_subset and block.is_superset)
if block.is_subset:
superset_block = superset_blocks[superset_block_index[id(
block)]]
block_mapping[block.block_id] = superset_block.block_id
owned_inputs = map(
lambda x: x[0],
sorted(block.owned_inputs.items(), key=lambda x: x[1]))
superset_owned_inputs = map(
lambda x: x[0],
sorted(superset_block.owned_inputs.items(),
key=lambda x: x[1]))
for from_input_id, to_input_id in zip(
owned_inputs, superset_owned_inputs):
from_input_name = self.network.get_input(
from_input_id).name
to_input_name = self.network.get_input(to_input_id).name
layer_mapping[from_input_name] = to_input_name
for from_layer_id, to_layer_id in zip(
block.layer_range, superset_block.layer_range):
from_layer = self.network.get_layer(from_layer_id)
to_layer = self.network.get_layer(to_layer_id)
layer_mapping[from_layer.name] = to_layer.name
for i in range(from_layer.num_outputs):
from_output = from_layer.get_output(i)
if from_output.is_network_output:
to_output = to_layer.get_output(i)
layer_mapping[from_output.name] = to_output.name
if block.p2p_type is not None:
p2p_types[block.block_id] = block.p2p_type
p2p_tensors[block.block_id] = [
*set(block.get_input_names())
]
for from_name, to_name in zip(
block.get_input_names(),
superset_block.get_input_names()):
layer_mapping[
f"p2p_block{block.block_id}_{from_name}"] = f"p2p_block{superset_block.block_id}_{to_name}"
stage_id = 0
block_to_stage = {}
for block in self.blocks:
if block.p2p_type is not None:
stage_id += 1
block_to_stage[block.block_id] = stage_id
return GraphMapping(
layer_mapping,
block_mapping,
p2p_types,
p2p_tensors,
block_to_stage,
)
def create_simplified_graph(self, graph_config: GraphConfig):
new_graph = PipelineGraph.create_graph()
new_graph._io_buffer_mapping = self.graph._io_buffer_mapping
layer_mapping = graph_config.graph_mapping.layer_mapping
for i in range(self.network.num_inputs):
trt_input = self.network.get_input(i)
if trt_input.name not in layer_mapping:
new_graph.add_input(trt_input)
last_blocks = {}
same_spec_mapping = {}
same_spec_layer_mapping = {}
shape_mapping = {}
building_block_id = 0
same_spec_ids = {}
same_spec_count = 0
for block in self.blocks:
if not block.is_subset:
stage_type = None
if not block.is_superset:
if block.block_id == 0:
stage_type = StageType.START
elif block.block_id == len(self.blocks) - 1:
stage_type = StageType.END
input_mapping = block.get_input_mapping(last_blocks)
for from_name, to_name in [*input_mapping.items()]:
if to_name in same_spec_mapping:
input_mapping[from_name] = same_spec_mapping[to_name]
if to_name in layer_mapping:
input_mapping[from_name] = layer_mapping[to_name]
if block.is_superset and block.p2p_type is not None:
for from_name, to_name in [*input_mapping.items()]:
output_tensor = new_graph.get_tensor(to_name)
p2p_layer = new_graph.as_trt().add_identity(
output_tensor.as_trt())
p2p_layer.name = f"p2p_block{block.block_id}_{from_name}"
p2p_layer.metadata = p2p_layer.name
p2p_tensor = p2p_layer.get_output(0)
p2p_tensor.name = f"{p2p_layer.name}_output"
wrapped_layer = new_graph.register_layer(p2p_layer)
wrapped_layer.attrs[
"building_block_id"] = building_block_id
wrapped_layer.attrs["p2p_type"] = block.p2p_type
input_mapping[from_name] = p2p_tensor.name
shape_mapping[p2p_tensor.name] = from_name
building_block_id += 1
for i in block.sorted_layer_ids:
layer = self.network.get_layer(i)
wrapped_layer = new_graph.add_layer(
layer,
input_mapping=input_mapping,
)
wrapped_layer.attrs["building_block_id"] = building_block_id
wrapped_layer.attrs["stage_type"] = stage_type
if block.is_superset:
last_blocks[block.type_id] = block
if block.type_id in same_spec_ids:
same_spec_id = same_spec_ids[block.type_id]
update_same_spec_count = False
else:
same_spec_id = same_spec_count
same_spec_ids[block.type_id] = same_spec_id
update_same_spec_count = True
count = same_spec_id
for i, (layer_offset,
output_index) in enumerate(block.outputs):
layer = self.network.get_layer(block.layer_range.start +
layer_offset)
tensor_name = layer.get_output(output_index).name
output_tensor = new_graph.get_tensor(tensor_name)
same_spec_layer = new_graph.as_trt().add_identity(
output_tensor.as_trt())
same_spec_layer.name = f"{tensor_name}_same_spec"
same_spec_layer.metadata = same_spec_layer.name
same_spec_tensor = same_spec_layer.get_output(0)
same_spec_tensor.name = f"{same_spec_layer.name}_output"
wrapped_layer = new_graph.register_layer(
same_spec_layer)
wrapped_layer.attrs[
"building_block_id"] = building_block_id
wrapped_layer.attrs["same_spec_id"] = count
count += 1
same_spec_mapping[tensor_name] = same_spec_tensor.name
same_spec_layer_mapping[
same_spec_layer.name] = layer.name
shape_mapping[same_spec_tensor.name] = tensor_name
for i, graph_input_index in enumerate(
block.owned_inputs.keys()):
input_name = self.network.get_input(
graph_input_index).name
input_tensor = new_graph.get_input(input_name)
input_tensor.attrs["same_spec_id"] = count
count += 1
if update_same_spec_count:
same_spec_count = count
building_block_id += 1
graph_config.graph_mapping.same_spec_layer_mapping = same_spec_layer_mapping
if len(self.backbone_blocks) >= 2:
start_block = self.backbone_blocks[0]
if start_block.is_subset:
start_block = self.blocks[graph_config.graph_mapping.
block_mapping[start_block.block_id]]
for i in start_block.layer_range:
layer_name = self.network.get_layer(i).name
layer = new_graph.get_layer(layer_name)
layer.attrs["in_start_block"] = True
end_block = self.backbone_blocks[-1]
if end_block.is_subset:
end_block = self.blocks[graph_config.graph_mapping.
block_mapping[end_block.block_id]]
for i in end_block.layer_range:
layer_name = self.network.get_layer(i).name
layer = new_graph.get_layer(layer_name)
layer.attrs["in_end_block"] = True
slowest_p2p_type = None
if graph_config.has_cross_host:
slowest_p2p_type = P2PType.CROSS_HOST
elif graph_config.has_cross_device:
slowest_p2p_type = P2PType.CROSS_DEVICE
if slowest_p2p_type is not None:
for block in self.blocks:
if block.is_superset and block.p2p_type == slowest_p2p_type:
for i in block.layer_range:
layer_name = self.network.get_layer(i).name
layer = new_graph.get_layer(layer_name)
layer.attrs["in_slowest_block"] = True
for i in range(self.network.num_outputs):
trt_output = self.network.get_output(i)
output = self.graph.get_output(trt_output.name)
if output.producer is not None and output.producer.index in self.layer_to_block and self.layer_to_block[
output.producer.index].is_subset:
continue
if trt_output.is_shape_tensor:
new_output = new_graph.add_output_shape(trt_output)
else:
new_output = new_graph.add_output(trt_output)
sharded_io = False
for pattern in self.sharded_io_allowlist:
if re.match(pattern, new_output.name):
sharded_io = True
break
if not sharded_io:
new_output.producer.attrs["is_replicated"] = True
for input in new_graph.inputs:
input_name = input.name
sharded_io = False
for pattern in self.sharded_io_allowlist:
if re.match(pattern, input_name):
sharded_io = True
break
if not sharded_io:
input.attrs["is_replicated"] = True
for pattern, repl in self.same_spec_io.items():
if re.match(pattern, input_name):
output_name = re.sub(pattern, repl, input_name)
output = new_graph.get_output(output_name)
if output is not None:
if "same_spec_id" in input.attrs:
same_spec_id = input.attrs["same_spec_id"]
else:
same_spec_id = same_spec_count
same_spec_count += 1
input.attrs["same_spec_id"] = same_spec_id
output.attrs["same_spec_id"] = same_spec_id
if math.prod(self.graph.get_input(
input_name).shape) < math.prod(
self.graph.get_output(output_name).shape):
input.attrs["no_memory_footprint"] = True
else:
output.attrs["no_memory_footprint"] = True
return new_graph, shape_mapping
def enrich_shape_info(self, shape_mapping):
shapes = self.shape_info.shapes.copy()
max_shapes = self.shape_info.max_shapes.copy()
values = self.shape_info.values.copy()
shape_layers = self.shape_info.shape_layers
for from_name, to_name in shape_mapping.items():
if to_name in shapes:
shapes[from_name] = shapes[to_name]
if to_name in max_shapes:
max_shapes[from_name] = max_shapes[to_name]
if to_name in values:
values[from_name] = values[to_name]
shape_info = ShapeInfo(shapes, values, shape_layers, max_shapes)
return shape_info
def simplify_graph(
self, phy_mesh: PhysicalDeviceMesh, num_stages: int,
num_devices_per_stage: int) -> Tuple[PipelineGraph, GraphConfig]:
num_blocks = len(self.backbone_blocks)
if num_blocks % num_stages != 0:
return None, None
graph_config = GraphConfig()
graph_config.num_micro_batches = self.num_micro_batches
graph_config.num_blocks = num_blocks
graph_config.num_stages = num_stages
graph_config.phy_mesh = phy_mesh
stage_phy_meshes = phy_mesh.split_pipeline_meshes(
num_stages, num_devices_per_stage)
graph_config.stage_phy_meshes = stage_phy_meshes
with silent_trt_logger():
self.clean_blocks()
self.mark_p2p_type(phy_mesh, stage_phy_meshes, graph_config)
graph_config.graph_mapping = self.get_graph_mapping()
new_graph, shape_mapping = self.create_simplified_graph(
graph_config)
shape_info = self.enrich_shape_info(shape_mapping)
new_graph.assign_shapes(shape_info)
return new_graph, graph_config
def get_graph_mapping_for_shape(self):
layer_mapping = {}
tensor_mapping = {}
for block_list in self.blocks_by_edge_hash.values():
head_block = block_list[0]
for block in block_list[1:]:
for from_layer_id, to_layer_id in zip(block.layer_range,
head_block.layer_range):
from_layer = self.network.get_layer(from_layer_id)
to_layer = self.network.get_layer(to_layer_id)
layer_mapping[from_layer.name] = to_layer.name
for i in range(from_layer.num_outputs):
tensor_mapping[from_layer.get_output(
i).name] = to_layer.get_output(i).name
return layer_mapping, tensor_mapping
def create_simplified_graph_for_shape(self):
new_graph = PipelineGraph.create_graph()
for i in range(self.network.num_inputs):
trt_input = self.network.get_input(i)
new_graph.add_input(trt_input)
head_blocks = {}
removed_blocks = set()
removed_layers = set()
for block_list in self.blocks_by_edge_hash.values():
head_block = block_list[0]
head_blocks[head_block.type_id] = head_block
for block in block_list[1:]:
removed_blocks.add(id(block))
for layer_index in block.layer_range:
removed_layers.add(layer_index)
for block in self.blocks:
if not id(block) in removed_blocks:
input_mapping = block.get_input_mapping(head_blocks)
for i in block.sorted_layer_ids:
layer = self.network.get_layer(i)
new_graph.add_layer(
layer,
input_mapping=input_mapping,
)
for i in range(self.network.num_outputs):
trt_output = self.network.get_output(i)
output = self.graph.get_output(trt_output.name)
if output.producer is not None and output.producer.index in removed_layers:
continue
if trt_output.is_shape_tensor:
new_graph.add_output_shape(trt_output)
else:
new_graph.add_output(trt_output)
return new_graph
def get_full_shape_info(self, num_micro_batches):
layer_mapping, tensor_mapping = self.graph_mapping_for_shape
optimization_profiles = self.llm_network._generate_optimization_profiles(
)
if len(optimization_profiles) > 0:
optimization_profile = optimization_profiles[-1]
else:
optimization_profile = None
shape_info = get_shape_info(self.graph_for_shape.as_trt(),
optimization_profile)
max_shape_info = get_shape_info(self.graph_for_shape.as_trt(),
optimization_profile,
shape_type=ShapeType.MAX)
shape_info.max_shapes = max_shape_info.shapes
for removed_tensor_name, tensor_name in tensor_mapping.items():
shape_info.shapes[removed_tensor_name] = shape_info.shapes[
tensor_name]
shape_info.max_shapes[removed_tensor_name] = shape_info.max_shapes[
tensor_name]
if tensor_name in shape_info.values:
shape_info.values[removed_tensor_name] = shape_info.values[
tensor_name]
for removed_layer_name, layer_name in layer_mapping.items():
if layer_name in shape_info.shape_layers:
shape_info.shape_layers.add(removed_layer_name)
return shape_info
def init_layer_hash(self):
with silent_trt_logger():
optimization_profiles = self.llm_network._generate_optimization_profiles(
)
if len(optimization_profiles) > 0:
optimization_profile = optimization_profiles[-1]
else:
optimization_profile = None
shape_info = get_shape_info(self.network, optimization_profile)
dtypes = {tensor.name: tensor.dtype for tensor in self.graph.tensors}
for layer in self.graph.layers:
layer_hash = get_cache_key(
layer.as_trt(),
shape_info.shapes,
shape_info.values,
dtypes,
)
layer.attrs["hash"] = layer_hash

View File

@ -1,641 +0,0 @@
"""This code is adapted from Alpa https://github.com/alpa-projects/alpa/ with some changes.
"""
import multiprocessing
import time
import warnings
from collections import defaultdict
import numpy as np
import pulp
from pulp import LpMinimize, LpProblem, LpVariable, lpDot, lpSum
from ..logger import logger
class Solution:
def __init__(self, leaf_strategies, s_val, e_val, edge_pairs,
node_index_dict, total_cost):
self.leaf_strategies = leaf_strategies
self.nodes = [
strategies_vector.node for strategies_vector in self.leaf_strategies
]
self.s_val = s_val
self.e_val = e_val
self.total_cost = total_cost
self.edge_pairs = list(np.reshape(edge_pairs, (-1, 2)))
self.node_index_dict = node_index_dict
self.index_node_dict = {}
for node, index in self.node_index_dict.items():
self.index_node_dict[index] = node
self.node_best_strategy = {}
self._annotate_strategy()
def _annotate_strategy(self):
self.node_best_strategy = {}
for index, node in enumerate(self.nodes):
best_strategy_id = self.s_val[index]
best_strategy = self.leaf_strategies[index][best_strategy_id]
self.node_best_strategy[node.node_name] = best_strategy
for edge_idx, edge_pair in enumerate(self.edge_pairs):
src_node = self.index_node_dict[edge_pair[0]]
dst_node = self.index_node_dict[edge_pair[1]]
src_node_index = self.node_index_dict[src_node]
for dst_pre_node in dst_node.predecessor_nodes:
if dst_pre_node is None:
continue
if src_node.node_name == dst_pre_node.node_name:
self.node_best_strategy[
dst_node.node_name].best_resharding_cost[
src_node.node_name] = [
self.node_best_strategy[dst_node.node_name].
resharding_costs[src_node.node_name][
self.s_val[src_node_index]]
]
def print_solution(self):
for index, node in enumerate(self.nodes):
best_strategy = self.node_best_strategy[node.node_name]
print(f'\n[{index}]: node_name = {node.node_name}')
best_strategy.print_strategy(best_resharding_cost_only=True)
print(f'solution total cost = {self.total_cost}')
class CostGraph:
'''
A graph data structure to simplify the edge cost graph. It has two main functions:
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
2. To reduce the searching space, we merge computationally-trivial operators, such as
element-wise operators, transpose, and reduction, into their following nodes. The merging information will
be given by the StrategiesVector depending on the type of target node and following nodes.
Argument:
leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
'''
def __init__(self, leaf_strategies):
self.leaf_strategies = leaf_strategies
self.nodes = [
strategies_vector.node for strategies_vector in leaf_strategies
]
# stores number of strategies in each node
self.node_strategies_vector = {}
for node, strategies_vector in zip(self.nodes, self.leaf_strategies):
self.node_strategies_vector[node] = strategies_vector
# extra_node_costs will store the extra costs introduced by merging nodes
self.extra_node_costs = {}
self.following_dict = {}
self._build_cost_graph()
def _remove_invalid_node(self, node, attr_name):
remove_list = []
target_node_list = getattr(node, attr_name, [])
for target_node in target_node_list:
if target_node not in self.nodes:
remove_list.append(target_node)
for element in remove_list:
target_node_list.remove(element)
def _build_cost_graph(self):
'''
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
set to node.
'''
self.edge_costs = {}
for dst_node, strategies_vector in zip(self.nodes,
self.leaf_strategies):
# build edge_cost
for src_node in dst_node.predecessor_nodes:
if src_node is None:
continue
if src_node not in self.nodes:
continue
node_pair = (src_node, dst_node)
edge_cost = {}
for i in range(len(strategies_vector)):
for j in range(len(self.node_strategies_vector[src_node])):
resharding_cost = strategies_vector[i].resharding_costs[
src_node.node_name][j][-1]
edge_cost[(j, i)] = resharding_cost
self.edge_costs[node_pair] = edge_cost
def get_edge_cost(self, src_node, dst_node):
return self.edge_costs[(src_node, dst_node)]
class Solver:
INFINITY_COST = 1e13
def __init__(self,
cost_graph: CostGraph,
memory_budget: float = -1.0,
solution_numbers: int = 1,
memory_increasing_coefficient: float = 1.3,
verbose=False):
'''
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
graph: The computing graph to be optimized.
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
cost_graph: A graph data structure to simplify the edge cost graph.
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
memory_budget: Memory constraint for the solution.
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
'''
self.cost_graph = cost_graph
self.leaf_strategies = cost_graph.leaf_strategies
self.nodes = cost_graph.nodes
self.memory_budget = memory_budget
self.solution_numbers = solution_numbers
if self.solution_numbers > 1:
self.memory_increasing_coefficient = memory_increasing_coefficient
else:
self.memory_increasing_coefficient = 1
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# self.liveness_list = self.graph_analyser.liveness_analysis()
self.liveness_list = self.nodes
self.node_index_dict = self._generate_node_index_dict()
# The last solution vector of auto sharding.
self.last_s_val = None
# The last objective value of the best ILP solution.
self.last_objective = None
self.verbose = verbose
def _generate_node_index_dict(self):
node_index_dict = {}
for index, node in enumerate(self.nodes):
node_index_dict[node] = index
return node_index_dict
def _prepare_data_for_solver(self):
'''
Extract information from components for solver.
'''
node_nums = len(self.leaf_strategies)
memory_budget = self.memory_budget
# prepare strategies_len
strategies_len = []
for node in self.nodes:
strategies_len.append(
len(self.cost_graph.node_strategies_vector[node]))
strategies_len = np.array(strategies_len)
# prepare edge_pairs and resharding costs
edge_pairs = []
resharding_costs = []
edge_cost_level = []
edge_resharding_weights = []
for pairs, edge_cost in self.cost_graph.edge_costs.items():
src_node = pairs[0]
dst_node = pairs[1]
src_node_index = self.node_index_dict[src_node]
dst_node_index = self.node_index_dict[dst_node]
edge_pairs.append(src_node_index)
edge_pairs.append(dst_node_index)
edge_cost_level.append(
(dst_node.building_block_id, dst_node.cost_level))
for i in range(strategies_len[src_node_index]):
for j in range(strategies_len[dst_node_index]):
resharding_costs.append(edge_cost[(i, j)])
edge_resharding_weights.append(dst_node.resharding_weight +
dst_node.pipeline_weight)
edge_pairs = np.array(edge_pairs)
resharding_costs = np.array(resharding_costs)
edge_resharding_weights = np.array(edge_resharding_weights)
# prepare compute_costs, communication_costs and memory_costs
compute_costs = []
communication_costs = []
memory_costs = []
peak_act_memory_costs, constant_memory_costs = [], []
node_sharding_weights = []
for node, strategies_vector in zip(self.nodes, self.leaf_strategies):
for index, strategy in enumerate(strategies_vector):
compute_cost = strategy.sharding_cost
origin_communication_cost = strategy.communication_cost
memory_cost = strategy.const_memory_footprint * node.sharding_weight
peak_act_memory = strategy.peak_memory_footprint
# extract the memory cost in float from MemoryCost item and sum them up
compute_costs.append(compute_cost)
# node in extra_node_costs means it has some extra communication
# cost from node merging, so we need to add those extra communication
# cost into
communication_costs.append(origin_communication_cost)
peak_act_memory_costs.append(peak_act_memory)
constant_memory_costs.append(memory_cost)
node_sharding_weights.append(node.sharding_weight +
node.pipeline_weight)
compute_costs = np.array(compute_costs)
communication_costs = np.array(communication_costs)
memory_costs = np.array([constant_memory_costs, peak_act_memory_costs])
node_sharding_weights = np.array(node_sharding_weights)
same_spec_nodes_dict = defaultdict(list)
node_cost_level = []
for idx, node in enumerate(self.nodes):
if node.same_spec_id >= 0:
same_spec_nodes_dict[node.same_spec_id].append(idx)
node_cost_level.append((node.building_block_id, node.cost_level))
# omit initial value for nodes
s_init_np = None
following_nodes = [-1 for i in range(node_nums)]
liveness_set = self.nodes
alias_set = []
alias_convert_costs = None
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, node_sharding_weights, edge_resharding_weights, same_spec_nodes_dict, node_cost_level, edge_cost_level, alias_convert_costs, s_init_np, self.verbose
def _call_solver_serialized_args(self,
node_nums,
memory_budget,
strategies_len,
following_nodes,
edge_pairs,
alias_set,
liveness_set,
compute_costs,
communication_costs,
memory_costs,
resharding_costs,
node_sharding_weights,
edge_resharding_weights,
same_spec_nodes_dict,
node_cost_level,
edge_cost_level,
alias_convert_costs,
s_init_np=None,
verbose=True):
"""
Call the solver with serialized arguments.
"""
time.time()
for x in [
strategies_len, edge_pairs, compute_costs, communication_costs,
memory_costs, resharding_costs, node_sharding_weights,
edge_resharding_weights
]:
assert isinstance(x, np.ndarray)
assert len(strategies_len) == node_nums, "strategies_len"
def get_non_zero_index(binary_vector):
"""
Get the index of non-zero item in a vector.
"""
ct = 0
ret = None
for i, elem in enumerate(binary_vector):
if pulp.value(elem):
ret = i
ct += 1
assert ct == 1
return ret
# 0. Unpack flatten numpy arrays
s_follow = following_nodes
s_alias = alias_set
E = edge_pairs.reshape((-1, 2)) # noqa
r = []
pt = 0
edge_set = set()
for (i, j) in E:
prod_length = strategies_len[i] * strategies_len[j]
if (i, j) in edge_set:
raise ValueError(f"Duplicated edges: {(i, j)}")
edge_set.add((i, j))
r.append(resharding_costs[pt:pt + prod_length])
pt += prod_length
assert pt == len(resharding_costs)
######################
# omit alias set now #
######################
# A = alias_set.reshape((-1, 2)) # noqa
# for (i, j) in A:
# prod_length = strategies_len[i] * strategies_len[j]
# v.append(alias_convert_costs[pt:pt + prod_length])
# pt += prod_length
# assert pt == len(alias_convert_costs)
# L = [] # noqa
# pt = node_nums
# for i in range(node_nums):
# length = liveness_set[i]
# L.append(liveness_set[pt:pt + length])
# pt += length
# assert pt == len(liveness_set)
pt = 0
c = []
d = []
m = []
peak_m = []
pt = 0
for i in range(node_nums):
length = strategies_len[i]
c.append(compute_costs[pt:pt + length])
d.append(communication_costs[pt:pt + length])
m.append(memory_costs[0][pt:pt + length])
peak_m.append(memory_costs[1][pt:pt + length])
pt += length
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
assert pt == len(
communication_costs), f"{pt} == {len(communication_costs)}"
assert pt == len(memory_costs[0]), f"{pt} == {len(memory_costs[0])}"
# 1. Create variables
#############################
# create variables for node #
#############################
s = []
num_nodes = 0
reverse_follow_backpatch = []
for i in range(node_nums):
if s_follow[i] < 0:
if strategies_len[i] == 1:
s.append([1])
else:
if i not in s_alias:
num_nodes += 1
s.append(
LpVariable.matrix(f"s[{i}]",
(range(strategies_len[i]), ),
cat="Binary"))
else:
s.append(s[s_alias[i]])
else:
if s_follow[i] < len(s):
s.append(s[s_follow[i]])
else:
s.append(None)
reverse_follow_backpatch.append(i)
for i in reverse_follow_backpatch:
s[i] = s[s_follow[i]]
#############################
# create variables for edge #
#############################
e = []
num_edges = 0
map_edge_to_idx = {}
for (idx, (i, j)) in enumerate(E):
if len(s[i]) == 1:
e.append(s[j])
elif len(s[j]) == 1:
e.append(s[i])
else:
if i in s_alias and j in s_alias and (
s_alias[i], s_alias[j]) in map_edge_to_idx:
e.append(e[map_edge_to_idx[(s_alias[i], s_alias[j])]])
else:
num_edges += 1
e.append(
LpVariable.matrix(f"e[{i},{j}]",
(range(len(s[i]) * len(s[j])), ),
cat="Binary"))
assert len(e[idx]) == len(r[idx])
map_edge_to_idx[(i, j)] = idx
for element in s:
assert len(element) > 0
# 2. Set initial value
######################################
# set a initial value for warm start #
######################################
if s_init_np is not None:
s_init = s_init_np.reshape((-1, 3))
for (idx, value, fix) in s_init:
for i in range(len(s[idx])):
s[idx][i].setInitialValue(i == value)
if fix:
s[idx][i].fixValue()
# 3. Objective
prob = LpProblem("myProblem", LpMinimize)
###################################################################
# computing the node cost(computing cost and communication cost) #
###################################################################
obj = 0
block_cost_level_dict = {}
for i in range(node_nums):
assert len(s[i]) == len(c[i])
assert len(s[i]) == len(d[i])
obj += (lpDot(s[i], c[i]) +
lpDot(s[i], d[i])) * node_sharding_weights[i]
cost_level = node_cost_level[i]
if -1 != cost_level[1]:
if cost_level in block_cost_level_dict:
block_cost_level_dict[cost_level] += lpDot(
s[i], c[i]) + lpDot(s[i], d[i])
else:
block_cost_level_dict[cost_level] = lpDot(
s[i], c[i]) + lpDot(s[i], d[i])
#############################################
# computing the edge cost(resharding cost) #
#############################################
for i in range(len(E)):
assert len(e[i]) == len(r[i])
obj += lpDot(e[i], r[i]) * edge_resharding_weights[i]
cost_level = edge_cost_level[i]
if -1 != cost_level[1]:
if cost_level in block_cost_level_dict:
block_cost_level_dict[cost_level] += lpDot(e[i], r[i])
else:
block_cost_level_dict[cost_level] = lpDot(e[i], r[i])
prob += obj
if len(block_cost_level_dict) >= 2:
block_cost_levels = [key for key in block_cost_level_dict.keys()]
for i in range(len(block_cost_levels)):
for j in range(i + 1, len(block_cost_levels)):
if block_cost_levels[i][1] > block_cost_levels[j][1]:
prob += block_cost_level_dict[
block_cost_levels[i]] >= block_cost_level_dict[
block_cost_levels[j]] + 1e-6
elif block_cost_levels[i][1] < block_cost_levels[j][1]:
prob += block_cost_level_dict[
block_cost_levels[j]] >= block_cost_level_dict[
block_cost_levels[i]] + 1e-6
# 4. Constraints
# (a). specified by `cat="Binary"`
# (b)
#################################################
# make sure each node only choose one strategy #
#################################################
for i in range(node_nums):
if s_follow[i] < 0:
prob += lpSum(s[i]) == 1
# (c)
#################################################
# force to constrain some nodes have the same sharding specs #
#################################################
for spec_id, same_spec_nodes_id in same_spec_nodes_dict.items():
num_same_spec_nodes = len(same_spec_nodes_id)
if num_same_spec_nodes >= 2:
src_node_s = s[same_spec_nodes_id[0]]
num_specs = len(src_node_s)
for i in range(1, num_same_spec_nodes):
dst_node_s = s[same_spec_nodes_id[i]]
assert len(
dst_node_s
) == num_specs, f'unmatched num_specs when force node {same_spec_nodes_id[0]} and {same_spec_nodes_id[i]} the same specs'
for j in range(num_specs):
prob += (src_node_s[j] == dst_node_s[j])
# (c)
#################################################
# compute memory consumption with liveness set #
#################################################
if memory_budget > 0:
# calculate the constant memory
mem = 0
for node in liveness_set:
if node not in self.node_index_dict:
continue
node_index = self.node_index_dict[node]
mem += lpSum(s[node_index][j] * m[node_index][j]
for j in range(len(s[node_index])))
# calculate the peak activation memory
for node in liveness_set:
if node not in self.node_index_dict:
continue
node_index = self.node_index_dict[node]
cur_peak_mem = lpSum(s[node_index][j] * peak_m[node_index][j]
for j in range(len(s[node_index])))
total_mem = mem + cur_peak_mem
prob += total_mem <= memory_budget
# (d). specified by `cat="Binary"`
for (idx, (i, j)) in enumerate(E):
if strategies_len[i] == 1 or strategies_len[j] == 1:
continue
# (e)
prob += lpSum(e[idx]) == 1
# (f)
for row in range(len(s[i])):
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col]
for col in range(0, C)) <= s[i][row]
# (g)
for col in range(len(s[j])):
R = len(s[i]) # noqa
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col]
for row in range(0, R)) <= s[j][col]
if prob.objective.isNumericalConstant():
objective = float(pulp.value(prob.objective))
status = pulp.LpStatusOptimal
else:
msg = verbose
time_limit = 600
solver = pulp.PULP_CBC_CMD(
mip=True,
msg=msg,
timeLimit=time_limit,
threads=multiprocessing.cpu_count(),
)
prob.solve(solver)
status = prob.status
objective = pulp.value(prob.objective)
objective = float(
objective) if objective is not None else self.INFINITY_COST
if prob.status in [pulp.LpStatusInfeasible]:
objective = self.INFINITY_COST
# Get and check results
s_val = np.full((node_nums, ), -1, dtype=np.int32)
for i in range(node_nums):
s_val[i] = get_non_zero_index(s[i])
e_val = np.full((len(E), ), -1, dtype=np.int32)
for (idx, (i, j)) in enumerate(E):
e_val[idx] = get_non_zero_index(e[idx])
i_spec_index = e_val[idx] // len(s[j])
j_spec_index = e_val[idx] % len(s[j])
assert i_spec_index == s_val[i], f"e_val[{i}][{j}]"
assert j_spec_index == s_val[j], f"e_val[{i}][{j}]"
if verbose and r[idx][e_val[idx]] > 0:
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
self.last_s_val = list(s_val)
# self._recover_merged_node_strategy()
self.last_objective = objective
if objective >= self.INFINITY_COST:
warnings.warn(
f"Cannot find an optimized solution given memory budget {self.memory_budget}, Please consider\n" + \
f"1. increase memory budget if possible\n" + \
f"2. enlarge mesh shape if possible\n" + \
f"3. decrease the maximum parameters(i.e., max_batch_size, max_seq_len, etc.) in building config")
if memory_budget > 0:
# calculate the constant memory
mem = 0
for node in liveness_set:
if node not in self.node_index_dict:
continue
node_index = self.node_index_dict[node]
j = self.last_s_val[node_index]
mem += m[node_index][j]
max_peak_mem = 0
for node in liveness_set:
if node not in self.node_index_dict:
continue
node_index = self.node_index_dict[node]
j = self.last_s_val[node_index]
cur_peak_mem = peak_m[node_index][j]
max_peak_mem = max(max_peak_mem, cur_peak_mem)
logger.debug(
f'constant_mem = {mem}, peak_mem = {max_peak_mem}, memory_budget = {memory_budget}'
)
solution = Solution(self.leaf_strategies, self.last_s_val, e_val,
edge_pairs, self.node_index_dict,
self.last_objective)
return status, solution
def find_solution(self):
"""
Call the solver with serialized arguments and handle python errors. Additionally,
we could give a serious of solutions with different memory budget.
"""
if self.solution_numbers == 1:
args = self._prepare_data_for_solver()
ret = self._call_solver_serialized_args(*args)
return ret
origin_memory_budget = self.memory_budget
memory_budget_list = [
origin_memory_budget * self.memory_increasing_coefficient**i
for i in range(self.solution_numbers)
]
ret_list = []
for memory_budget in memory_budget_list:
self.memory_budget = memory_budget
args = self._prepare_data_for_solver()
ret = self._call_solver_serialized_args(*args)
ret_list.append(ret)
return ret_list

View File

@ -1,41 +0,0 @@
import copy
from .node import Node
from .sharding_strategy import StrategiesVector
class Activation(Node):
def _collect_strategies(self, device_mesh):
dim_partition_list = []
dim_size = len(self.op_data['input0'].shape)
dim_partition_list.append({})
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
strategies_vector = StrategiesVector(self)
for dim_partition_dict in dim_partition_list:
in0_partition_dict = dim_partition_dict
out_partition_dict = copy.deepcopy(dim_partition_dict)
dim_partition_dict_mapping = {
"input0": in0_partition_dict,
"output0": out_partition_dict,
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
name = '{} = <activation op> {}'.format(
sharding_spec_mapping['output0'].sharding_sequence,
sharding_spec_mapping['input0'].sharding_sequence)
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector

View File

@ -1,34 +0,0 @@
import copy
from .node import Node
from .sharding_strategy import StrategiesVector
class Assertion(Node):
def _collect_strategies(self, device_mesh):
predecessor = self.predecessor_nodes[0] # one input for softmax node
strategies_vector = StrategiesVector(self)
for idx, strategy in enumerate(predecessor.strategies_vector):
global_input_name = self.op_data[
'input0'].name # current node's local name input0 -> global name xxx
prenode_local_name = predecessor.global_to_local_op_name[
global_input_name] # global name xxx -> pre node local output name
dim_partition_dict = copy.deepcopy(
strategy.sharding_specs[prenode_local_name].dim_partition_dict)
in0_partition_dict = dim_partition_dict
dim_partition_dict_mapping = {
"input0": in0_partition_dict,
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
return strategies_vector
name = '<assertion> {}'.format(
sharding_spec_mapping['input0'].sharding_sequence)
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector

View File

@ -1,45 +0,0 @@
import copy
from .node import Node
from .sharding_strategy import StrategiesVector
class Cast(Node):
def _collect_strategies(self, device_mesh):
dim_partition_list = []
dim_size = len(self.op_data['input0'].shape)
dim_partition_list.append({})
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
strategies_vector = StrategiesVector(self)
for dim_partition_dict in dim_partition_list:
in0_partition_dict = dim_partition_dict
out_partition_dict = copy.deepcopy(dim_partition_dict)
dim_partition_dict_mapping = {
"input0": in0_partition_dict,
"output0": out_partition_dict,
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
name = '{} = <cast op> {}'.format(
sharding_spec_mapping['output0'].sharding_sequence,
sharding_spec_mapping['input0'].sharding_sequence)
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector
def _update_memory_cost(self, strategies):
pass

View File

@ -1,58 +0,0 @@
__all__ = [
'CommSpec',
]
class CommSpec:
def __init__(self,
comm_pattern,
sharding_spec,
gather_dim=None,
shard_dim=None,
logical_process_axis=None,
mix_gather=False,
forward_only=True):
self.comm_pattern = comm_pattern
self.sharding_spec = sharding_spec
self.gather_dim = gather_dim
self.shard_dim = shard_dim
self.logical_process_axis = logical_process_axis
self.device_mesh = self.sharding_spec.device_mesh
self.mix_gather = mix_gather
self.forward_only = forward_only
if self.gather_dim:
assert len(self.gather_dim) == len(
self.logical_process_axis
), f'unmatched gather dim {self.gather_dim} and logical process axis {self.logical_process_axis}'
if self.shard_dim:
assert len(self.shard_dim) == len(
self.logical_process_axis
), f'unmatched shard dim {self.shard_dim} and logical process axis {self.logical_process_axis}'
if self.gather_dim and self.shard_dim:
assert len(self.shard_dim) == len(
self.gather_dim
), f'unmatched gather dim {self.gather_dim} and shard dim {self.shard_dim}'
def get_comm_cost(self):
'''
For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
compute the communication cost.
For shard operation, it is an on-chip operation, so the communication cost is zero.
'''
comm_size = self.sharding_spec.get_sharded_size_per_device()
dtype = self.sharding_spec.dtype
# reduce list_of_list to list
comm_dims = sum(self.logical_process_axis, [])
comm_cost = self.device_mesh.estimate_comm_cost(self.comm_pattern,
comm_dims, comm_size,
dtype)
return comm_cost
def get_mem_cost(self):
return self.device_mesh.shape_consistency_manager.mem_cost([self])
def get_max_mem_cost(self):
return self.device_mesh.shape_consistency_manager.mem_cost(
[self], mem_pattern='max')

View File

@ -1,56 +0,0 @@
import copy
from .node import Node
from .sharding_strategy import StrategiesVector
class Concatenation(Node):
def __init__(self, layer):
super().__init__(layer)
layer.to_subclass()
batch_dims = [i for i in range(len(self.get_output(0).shape))]
self.axis = layer.as_trt().axis
batch_dims.remove(self.axis)
self._generate_bcast_dims(batch_dims, self.get_output(0).shape)
layer.to_base_class()
def _collect_strategies(self, device_mesh):
dim_partition_list = []
dim_size = len(self.op_data['output0'].shape)
dim_partition_list.append({})
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
dim_partition_dict_mapping = {}
strategies_vector = StrategiesVector(self)
for dim_partition_dict in dim_partition_list:
if self.axis in dim_partition_dict:
dim_partition_dict.pop(self.axis)
for idx in range(self.num_inputs):
in_partition_dict = copy.deepcopy(dim_partition_dict)
dim_partition_dict_mapping[f'input{idx}'] = in_partition_dict
out_partition_dict = dim_partition_dict
dim_partition_dict_mapping['output0'] = out_partition_dict
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
name = '{} = <concate along dim {}> {}'.format(
sharding_spec_mapping['output0'].sharding_sequence, self.axis, [
sharding_spec_mapping[f'input{idx}'].sharding_sequence
for idx in range(self.num_inputs)
])
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector

View File

@ -1,45 +0,0 @@
from .node import Node
from .sharding_strategy import StrategiesVector
class Constant(Node):
def _update_memory_cost(self, strategies):
super()._update_memory_cost(strategies)
for strategy in strategies:
strategy.inout_memory_footprint = 0.0
strategy.peak_memory_footprint = 0.0
strategy.const_memory_footprint = strategy.sharding_specs[
'output0'].get_max_sharded_size_per_device()
def _collect_strategies(self, device_mesh):
dim_partition_list = []
dim_size = len(self.op_data['output0'].shape)
dim_partition_list.append({})
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([1], dim_size))
strategies_vector = StrategiesVector(self)
for dim_partition_dict in dim_partition_list:
dim_partition_dict_mapping = {'output0': dim_partition_dict}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
sharding_seq = sharding_spec_mapping['output0'].sharding_sequence
sharding_strategy = self._get_sharding_strategy(
name=f'constant-op {sharding_seq}',
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector
def _profile_sharding_cost(self, strategy, device_mesh):
return 0.0

View File

@ -1,49 +0,0 @@
from .node import Node
from .sharding_strategy import StrategiesVector
class ElementWise(Node):
def __init__(self, layer):
super().__init__(layer)
batch_dims = [i for i in range(len(self.get_output(0).shape))]
self._generate_bcast_dims(batch_dims, self.get_output(0).shape)
def _collect_strategies(self, device_mesh):
dim_partition_list = []
dim_size = len(self.op_data['output0'].shape)
dim_partition_list.append({})
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
strategies_vector = StrategiesVector(self)
for dim_partition_dict in dim_partition_list:
in0_partition_dict = self._recover_bcast_partition_dict(
dim_partition_dict, self.op_data['input0'])
in1_partition_dict = self._recover_bcast_partition_dict(
dim_partition_dict, self.op_data['input1'])
out_partition_dict = dim_partition_dict
dim_partition_dict_mapping = {
"input0": in0_partition_dict,
"input1": in1_partition_dict,
"output0": out_partition_dict,
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
name = '{} = {} <elementwise> {}'.format(
sharding_spec_mapping['output0'].sharding_sequence,
sharding_spec_mapping['input0'].sharding_sequence,
sharding_spec_mapping['input1'].sharding_sequence)
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector

View File

@ -1,59 +0,0 @@
import tensorrt as trt
from .node import Node
from .sharding_strategy import StrategiesVector
class Fill(Node):
def __init__(self, layer):
super().__init__(layer)
layer.to_subclass()
self.operation = layer.as_trt().operation
layer.to_base_class()
def _collect_strategies(self, device_mesh):
dim_partition_list = []
dim_size = len(self.op_data['output0'].shape)
dim_partition_list.append({})
if self.num_inputs == 0 and self.operation != trt.FillOperation.LINSPACE:
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
strategies_vector = StrategiesVector(self)
for dim_partition_dict in dim_partition_list:
dim_partition_dict_mapping = {'output0': dim_partition_dict}
for i in range(self.num_inputs):
dim_partition_dict_mapping[f'input{i}'] = {}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
sharding_seq = sharding_spec_mapping['output0'].sharding_sequence
sharding_strategy = self._get_sharding_strategy(
name=f'fill-op {sharding_seq}',
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector
def _profile_sharding_cost(self, strategy, device_mesh):
updated_layer_attrs = {}
updated_input_values = {}
shape = strategy.sharding_specs['output0'].get_sharded_shape_per_device(
)
if self.layer.num_inputs >= 1:
updated_input_values[0] = shape
else:
updated_layer_attrs['shape'] = shape
elapsed_time = self.node_runtime_profiler.runtime_profile(
self.layer, updated_layer_attrs, updated_input_values, strategy,
device_mesh)
return elapsed_time

View File

@ -1,196 +0,0 @@
import copy
import tensorrt as trt
from .comm_spec import CommSpec
from .node import Node
from .sharding_spec import DimSpec
from .sharding_strategy import StrategiesVector
class Gather(Node):
def __init__(self, layer):
super().__init__(layer)
layer.to_subclass()
self.mode = layer.as_trt().mode
self.axis = layer.as_trt().axis
self.num_elementwise_dims = layer.as_trt().num_elementwise_dims
self.input_id = 0
self.indice_id = 1
self.support_vocab_tp = False
layer.to_base_class()
def _update_memory_cost(self, strategies):
for strategy in strategies:
# for gather node, it input0's read = output0's write
inout_memory_footprint = (
strategy.sharding_specs['output0'].get_sharded_size_per_device(
) * 2 +
strategy.sharding_specs['input1'].get_sharded_size_per_device())
strategy.inout_memory_footprint = inout_memory_footprint
strategy.peak_memory_footprint = (
strategy.sharding_specs['output0'].
get_max_sharded_size_per_device() + strategy.
sharding_specs['input0'].get_max_sharded_size_per_device() +
strategy.sharding_specs['input1'].
get_max_sharded_size_per_device())
def _collect_strategies(self, device_mesh):
if self.mode == trt.GatherMode.DEFAULT:
return self._default_gather_strategies(device_mesh)
elif self.mode == trt.GatherMode.ELEMENT:
return self._element_gather_strategies(device_mesh)
elif self.mode == trt.GatherMode.ND:
assert 0, 'unsupport gatherND'
else:
assert 0, f'unsupport gather mode {self.mode}'
def _element_gather_strategies(self, device_mesh):
dim_partition_list = []
dim_size = len(self.op_data['output0'].shape)
dim_partition_list.append({})
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
strategies_vector = StrategiesVector(self)
for dim_partition_dict in dim_partition_list:
if self.axis in dim_partition_dict:
dim_partition_dict.pop(self.axis)
dim_partition_dict_mapping = {
'input0': dim_partition_dict,
'input1': copy.deepcopy(dim_partition_dict),
'output0': copy.deepcopy(dim_partition_dict),
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
name = '{} = {} <element gather op axis {}> {}'.format(
sharding_spec_mapping['output0'].sharding_sequence,
sharding_spec_mapping['input0'].sharding_sequence, self.axis,
sharding_spec_mapping['input1'].sharding_sequence)
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector
# for plugin, indice is input0, and weight is input1, which is different from gather node
def _default_gather_strategies(self, device_mesh):
def add_sharding_strategy(dim_partition_dict_mapping,
vocab_tp_dim=None):
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if len(sharding_spec_mapping) > 0:
name = '{} = {} <default gather op axis {}, num_elementwise_dims {}> {}'.format(
sharding_spec_mapping['output0'].sharding_sequence,
sharding_spec_mapping['input0'].sharding_sequence,
self.axis, self.num_elementwise_dims,
sharding_spec_mapping['input1'].sharding_sequence)
communication_action_mapping = {}
if vocab_tp_dim is not None:
name += f'_allreduce{DimSpec(vocab_tp_dim)}'
output0_comm_action = CommSpec(
comm_pattern='all_reduce',
sharding_spec=sharding_spec_mapping['output0'],
logical_process_axis=[vocab_tp_dim],
)
communication_action_mapping[
'output0'] = output0_comm_action
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategies_vector.append(sharding_strategy)
input_id, indice_id = self.input_id, self.indice_id
strategies_vector = StrategiesVector(self)
input_size = len(self.op_data[f'input{input_id}'].shape)
indice_size = len(self.op_data[f'input{indice_id}'].shape)
output_dim = input_size + indice_size - 1 - self.num_elementwise_dims
for strategy in self.predecessor_nodes[input_id].strategies_vector:
# current node's local name input0 -> global name xxx
global_input_name = self.op_data[f'input{input_id}'].name
# global name xxx -> pre node local output name
prenode_local_name = self.predecessor_nodes[
input_id].global_to_local_op_name[global_input_name]
input_dim_partition_dict = copy.deepcopy(
strategy.sharding_specs[prenode_local_name].dim_partition_dict)
vocab_tp_dim = input_dim_partition_dict.pop(self.axis, None)
input_mesh_dims = []
for dim, mesh_dims in input_dim_partition_dict.items():
input_mesh_dims += mesh_dims
input_mesh_dims = set(input_mesh_dims)
for idx_strategy in self.predecessor_nodes[
indice_id].strategies_vector:
# current node's local name input0 -> global name xxx
global_indice_name = self.op_data[f'input{indice_id}'].name
# global name xxx -> pre node local output name
prenode_local_name = self.predecessor_nodes[
indice_id].global_to_local_op_name[global_indice_name]
indice_dim_partition_dict = copy.deepcopy(
idx_strategy.sharding_specs[prenode_local_name].
dim_partition_dict)
for dim, indice_mesh_dims in idx_strategy.sharding_specs[
prenode_local_name].dim_partition_dict.items():
for indice_mesh_dim in indice_mesh_dims:
if indice_mesh_dim in input_mesh_dims:
indice_dim_partition_dict.pop(dim)
break
out_partition_dict = {}
for dim in range(output_dim):
if dim < self.axis:
if dim in input_dim_partition_dict:
out_partition_dict[dim] = \
input_dim_partition_dict[dim]
elif dim >= self.axis and dim < self.axis + indice_size - self.num_elementwise_dims:
indice_dim = dim - self.axis + self.num_elementwise_dims
if indice_dim in indice_dim_partition_dict:
out_partition_dict[dim] = \
indice_dim_partition_dict[indice_dim]
else:
input_dim = dim - (indice_size -
self.num_elementwise_dims) + 1
if input_dim in input_dim_partition_dict:
out_partition_dict[dim] = \
input_dim_partition_dict[input_dim]
dim_partition_dict_mapping = {
f"input{input_id}": input_dim_partition_dict,
f"input{indice_id}": indice_dim_partition_dict,
"output0": out_partition_dict,
}
add_sharding_strategy(dim_partition_dict_mapping)
if self.support_vocab_tp and vocab_tp_dim is not None:
vocab_tp_dim_partition_dict = {
**input_dim_partition_dict,
self.axis: vocab_tp_dim,
}
dim_partition_dict_mapping = {
f"input{input_id}": vocab_tp_dim_partition_dict,
f"input{indice_id}": indice_dim_partition_dict,
"output0": out_partition_dict,
}
add_sharding_strategy(dim_partition_dict_mapping,
vocab_tp_dim)
return strategies_vector

View File

@ -1,56 +0,0 @@
import copy
from .node import Node
from .sharding_strategy import StrategiesVector
class Identity(Node):
def _update_memory_cost(self, strategies):
if not self.is_fake:
super()._update_memory_cost(strategies)
else:
# fake nodes for building block/PP connection
pass
def _collect_strategies(self, device_mesh):
dim_partition_list = []
dim_size = len(self.op_data['input0'].shape)
dim_partition_list.append({})
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
strategies_vector = StrategiesVector(self)
# dim_partition_dict can be the same as previous node if solver's time is a problem
for dim_partition_dict in dim_partition_list:
in0_partition_dict = dim_partition_dict
out_partition_dict = copy.deepcopy(dim_partition_dict)
dim_partition_dict_mapping = {
"input0": in0_partition_dict,
"output0": out_partition_dict,
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
name = '{} = <identity op> {}'.format(
sharding_spec_mapping['output0'].sharding_sequence,
sharding_spec_mapping['input0'].sharding_sequence)
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector
def _profile_sharding_cost(self, strategy, device_mesh):
# if same spec id is not 0, identify node is used as same spec id node
if self.same_spec_id == -1:
return super()._profile_sharding_cost(strategy, device_mesh)
else:
return 0.0

View File

@ -1,79 +0,0 @@
from .node import Node
from .sharding_strategy import StrategiesVector
class InputNode(Node):
def _update_memory_cost(self, strategies):
for strategy in strategies:
if not self.no_memory_footprint:
strategy.const_memory_footprint = strategy.sharding_specs[
'output0'].get_max_sharded_size_per_device()
def __init__(self, tensor):
self._layer = None
self.is_shape_io = False
self._inputs = []
self._outputs = []
self.predecessor_nodes = []
self.predecessor_nodes_out_index = {}
self.successor_nodes = []
self.op_data = {}
self.global_to_local_op_name = {}
self.is_replicated = tensor.attrs.get("is_replicated", False)
self.same_spec_id = tensor.attrs.get("same_spec_id", -1)
self.no_memory_footprint = tensor.attrs.get("no_memory_footprint",
False)
self.building_block_id = -1
self.cost_level = -1
self.stage_type = None
self.in_start_block = None
self.in_end_block = None
self.in_slowest_block = None
output = tensor.copy()
self._outputs.append(output)
self.op_data['output0'] = output
self.global_to_local_op_name[output.name] = 'output0'
self.sharding_weight = 1.0
self.resharding_weight = 1.0
self.pipeline_weight = 0
self.node_name = tensor.name
self.node_type = 'input_node'
self.num_inputs = 0
self.num_outputs = 1
self.dtype = tensor.dtype
self.strategies_vector = []
self.node_runtime_profiler = None
def _collect_strategies(self, device_mesh):
dim_partition_list = []
dim_size = len(self.op_data['output0'].shape)
dim_partition_list.append({})
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
strategies_vector = StrategiesVector(self)
for dim_partition_dict in dim_partition_list:
dim_partition_dict_mapping = {'output0': dim_partition_dict}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
sharding_seq = sharding_spec_mapping['output0'].sharding_sequence
sharding_strategy = self._get_sharding_strategy(
name=f'input-op {sharding_seq}',
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector
def _profile_sharding_cost(self, strategy, device_mesh):
return 0.0

View File

@ -1,798 +0,0 @@
import copy
import operator
from functools import reduce
import tensorrt as trt
from ..device_mesh import LogicalDeviceMesh
from ..utils import get_builder_flags
from .comm_spec import CommSpec
from .node import Node
from .sharding_spec import DimSpec
from .sharding_strategy import StrategiesVector
class MatrixMultiply(Node):
def __init__(self, layer):
super().__init__(layer)
layer.to_subclass()
batch_dims = [i for i in range(len(self.get_output(0).shape))][:-2]
self._generate_bcast_dims(batch_dims, self.get_output(0).shape)
self.op0_transpose = layer.as_trt().op0 == trt.MatrixOperation.TRANSPOSE
self.op1_transpose = layer.as_trt().op1 == trt.MatrixOperation.TRANSPOSE
self.num_out_dims = len(self.get_output(0).shape)
dtypes_str = [
self.get_input(0).dtype_str,
self.get_input(1).dtype_str,
self.get_output(0).dtype_str
]
dtypes_size = [
self.get_input(0).dtype_size,
self.get_input(1).dtype_size,
self.get_output(0).dtype_size
]
min_idx = dtypes_size.index(min(dtypes_size))
self.dtype = dtypes_str[min_idx]
layer.to_base_class()
def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1, device_mesh):
in0_split_dim = -1 if self.op0_transpose else -2
in1_split_dim = -2 if self.op1_transpose else -1
name = (f'{DimSpec(mesh_dim_0)}{DimSpec(mesh_dim_1)} = '
f'{DimSpec(mesh_dim_0)}R x R{DimSpec(mesh_dim_1)}')
dim_partition_dict_mapping = {
"input0": {
in0_split_dim: mesh_dim_0
},
"input1": {
in1_split_dim: mesh_dim_1
},
"output0": {
-2: mesh_dim_0,
-1: mesh_dim_1
},
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if len(sharding_spec_mapping) == 0:
return None
strategy = self._get_sharding_strategy(name = name, \
sharding_spec_mapping = sharding_spec_mapping, \
communication_action_mapping = {})
return strategy
def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1,
device_mesh):
# handle the case SR = SS x SR
name = (
f'{DimSpec(mesh_dim_0)}R = '
f'{DimSpec(mesh_dim_0)}{DimSpec(mesh_dim_1)} x {DimSpec(mesh_dim_1)}R'
f'_allreduce{DimSpec(mesh_dim_1)}')
in0_split_dim = [-1, -2] if self.op0_transpose else [-2, -1]
in1_split_dim = -1 if self.op1_transpose else -2
# get sharding spec mapping
dim_partition_dict_mapping = {
"input0": {
in0_split_dim[0]: mesh_dim_0,
in0_split_dim[1]: mesh_dim_1
},
"input1": {
in1_split_dim: mesh_dim_1
},
"output0": {
-2: mesh_dim_0
},
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if len(sharding_spec_mapping) == 0:
return None
# get communication action mapping
communication_action_mapping = {}
output0_comm_action = CommSpec(
comm_pattern='all_reduce',
sharding_spec=sharding_spec_mapping['output0'],
logical_process_axis=[mesh_dim_1],
)
communication_action_mapping['output0'] = output0_comm_action
return self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def _split_both_contract_rs(self, name, rs_dim, rs_mesh_dim, src_spec,
dim_partition_dict_mapping, device_mesh):
output0_comm_action = CommSpec(
comm_pattern='reduce_scatter',
sharding_spec=src_spec,
shard_dim=[rs_dim],
logical_process_axis=[rs_mesh_dim],
)
rs_out_partition_dict_mapping = copy.deepcopy(
dim_partition_dict_mapping)
rs_out_partition_dict_mapping["output0"][rs_dim] = rs_mesh_dim
rs_out_sharding_spec_mapping = self._to_sharding_spec_mapping(
rs_out_partition_dict_mapping, device_mesh)
if len(rs_out_sharding_spec_mapping) == 0:
return None
communication_action_mapping = {}
communication_action_mapping['output0'] = output0_comm_action
return self._get_sharding_strategy(
name=name,
sharding_spec_mapping=rs_out_sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def _split_lhs_space_both_contract_rs(self, mesh_dim_0, mesh_dim_1,
device_mesh):
# handle the case SS = SS x SR -> reduce_scatter
in0_split_dim = [-1, -2] if self.op0_transpose else [-2, -1]
in1_split_dim = -1 if self.op1_transpose else -2
# get sharding spec mapping
dim_partition_dict_mapping = {
"input0": {
in0_split_dim[0]: mesh_dim_0,
in0_split_dim[1]: mesh_dim_1
},
"input1": {
in1_split_dim: mesh_dim_1
},
"output0": {
-2: mesh_dim_0,
},
}
mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if len(mm_out_sharding_spec_mapping) == 0:
return []
strategies = []
for rs_dim in range(self.num_out_dims):
if rs_dim != self.num_out_dims - 2:
name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [
'R'
] * self.num_out_dims, ['R'] * self.num_out_dims
name_in0[-2], name_in0[-1] = str(DimSpec(mesh_dim_0)), str(
DimSpec(mesh_dim_1))
name_in1[-2] = str(DimSpec(mesh_dim_1))
name_out0[-2], name_out0[rs_dim] = str(
DimSpec(mesh_dim_0)), str(DimSpec(mesh_dim_1))
name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join(
name_in1), ', '.join(name_out0)
name = (f'[{name_out0}] = [{name_in0}] x [{name_in1}]'
f'_reducescatter{(rs_dim, DimSpec(mesh_dim_1))}')
ret = self._split_both_contract_rs(
name, rs_dim, mesh_dim_1,
mm_out_sharding_spec_mapping['output0'],
dim_partition_dict_mapping, device_mesh)
if ret:
strategies.append(ret)
return strategies
def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1,
device_mesh):
name = (
f'R{DimSpec(mesh_dim_1)} = '
f'R{DimSpec(mesh_dim_0)} x {DimSpec(mesh_dim_0)}{DimSpec(mesh_dim_1)}'
f'_allreduce{DimSpec(mesh_dim_0)}')
in0_split_dim = -2 if self.op0_transpose else -1
in1_split_dim = [-1, -2] if self.op1_transpose else [-2, -1]
# get sharding specs
dim_partition_dict_mapping = {
"input0": {
in0_split_dim: mesh_dim_0
},
"input1": {
in1_split_dim[0]: mesh_dim_0,
in1_split_dim[1]: mesh_dim_1
},
"output0": {
-1: mesh_dim_1
},
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if len(sharding_spec_mapping) == 0:
return None
# get communication actions
communication_action_mapping = {}
output0_comm_action = CommSpec(
comm_pattern='all_reduce',
sharding_spec=sharding_spec_mapping['output0'],
logical_process_axis=[mesh_dim_0],
)
communication_action_mapping['output0'] = output0_comm_action
return self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def _split_rhs_space_both_contract_rs(self, mesh_dim_0, mesh_dim_1,
device_mesh):
in0_split_dim = -2 if self.op0_transpose else -1
in1_split_dim = [-1, -2] if self.op1_transpose else [-2, -1]
# get sharding specs
dim_partition_dict_mapping = {
"input0": {
in0_split_dim: mesh_dim_0
},
"input1": {
in1_split_dim[0]: mesh_dim_0,
in1_split_dim[1]: mesh_dim_1
},
"output0": {
-1: mesh_dim_1
},
}
mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if len(mm_out_sharding_spec_mapping) == 0:
return []
strategies = []
for rs_dim in range(self.num_out_dims):
if rs_dim != self.num_out_dims - 1:
name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [
'R'
] * self.num_out_dims, ['R'] * self.num_out_dims
name_in1[-2], name_in1[-1] = str(DimSpec(mesh_dim_0)), str(
DimSpec(mesh_dim_1))
name_in0[-1] = str(DimSpec(mesh_dim_0))
name_out0[-1], name_out0[rs_dim] = str(
DimSpec(mesh_dim_1)), str(DimSpec(mesh_dim_0))
name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join(
name_in1), ', '.join(name_out0)
name = (f'[{name_out0}] = [{name_in0}] x [{name_in1}]'
f'_reducescatter{(rs_dim, DimSpec(mesh_dim_0))}')
ret = self._split_both_contract_rs(
name, rs_dim, mesh_dim_0,
mm_out_sharding_spec_mapping['output0'],
dim_partition_dict_mapping, device_mesh)
if ret:
strategies.append(ret)
return strategies
def _recompute_split_both_contract(self, mesh_dim, device_mesh):
name = (f'RR = R{DimSpec(mesh_dim)} x {DimSpec(mesh_dim)}R'
f'_allreduce{DimSpec(mesh_dim)}')
in0_split_dim = -2 if self.op0_transpose else -1
in1_split_dim = -1 if self.op1_transpose else -2
dim_partition_dict_mapping = {
"input0": {
in0_split_dim: mesh_dim
},
"input1": {
in1_split_dim: mesh_dim
},
"output0": {},
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if len(sharding_spec_mapping) == 0:
return None
# get communication action
communication_action_mapping = {}
output0_comm_action = CommSpec(
comm_pattern='all_reduce',
sharding_spec=sharding_spec_mapping['output0'],
logical_process_axis=[mesh_dim],
)
communication_action_mapping['output0'] = output0_comm_action
return self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def _recompute_split_both_contract_rs(self, mesh_dim, device_mesh):
name = (f'{DimSpec(mesh_dim)}R = '
f'R{DimSpec(mesh_dim)} x {DimSpec(mesh_dim)}R'
f'_reducescatter0_{DimSpec(mesh_dim)}')
in0_split_dim = -2 if self.op0_transpose else -1
in1_split_dim = -1 if self.op1_transpose else -2
dim_partition_dict_mapping = {
"input0": {
in0_split_dim: mesh_dim
},
"input1": {
in1_split_dim: mesh_dim
},
"output0": {},
}
mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if len(mm_out_sharding_spec_mapping) == 0:
return []
strategies = []
for rs_dim in range(self.num_out_dims):
name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [
'R'
] * self.num_out_dims, ['R'] * self.num_out_dims
name_in0[-1], name_in1[-2], name_out0[rs_dim] = str(
DimSpec(mesh_dim)), str(DimSpec(mesh_dim)), str(
DimSpec(mesh_dim))
name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join(
name_in1), ', '.join(name_out0)
name = f'[{name_out0}] = [{name_in0}] x [{name_in1}]_reducescatter{(rs_dim, DimSpec(mesh_dim))}'
ret = self._split_both_contract_rs(
name, rs_dim, mesh_dim, mm_out_sharding_spec_mapping['output0'],
dim_partition_dict_mapping, device_mesh)
if ret:
strategies.append(ret)
return strategies
def _split_rhs_space_only(self, mesh_dim, device_mesh):
name = f'R{DimSpec(mesh_dim)} = RR x R{DimSpec(mesh_dim)}'
in1_split_dim = -2 if self.op1_transpose else -1
# get sharding spec
dim_partition_dict_mapping = {
"input0": {},
"input1": {
in1_split_dim: mesh_dim
},
"output0": {
-1: mesh_dim
},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output0.
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if len(sharding_spec_mapping) == 0:
return None
return self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
def _split_lhs_space_only(self, mesh_dim, device_mesh):
name = f'{DimSpec(mesh_dim)}R = {DimSpec(mesh_dim)}R x RR'
in0_split_dim = -1 if self.op0_transpose else -2
# get sharding spec
dim_partition_dict_mapping = {
"input0": {
in0_split_dim: mesh_dim
},
"input1": {},
"output0": {
-2: mesh_dim
},
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if len(sharding_spec_mapping) == 0:
return None
return self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
def _non_split(self, device_mesh):
name = 'RR = RR x RR'
# get sharding spec
dim_partition_dict_mapping = {
"input0": {},
"input1": {},
"output0": {},
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if len(sharding_spec_mapping) == 0:
return None
return self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
def _split_one_batch_dim(self, batch_dim, mesh_dim, device_mesh):
name = (
f'{DimSpec(mesh_dim)}b{batch_dim}RR = '
f'{DimSpec(mesh_dim)}b{batch_dim}RR x {DimSpec(mesh_dim)}b{batch_dim}RR'
)
in0_data = self.op_data['input0']
in1_data = self.op_data['input1']
batch_partition_dict = {batch_dim: mesh_dim}
in0_parition_dict = self._recover_bcast_partition_dict(
batch_partition_dict, in0_data)
in1_parition_dict = self._recover_bcast_partition_dict(
batch_partition_dict, in1_data)
out_partition_dict = {batch_dim: mesh_dim}
# TODO: Double check if MatrixMultiplication's output has bcast in dim
dim_partition_dict_mapping = {
"input0": in0_parition_dict,
"input1": in1_parition_dict,
"output0": out_partition_dict,
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if len(sharding_spec_mapping) == 0:
return None
return self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
def _split_two_batch_dims(self, batch_dim0, batch_dim1, mesh_dim0,
mesh_dim1, device_mesh):
name = (
f'{DimSpec(mesh_dim0)}b{batch_dim0}{DimSpec(mesh_dim1)}b{batch_dim1}RR = '
f'{DimSpec(mesh_dim0)}b{batch_dim0}RR x {DimSpec(mesh_dim1)}b{batch_dim1}RR'
)
in0_data = self.op_data['input0']
in1_data = self.op_data['input1']
in0_parition_dict = {}
if batch_dim0 not in in0_data.attrs["broadcast_dims"]:
in0_parition_dict[batch_dim0] = mesh_dim0
if batch_dim1 not in in0_data.attrs["broadcast_dims"]:
in0_parition_dict[batch_dim1] = mesh_dim1
in1_parition_dict = {}
if batch_dim0 not in in1_data.attrs["broadcast_dims"]:
in1_parition_dict[batch_dim0] = mesh_dim0
if batch_dim1 not in in1_data.attrs["broadcast_dims"]:
in1_parition_dict[batch_dim1] = mesh_dim1
batch_partition_dict = {batch_dim0: mesh_dim0, batch_dim1: mesh_dim1}
in0_parition_dict = self._recover_bcast_partition_dict(
batch_partition_dict, in0_data)
in1_parition_dict = self._recover_bcast_partition_dict(
batch_partition_dict, in1_data)
out_partition_dict = {batch_dim0: mesh_dim0, batch_dim1: mesh_dim1}
dim_partition_dict_mapping = {
"input0": in0_parition_dict,
"input1": in1_parition_dict,
"output0": out_partition_dict,
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if len(sharding_spec_mapping) == 0:
return None
return self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
def _split_batch_dim_lhs_space(self, batch_dim, mesh_dim0, mesh_dim1,
device_mesh):
name = (
f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R = '
f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R x {DimSpec(mesh_dim0)}b{batch_dim}RR'
)
in0_data = self.op_data['input0']
in1_data = self.op_data['input1']
in0_parition_dict = {batch_dim: mesh_dim0}
in1_parition_dict = {batch_dim: mesh_dim0}
in0_lhs_split_dim = -1 if self.op0_transpose else -2
in0_parition_dict[in0_lhs_split_dim] = mesh_dim1
in0_parition_dict = self._recover_bcast_partition_dict(
in0_parition_dict, in0_data)
in1_parition_dict = self._recover_bcast_partition_dict(
in1_parition_dict, in1_data)
out_partition_dict = {batch_dim: mesh_dim0, -2: mesh_dim1}
dim_partition_dict_mapping = {
"input0": in0_parition_dict,
"input1": in1_parition_dict,
"output0": out_partition_dict,
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if len(sharding_spec_mapping) == 0:
return None
return self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
def _split_batch_dim_rhs_space(self, batch_dim, mesh_dim0, mesh_dim1,
device_mesh):
name = (
f'{DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)} = '
f'{DimSpec(mesh_dim0)}b{batch_dim}RR x {DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)}'
)
in0_data = self.op_data['input0']
in1_data = self.op_data['input1']
in0_parition_dict = {batch_dim: mesh_dim0}
in1_parition_dict = {batch_dim: mesh_dim0}
in1_rhs_split_dim = -2 if self.op1_transpose else -1
in1_parition_dict[in1_rhs_split_dim] = mesh_dim1
in0_parition_dict = self._recover_bcast_partition_dict(
in0_parition_dict, in0_data)
in1_parition_dict = self._recover_bcast_partition_dict(
in1_parition_dict, in1_data)
out_partition_dict = {batch_dim: mesh_dim0, -1: mesh_dim1}
dim_partition_dict_mapping = {
"input0": in0_parition_dict,
"input1": in1_parition_dict,
"output0": out_partition_dict,
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if len(sharding_spec_mapping) == 0:
return None
return self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
def _split_batch_dim_both_contract(self, batch_dim, mesh_dim0, mesh_dim1,
device_mesh):
name = (
f'{DimSpec(mesh_dim0)}b{batch_dim}RR = '
f'{DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)} x '
f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R_AR{mesh_dim1}'
)
in0_data = self.op_data['input0']
in1_data = self.op_data['input1']
in0_parition_dict = {batch_dim: mesh_dim0}
in1_parition_dict = {batch_dim: mesh_dim0}
in0_contract_dim = -2 if self.op0_transpose else -1
in1_contract_dim = -1 if self.op1_transpose else -2
in0_parition_dict[in0_contract_dim] = mesh_dim1
in1_parition_dict[in1_contract_dim] = mesh_dim1
in0_parition_dict = self._recover_bcast_partition_dict(
in0_parition_dict, in0_data)
in1_parition_dict = self._recover_bcast_partition_dict(
in1_parition_dict, in1_data)
out_partition_dict = {batch_dim: mesh_dim0}
dim_partition_dict_mapping = {
"input0": in0_parition_dict,
"input1": in1_parition_dict,
"output0": out_partition_dict,
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if len(sharding_spec_mapping) == 0:
return None
# get communication actions
communication_action_mapping = {}
output0_comm_action = CommSpec(
comm_pattern='all_reduce',
sharding_spec=sharding_spec_mapping['output0'],
logical_process_axis=[mesh_dim1],
)
communication_action_mapping['output0'] = output0_comm_action
return self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def _split_batch_dim_both_contract_rs(self, batch_dim, mesh_dim0, mesh_dim1,
device_mesh):
name = (
f'{DimSpec(mesh_dim0)}b{batch_dim}RR = '
f'{DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)} x '
f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R_AR{mesh_dim1}'
)
in0_data = self.op_data['input0']
in1_data = self.op_data['input1']
in0_parition_dict = {batch_dim: mesh_dim0}
in1_parition_dict = {batch_dim: mesh_dim0}
in0_contract_dim = -2 if self.op0_transpose else -1
in1_contract_dim = -1 if self.op1_transpose else -2
in0_parition_dict[in0_contract_dim] = mesh_dim1
in1_parition_dict[in1_contract_dim] = mesh_dim1
in0_parition_dict = self._recover_bcast_partition_dict(
in0_parition_dict, in0_data)
in1_parition_dict = self._recover_bcast_partition_dict(
in1_parition_dict, in1_data)
out_partition_dict = {batch_dim: mesh_dim0}
dim_partition_dict_mapping = {
"input0": in0_parition_dict,
"input1": in1_parition_dict,
"output0": out_partition_dict,
}
mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if len(mm_out_sharding_spec_mapping) == 0:
return []
strategies = []
for rs_dim in range(self.num_out_dims):
if rs_dim != batch_dim:
name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [
'R'
] * self.num_out_dims, ['R'] * self.num_out_dims
name_in0[batch_dim], name_in0[-1] = str(
DimSpec(mesh_dim0)), str(DimSpec(mesh_dim1))
name_in1[batch_dim], name_in1[-2] = str(
DimSpec(mesh_dim0)), str(DimSpec(mesh_dim1))
name_in1[batch_dim], name_out0[rs_dim] = str(
DimSpec(mesh_dim0)), str(DimSpec(mesh_dim1))
name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join(
name_in1), ', '.join(name_out0)
name = f'[{name_out0}] = [{name_in0}] x [{name_in1}]_reducescatter{(rs_dim, DimSpec(mesh_dim1))}'
ret = self._split_both_contract_rs(
name, rs_dim, mesh_dim1,
mm_out_sharding_spec_mapping['output0'],
dim_partition_dict_mapping, device_mesh)
if ret:
strategies.append(ret)
return strategies
def _dp_strategies(self, device_mesh):
strategies = []
# S0R = S0R x RR
strategies.append(self._split_lhs_space_only([0], device_mesh))
# S1R = S1R x RR
strategies.append(self._split_lhs_space_only([1], device_mesh))
# S01R = S01R x RR
strategies.append(self._split_lhs_space_only([0, 1], device_mesh))
return strategies
def _tp_strategies(self, device_mesh: LogicalDeviceMesh):
strategies = []
# RR = RS x SR _ AR
strategies.append(self._recompute_split_both_contract([0], device_mesh))
strategies.append(self._recompute_split_both_contract([1], device_mesh))
strategies.append(
self._recompute_split_both_contract([0, 1], device_mesh))
if device_mesh.config.enable_reduce_scatter:
# RS x SR _ reduce scatter
strategies.extend(
self._recompute_split_both_contract_rs([0], device_mesh))
strategies.extend(
self._recompute_split_both_contract_rs([1], device_mesh))
strategies.extend(
self._recompute_split_both_contract_rs([0, 1], device_mesh))
# RS = RR x RS
strategies.append(self._split_rhs_space_only([0], device_mesh))
strategies.append(self._split_rhs_space_only([1], device_mesh))
strategies.append(self._split_rhs_space_only([0, 1], device_mesh))
# RS = RS x SS _ AR
strategies.append(
self._split_rhs_space_both_contract([0], [1], device_mesh))
strategies.append(
self._split_rhs_space_both_contract([1], [0], device_mesh))
if device_mesh.config.enable_reduce_scatter:
# RS x SS _ reduce scatter
strategies.extend(
self._split_rhs_space_both_contract_rs([0], [1], device_mesh))
strategies.extend(
self._split_rhs_space_both_contract_rs([1], [0], device_mesh))
return strategies
def _mix_strategies(self, device_mesh):
strategies = []
# SR = SS x SR_AR
strategies.append(
self._split_lhs_space_both_contract([0], [1], device_mesh))
strategies.append(
self._split_lhs_space_both_contract([1], [0], device_mesh))
if device_mesh.config.enable_reduce_scatter:
# RS x SS _ reduce scatter
strategies.extend(
self._split_lhs_space_both_contract_rs([0], [1], device_mesh))
strategies.extend(
self._split_lhs_space_both_contract_rs([1], [0], device_mesh))
# SS = SR x RS
strategies.append(self._split_lhs_space_rhs_space([0], [1],
device_mesh))
strategies.append(self._split_lhs_space_rhs_space([0], [1],
device_mesh))
# RR = RR x RR
strategies.append(self._non_split(device_mesh))
return strategies
def _bmm_strategies(self, device_mesh: LogicalDeviceMesh):
strategies = []
bmm_dim = len(self.op_data['output0'].shape)
if bmm_dim >= 3:
for batch_dim in range(0, bmm_dim - 2):
strategies.append(
self._split_one_batch_dim(batch_dim, [0], device_mesh))
strategies.append(
self._split_one_batch_dim(batch_dim, [1], device_mesh))
strategies.append(
self._split_one_batch_dim(batch_dim, [0, 1], device_mesh))
strategies.append(
self._split_batch_dim_lhs_space(batch_dim, [0], [1],
device_mesh))
strategies.append(
self._split_batch_dim_lhs_space(batch_dim, [1], [0],
device_mesh))
strategies.append(
self._split_batch_dim_rhs_space(batch_dim, [0], [1],
device_mesh))
strategies.append(
self._split_batch_dim_rhs_space(batch_dim, [1], [0],
device_mesh))
strategies.append(
self._split_batch_dim_both_contract(batch_dim, [0], [1],
device_mesh))
strategies.append(
self._split_batch_dim_both_contract(batch_dim, [1], [0],
device_mesh))
if device_mesh.config.enable_reduce_scatter:
strategies.extend(
self._split_batch_dim_both_contract_rs(
batch_dim, [0], [1], device_mesh))
strategies.extend(
self._split_batch_dim_both_contract_rs(
batch_dim, [1], [0], device_mesh))
if bmm_dim >= 4:
for batch_dim0 in range(0, bmm_dim - 2):
for batch_dim1 in range(0, bmm_dim - 2):
if batch_dim0 != batch_dim1:
strategies.append(
self._split_two_batch_dims(
batch_dim0, batch_dim1, [0], [1],
device_mesh))
return strategies
def _collect_strategies(self, device_mesh):
strategies_vector = StrategiesVector(self)
dp_strategies = self._dp_strategies(device_mesh)
tp_strategies = self._tp_strategies(device_mesh)
mix_strategies = self._mix_strategies(device_mesh)
bmm_strategies = self._bmm_strategies(device_mesh)
strategies_vector.extend(dp_strategies)
strategies_vector.extend(tp_strategies)
strategies_vector.extend(mix_strategies)
strategies_vector.extend(bmm_strategies)
return strategies_vector
def is_fp16(self):
builder_flags = get_builder_flags()
return builder_flags & (1 << int(trt.BuilderFlag.FP16)) != 0
def _get_math_time(self, strategy, device_mesh):
shape_in0 = strategy.sharding_specs[
'input0'].get_sharded_shape_per_device()
shape_out = strategy.sharding_specs[
'output0'].get_sharded_shape_per_device()
m, n = shape_out[-2], shape_out[-1]
batches = shape_out[:-2]
k = shape_in0[-2] if self.op0_transpose else shape_in0[-1]
macs_shape = batches + [m, n, k]
macs = reduce(operator.mul, macs_shape, 1) * 2
config = device_mesh.config
cluster_info = device_mesh.cluster_info
dtype = self.dtype
# For fp16 matmul ops that use_fp32_acc=True.
# They are mistaken for fp32 ops since all of their IO tensors use fp32 dtype.
if self.is_fp16() and self.dtype == "float32":
dtype = "float16"
math_throughput_tflops = getattr(cluster_info.math_throughput, dtype)
assert math_throughput_tflops != 0, \
"Undefined {} math throughput of cluster {}".format(dtype, config.cluster_key)
math_time = macs / math_throughput_tflops * 1e-6 * cluster_info.math_efficiency
return math_time
def _update_memory_cost(self, strategies):
super()._update_memory_cost(strategies)
# For fp16 matmul ops that use_fp32_acc=True.
# Their memory footprints are calculated based on fp32 IO tensors.
# Actually they will use fp16 IO tensors after fused.
# So we divide all the memory footprints by 2.
if self.is_fp16() and self.dtype == "float32":
for strategy in strategies:
strategy.inout_memory_footprint /= 2
strategy.peak_memory_footprint /= 2
strategy.comm_buff_memory_footprint /= 2

View File

@ -1,376 +0,0 @@
from abc import ABC
from ..config import CostModel
from ..device_mesh import LogicalDeviceMesh
from .comm_spec import CommSpec
from .sharding_spec import ShardingSpec
from .sharding_strategy import ShardingStrategy, StrategiesVector
class Node(ABC):
def __init__(self, layer):
self._layer = layer
self.is_shape_io = self._layer.is_shape_io
self._inputs = []
self._outputs = []
self.predecessor_nodes = []
self.predecessor_nodes_out_index = {}
self.successor_nodes = []
self.op_data = {}
self.global_to_local_op_name = {}
self.num_inputs = 0
self.is_replicated = layer.attrs.get("is_replicated", False)
self.same_spec_id = layer.attrs.get("same_spec_id", -1)
self.is_fake = self.same_spec_id != -1
self.building_block_id = layer.attrs.get("building_block_id", -1)
self.cost_level = -1
self.stage_type = layer.attrs.get("stage_type", None)
self.in_start_block = layer.attrs.get("in_start_block", False)
self.in_end_block = layer.attrs.get("in_end_block", False)
self.in_slowest_block = layer.attrs.get("in_slowest_block", False)
for i, input in enumerate(layer.inputs):
if input is None:
self._inputs.append(None)
self.op_data[f'input{i}'] = None
continue
input = input.copy()
input.attrs["broadcast_dims"] = []
self._inputs.append(input)
self.op_data[f'input{i}'] = input
self.global_to_local_op_name[input.name] = f'input{i}'
for i, output in enumerate(layer.outputs):
output = output.copy()
output.attrs["broadcast_dims"] = []
self._outputs.append(output)
self.op_data[f'output{i}'] = output
self.global_to_local_op_name[output.name] = f'output{i}'
self.sharding_weight = 1.0
self.resharding_weight = 1.0
self.pipeline_weight = 0
self.node_name = layer.name
self.node_type = 'normal_node'
self.num_inputs = layer.num_inputs
self.num_outputs = layer.num_outputs
self.dtype = layer.as_trt().precision
self.strategies_vector = []
self.node_runtime_profiler = None
def post_init(self, graph):
for input in self.inputs:
if input is None:
self.predecessor_nodes.append(None)
continue
if input.producer is None:
predecessor_node = graph.get_node(input.name)
self.predecessor_nodes.append(predecessor_node)
self.predecessor_nodes_out_index[predecessor_node] = 0
predecessor_node.successor_nodes.append(self)
else:
predecessor_node = graph.get_node(input.producer.name)
self.predecessor_nodes.append(predecessor_node)
self.predecessor_nodes_out_index[
predecessor_node] = input.output_index
predecessor_node.successor_nodes.append(self)
@property
def layer(self):
return self._layer
def get_input(self, index):
return self._inputs[index]
@property
def inputs(self):
return self._inputs
def get_output(self, index):
return self._outputs[index]
@property
def outputs(self):
return self._outputs
def collect_strategies(self, device_mesh):
strategies_vector = self._collect_strategies(device_mesh)
strategies_vector = self._post_process(strategies_vector)
self._update_sharding_cost(strategies_vector, device_mesh)
self.strategies_vector = strategies_vector
return self.strategies_vector
def _set_strategy(self, strategy, device_mesh):
strategies_vector = StrategiesVector(self)
if strategy is None:
dim_partition_dict_mapping = {}
for i in range(self.num_inputs):
dim_partition_dict_mapping[f'input{i}'] = {}
for i in range(self.num_outputs):
dim_partition_dict_mapping[f'output{i}'] = {}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
assert 0 != len(
sharding_spec_mapping
), f'failed to set default(all Replicate) strategy for node {self.node_name}'
name = 'RRs'
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
else:
sharding_specs_map = strategy.sharding_specs
comm_specs_map = strategy.communication_actions
dim_partition_dict_mapping = {}
for op_name, sharding_spec in sharding_specs_map.items():
dim_partition_dict_mapping[
op_name] = sharding_spec.dim_partition_dict
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
assert 0 != len(
sharding_spec_mapping
), f'failed to set strategy for node {self.node_name}'
comm_specs_mapping = {}
if len(comm_specs_map) > 0:
for op_name, comm_spec in comm_specs_map.items():
comm_specs_mapping[op_name] = CommSpec(
comm_pattern=comm_spec.comm_pattern,
sharding_spec=sharding_spec_mapping[op_name],
logical_process_axis=comm_spec.logical_process_axis,
)
strategies_vector.append(
self._get_sharding_strategy(
name=strategy.name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=comm_specs_mapping))
return strategies_vector
def set_strategy(self, strategy, device_mesh):
strategies_vector = self._set_strategy(strategy, device_mesh)
strategies_vector = self._post_process(strategies_vector)
self._update_sharding_cost(strategies_vector, device_mesh)
self.strategies_vector = strategies_vector
return self.strategies_vector
def update_resharding_cost(self):
self._update_resharding_cost(self.strategies_vector)
return self.strategies_vector
def _to_sharding_spec_mapping(self, dim_partition_dict_mapping,
device_mesh):
results = {}
for op_data_name, dim_partition_dict in dim_partition_dict_mapping.items(
):
if op_data_name in self.op_data:
op_data = self.op_data[op_data_name]
def _to_sharding_spec(op_data, dim_partition_dict):
sharding_spec = ShardingSpec(
device_mesh,
op_data.dtype_str_size, [*op_data.shape],
[*op_data.max_shape], [*op_data.raw_shape],
dim_partition_dict=dim_partition_dict)
if sharding_spec.sanity_check():
return sharding_spec
else:
return None
sharding_spec = _to_sharding_spec(op_data, dim_partition_dict)
if sharding_spec:
results[op_data_name] = sharding_spec
else:
return {}
return results
def _get_sharding_strategy(self, name, sharding_spec_mapping,
communication_action_mapping):
return ShardingStrategy(
name=name,
sharding_specs=sharding_spec_mapping,
communication_actions=communication_action_mapping,
)
def _remove_duplicated_strategy(self, strategies_vector):
name_checklist = []
remove_list = []
for strategy in strategies_vector:
if strategy.name not in name_checklist:
name_checklist.append(strategy.name)
else:
remove_list.append(strategy)
for strategy in remove_list:
strategies_vector.remove(strategy)
def _post_process(self, strategies_vector):
# TODO: deal with transpose and dimension 1 problem in ClossalAI, which have been processed before
for i in range(len(strategies_vector) - 1, -1, -1):
if strategies_vector[i] is None:
strategies_vector.pop(i)
self._remove_duplicated_strategy(strategies_vector)
return strategies_vector
def _profile_sharding_cost(self, strategy, device_mesh: LogicalDeviceMesh):
elapsed_time = self.node_runtime_profiler.runtime_profile(
self.layer, {}, {}, strategy, device_mesh)
return elapsed_time
def _model_sharding_cost_from_s_curve(self, strategy,
device_mesh: LogicalDeviceMesh):
'''
[ToDo] preprofile the s_curve
'''
sharding_cost = 0.0
return sharding_cost
# this method might be overwritten by some Ops
def _get_math_time(self, strategy, device_mesh: LogicalDeviceMesh):
return 0.0
# this method might be overwritten by some Ops
def _get_memory_time(self, strategy, device_mesh: LogicalDeviceMesh):
memory_time = (strategy.inout_memory_footprint /
device_mesh.cluster_info.memory_bw * 1e-3 *
device_mesh.cluster_info.memory_efficiency)
return memory_time
def _model_sharding_cost_from_alpha_beta(self, strategy,
device_mesh: LogicalDeviceMesh):
math_time = self._get_math_time(strategy, device_mesh)
mem_time = self._get_memory_time(strategy, device_mesh)
return max(math_time, mem_time)
def _get_communication_cost(self, strategy):
total_comm_cost = 0.0
for op_data_name, comm_spec in strategy.communication_actions.items():
comm_cost = comm_spec.get_comm_cost()
total_comm_cost = total_comm_cost + comm_cost
return total_comm_cost
def _update_sharding_cost(self, strategies, device_mesh: LogicalDeviceMesh):
self._update_memory_cost(strategies)
if device_mesh.config.sharding_cost_model == CostModel.ALPHA_BETA:
for strategy in strategies:
strategy.sharding_cost = self._model_sharding_cost_from_alpha_beta(
strategy, device_mesh)
elif device_mesh.config.sharding_cost_model == CostModel.S_CURVE:
for strategy in strategies:
strategy.sharding_cost = self._model_sharding_cost_from_s_curve(
strategy, device_mesh)
elif device_mesh.config.sharding_cost_model == CostModel.PROFILE:
for strategy in strategies:
strategy.alpha_beta_cost = self._model_sharding_cost_from_alpha_beta(
strategy, device_mesh)
if self.is_shape_io:
strategy.sharding_cost = strategy.alpha_beta_cost
else:
strategy.sharding_cost = self._profile_sharding_cost(
strategy, device_mesh)
elif device_mesh.config.sharding_cost_model == CostModel.ZERO:
for strategy in strategies:
strategy.sharding_cost = 0.0
else:
assert False, 'unsupport sharding cost model option: {}'.format(
device_mesh.config.sharding_cost_model)
for strategy in strategies:
strategy.communication_cost = self._get_communication_cost(strategy)
def _compute_resharding_cost(self, pre_sharding_sepc, cur_sharding_spec,
op_data):
transform_path, comm_action_sequence, resharding_cost = cur_sharding_spec.device_mesh.shape_consistency_manager.shape_consistency(
pre_sharding_sepc, cur_sharding_spec)
return (transform_path, comm_action_sequence, resharding_cost)
def _update_resharding_cost(self, strategies):
for strategy in strategies:
resharding_costs = {}
for pre_node, out_index in self.predecessor_nodes_out_index.items():
if pre_node is None:
continue
pre_node_out_data_name = pre_node.get_output(out_index).name
pre_node_out_data_lname = pre_node.global_to_local_op_name[
pre_node_out_data_name]
if pre_node_out_data_name not in self.global_to_local_op_name:
print(f"pre_node_out_data_name = {pre_node_out_data_name}")
continue
cur_node_inp_data_lname = self.global_to_local_op_name[
pre_node_out_data_name]
cur_sharding_spec = strategy.sharding_specs[
cur_node_inp_data_lname]
pre_node_out_sharding_specs = []
for pre_strategy in pre_node.strategies_vector:
pre_node_out_sharding_specs.append(
pre_strategy.sharding_specs[pre_node_out_data_lname])
if pre_node not in resharding_costs:
resharding_costs[pre_node.node_name] = []
for prev_sharding_spec in pre_node_out_sharding_specs:
resharding_cost = self._compute_resharding_cost(
prev_sharding_spec, cur_sharding_spec,
self.op_data[cur_node_inp_data_lname])
resharding_costs[pre_node.node_name].append(resharding_cost)
strategy.resharding_costs = resharding_costs
def _enumerate_all_possible_1d_sharding(self, mesh_dim, dim_size):
dim_partition_list = []
for i in range(dim_size):
dim_partition_list.append({i: mesh_dim})
return dim_partition_list
def _enumerate_all_possible_2d_sharding(self, mesh_dim0, mesh_dim1,
dim_size):
dim_partition_list = []
for i in range(dim_size):
for j in range(dim_size):
if i != j:
dim_partition_list.append({i: mesh_dim0, j: mesh_dim1})
return dim_partition_list
def _update_memory_cost(self, strategies):
for strategy in strategies:
inout_memory_footprint, max_inout_memory_footprint = 0.0, 0.0
for spec in strategy.sharding_specs.values():
inout_memory_footprint += spec.get_sharded_size_per_device()
max_inout_memory_footprint += spec.get_max_sharded_size_per_device(
)
# the communication happens
comm_buffer_footprint, max_comm_buffer_footprint = 0.0, 0.0
for comm_spec in strategy.communication_actions.values():
comm_buffer_footprint += comm_spec.get_mem_cost()
max_comm_buffer_footprint += comm_spec.get_max_mem_cost()
# when doing the output0 comm action, the input buffer should be released, the buffer is used to estimate the memory time
# rather than memory usage
strategy.inout_memory_footprint = inout_memory_footprint
strategy.comm_buff_memory_footprint = comm_buffer_footprint
strategy.peak_memory_footprint = max(max_inout_memory_footprint,
max_comm_buffer_footprint)
# The const memory (weight) is recorded in constant layers and should be accumulated
strategy.const_memory_footprint = 0.0
def _generate_bcast_dims(self, batch_dims, out_data_shape):
for output in self.outputs:
if output.broadcast_across_batch:
for bs in batch_dims:
if output.shape[bs] == 1 and output.shape[
bs] != out_data_shape[bs]:
output.attrs["broadcast_dims"].append(bs)
def _recover_bcast_partition_dict(self, partition_dict, op_data):
ret = {}
for data_dim, mesh_dim in partition_dict.items():
if data_dim not in op_data.attrs[
"broadcast_dims"] and data_dim + len(
op_data.shape) not in op_data.attrs[
"broadcast_dims"] and op_data.shape[data_dim] != 1:
ret[data_dim] = mesh_dim
return ret

View File

@ -1,60 +0,0 @@
from .node import Node
from .sharding_strategy import StrategiesVector
class Normalization(Node):
def __init__(self, layer):
super().__init__(layer)
layer.to_subclass()
self.axes = layer.as_trt().axes
self.weight_bias_dim_base = 0
layer.to_base_class()
def _collect_strategies(self, device_mesh):
dim_partition_list = []
dim_size = len(self.op_data['input0'].shape)
dim_partition_list.append({})
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
strategies_vector = StrategiesVector(self)
for dim_partition_dict in dim_partition_list:
shard_reduction_axes = False
for dim in range(len(self.get_input(0).shape)):
if (self.axes & (1 << dim)) and dim in dim_partition_dict:
shard_reduction_axes = True
break
if shard_reduction_axes:
continue
dim_partition_dict_mapping = {
"input0": dim_partition_dict,
"output0": dim_partition_dict,
}
if self.num_inputs >= 2:
dim_partition_dict_mapping['input1'] = {}
if self.num_inputs >= 3:
dim_partition_dict_mapping['input2'] = {}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
name = '{} = {} <normalization op> scale {}, bias {}'.format(
sharding_spec_mapping['output0'].sharding_sequence,
sharding_spec_mapping['input0'].sharding_sequence,
sharding_spec_mapping['input1'].sharding_sequence
if self.num_inputs >= 2 else 'None',
sharding_spec_mapping['input2'].sharding_sequence
if self.num_inputs >= 3 else 'None',
)
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector

View File

@ -1,79 +0,0 @@
from .node import Node
from .sharding_strategy import StrategiesVector
class OuputNode(Node):
def _update_memory_cost(self, strategies):
for strategy in strategies:
if not self.no_memory_footprint:
strategy.const_memory_footprint = strategy.sharding_specs[
'input0'].get_max_sharded_size_per_device()
def __init__(self, tensor):
self._layer = None
self.is_shape_io = False
self._inputs = []
self._outputs = []
self.predecessor_nodes = []
self.predecessor_nodes_out_index = {}
self.successor_nodes = []
self.op_data = {}
self.global_to_local_op_name = {}
self.is_replicated = tensor.attrs.get("is_replicated", False)
self.same_spec_id = tensor.attrs.get("same_spec_id", -1)
self.no_memory_footprint = tensor.attrs.get("no_memory_footprint",
False)
self.building_block_id = -1
self.cost_level = -1
self.stage_type = None
self.in_start_block = None
self.in_end_block = None
self.in_slowest_block = None
input = tensor.copy()
self._inputs.append(input)
self.op_data['input0'] = input
self.global_to_local_op_name[input.name] = 'input0'
self.sharding_weight = 1.0
self.resharding_weight = 1.0
self.pipeline_weight = 0
self.node_name = tensor.name
self.node_type = 'output_node'
self.num_inputs = 0
self.num_outputs = 1
self.dtype = tensor.dtype
self.strategies_vector = []
self.node_runtime_profiler = None
def _collect_strategies(self, device_mesh):
dim_partition_list = []
dim_size = len(self.op_data['input0'].shape)
dim_partition_list.append({})
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
strategies_vector = StrategiesVector(self)
for dim_partition_dict in dim_partition_list:
dim_partition_dict_mapping = {'input0': dim_partition_dict}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
sharding_seq = sharding_spec_mapping['input0'].sharding_sequence
sharding_strategy = self._get_sharding_strategy(
name=f'output-op {sharding_seq}',
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector
def _profile_sharding_cost(self, strategy, device_mesh):
return 0.0

View File

@ -1,67 +0,0 @@
import copy
from enum import Enum
from .comm_spec import CommSpec
from .identity_node import Identity
from .sharding_strategy import StrategiesVector
class P2PType(Enum):
CROSS_DEVICE = 0
CROSS_HOST = 1
class P2PNode(Identity):
def __init__(self, layer):
super().__init__(layer)
self.p2p_type = layer.attrs["p2p_type"]
self.is_fake = True
def _collect_strategies(self, device_mesh):
# one input for softmax node
predecessor = self.predecessor_nodes[0]
strategies_vector = StrategiesVector(self)
for idx, strategy in enumerate(predecessor.strategies_vector):
# current node's local name input0 -> global name xxx
global_input_name = self.op_data['input0'].name
# global name xxx -> pre node local output name
prenode_local_name = predecessor.global_to_local_op_name[
global_input_name]
dim_partition_dict = copy.deepcopy(
strategy.sharding_specs[prenode_local_name].dim_partition_dict)
in0_partition_dict = dim_partition_dict
out_partition_dict = copy.deepcopy(dim_partition_dict)
dim_partition_dict_mapping = {
"input0": in0_partition_dict,
"output0": out_partition_dict,
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
logical_process_axis = [
['p2p_cross_device']
] if self.p2p_type == P2PType.CROSS_DEVICE else [['p2p_cross_host']]
# get communication action mapping
communication_action_mapping = {}
output0_comm_action = CommSpec(
comm_pattern='peer_to_peer',
sharding_spec=sharding_spec_mapping['output0'],
logical_process_axis=logical_process_axis,
)
communication_action_mapping['output0'] = output0_comm_action
name = '{} = <P2P op> {}'.format(
sharding_spec_mapping['output0'].sharding_sequence,
sharding_spec_mapping['input0'].sharding_sequence)
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategies_vector.append(sharding_strategy)
return strategies_vector
def _profile_sharding_cost(self, strategy, device_mesh):
return 0.0

View File

@ -1,40 +0,0 @@
from tensorrt_llm.network import PluginInfo, get_plugin_info
from .node import Node
from .sharding_strategy import StrategiesVector
class PluginNode(Node):
def __init__(self, layer):
super().__init__(layer)
layer.to_subclass()
self.plugin = layer.as_trt().plugin
self.plugin_type: str = self.plugin.plugin_type
self.plugin_info: PluginInfo = get_plugin_info(layer.graph.as_trt(),
layer.name)
layer.to_base_class()
def _collect_strategies(self, device_mesh):
raise NotImplementedError(
f"Auto parallel does not support {self.plugin_type} plugin right now."
)
def _default_strategy(self, device_mesh):
strategies_vector = StrategiesVector(self)
dim_partition_dict_mapping = {}
for idx in range(self.num_inputs):
dim_partition_dict_mapping[f'input{idx}'] = {}
for idx in range(self.num_outputs):
dim_partition_dict_mapping[f'output{idx}'] = {}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
return strategies_vector
name = '{}_all_replicate'.format(self.plugin_type)
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector

View File

@ -1,27 +0,0 @@
import tensorrt as trt
from tensorrt_llm._utils import trt_dtype_to_str
from ..matmul_node import MatrixMultiply
from ..plugin_node import PluginNode
class GemmPlugin(MatrixMultiply, PluginNode):
def __init__(self, layer):
PluginNode.__init__(self, layer)
batch_dims = [i for i in range(len(self.get_output(0).shape))][:-2]
self._generate_bcast_dims(batch_dims, self.get_output(0).shape)
pfc_as_list = self.plugin_info.pfc_as_list
self.op0_transpose = (pfc_as_list['transa'][0] == 1)
self.op1_transpose = (pfc_as_list['transb'][0] == 1)
self.num_out_dims = len(self.get_output(0).shape)
self.dtype = trt_dtype_to_str(trt.DataType(pfc_as_list['type_id'][0]))
def _collect_strategies(self, device_mesh):
strategies_vector = MatrixMultiply._collect_strategies(
self, device_mesh)
return strategies_vector
def _get_math_time(self, strategy, device_mesh):
return MatrixMultiply._get_math_time(self, strategy, device_mesh)

View File

@ -1,437 +0,0 @@
from enum import Enum, auto
import numpy as np
import torch
from tensorrt_llm.functional import AttentionMaskType, PositionEmbeddingType
from tensorrt_llm.quantization import QuantMode
from ..plugin_node import PluginNode
from ..sharding_strategy import StrategiesVector
# WARNING: Must in sync with IdxEntry in cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h
class IdxEntry(Enum):
QKV_TENSOR = auto()
K_TENSOR = auto()
V_TENSOR = auto()
CONTEXT_FMHA_CUSTOM_MASK = auto()
SEQUENCE_LENGTH = auto()
HOST_PAST_KEY_VALUE_LENGTHS = auto()
HOST_MAX_ATTENTION_WINDOW = auto()
HOST_SINK_TOKEN_LENGTH = auto()
CONTEXT_LENGTHS = auto()
CACHE_INDIR = auto()
REQUEST_TYPES = auto()
KV_CACHE_BLOCK_OFFSETS = auto()
HOST_KV_CACHE_BLOCK_OFFSETS = auto()
HOST_KV_CACHE_POOL_POINTERS = auto()
HOST_KV_CACHE_POOL_MAPPING = auto()
PAST_KEY_VALUE = auto()
KV_CACHE_QUANTIZATION_SCALE = auto()
KV_CACHE_DEQUANTIZATION_SCALE = auto()
ATTENTION_OUTPUT_QUANTIZATION_SCALE = auto()
ATTENTION_OUTPUT_SF_SCALE = auto()
ROTARY_INV_FREQ = auto()
ROTARY_COS_SIN = auto()
ALIBI_SLOPES = auto()
RELATIVE_ATTENTION_BIAS = auto()
CROSS_KV = auto()
CROSS_KV_LENGTH = auto()
ENCODER_INPUT_LENGTH = auto()
HOST_CONTEXT_LENGTH = auto()
QKV_BIAS_TENSOR = auto()
SPEC_DECODING_GENERATION_LENGTHS = auto()
SPEC_DECODING_PACKED_MASK = auto()
SPEC_DECODING_POSITION_OFFSETS = auto()
MROPE_ROTARY_COS_SIN = auto()
MROPE_POSITION_DELTAS = auto()
HOST_RUNTIME_PERF_KNOBS = auto()
HOST_CONTEXT_PROGRESS = auto()
MLA_FUSED_Q_PROJ_TENSOR = auto()
MLA_Q_B_PROJ_TENSOR = auto()
MLA_KV_B_PROJ_TENSOR = auto()
LOGN_SCALING = auto()
class IdxEntryParser:
def __init__(self, plugin_info):
self.num_kv_heads = plugin_info.pfc_as_list['num_kv_heads'][0]
self.unfuse_qkv_gemm = bool(
plugin_info.pfc_as_list['unfuse_qkv_gemm'][0])
self.use_fp8_context_fmha = bool(
plugin_info.pfc_as_list['use_fp8_context_fmha'][0])
self.fuse_fp4_quant = bool(plugin_info.pfc_as_list['fuse_fp4_quant'][0])
self.mask_type = AttentionMaskType(
plugin_info.pfc_as_list['mask_type'][0])
self.use_cache = bool(plugin_info.pfc_as_list['use_cache'][0])
self.paged_kv_cache = bool(plugin_info.pfc_as_list['paged_kv_cache'][0])
self.do_cross_attention = bool(
plugin_info.pfc_as_list['do_cross_attention'][0])
self.remove_input_padding = bool(
plugin_info.pfc_as_list['remove_input_padding'][0])
self.qkv_bias_enabled = bool(
plugin_info.pfc_as_list['qkv_bias_enabled'][0])
self.kv_cache_quant_mode = QuantMode(
plugin_info.pfc_as_list['kv_cache_quant_mode'][0])
self.position_embedding_type = PositionEmbeddingType(
plugin_info.pfc_as_list['position_embedding_type'][0])
self.is_spec_decoding_enabled = bool(
plugin_info.pfc_as_list['is_spec_decoding_enabled'][0])
self.is_mla_enabled = bool(plugin_info.pfc_as_list['is_mla_enabled'][0])
self.use_logn_scaling = bool(
plugin_info.pfc_as_list['use_logn_scaling'][0])
self.init_entry_to_index()
# WARNING: Must in sync with GPTAttentionPlugin::isEntryUsed in cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp
def is_entry_used(self, entry: IdxEntry) -> bool:
if entry == IdxEntry.QKV_TENSOR:
return True
elif entry == IdxEntry.K_TENSOR:
return self.unfuse_qkv_gemm
elif entry == IdxEntry.V_TENSOR:
return self.unfuse_qkv_gemm
elif entry == IdxEntry.CONTEXT_FMHA_CUSTOM_MASK:
return self.mask_type == AttentionMaskType.custom_mask
elif entry == IdxEntry.SEQUENCE_LENGTH:
return self.use_cache
elif entry == IdxEntry.HOST_PAST_KEY_VALUE_LENGTHS:
return self.use_cache
elif entry == IdxEntry.HOST_MAX_ATTENTION_WINDOW:
return True
elif entry == IdxEntry.HOST_SINK_TOKEN_LENGTH:
return True
elif entry == IdxEntry.CONTEXT_LENGTHS:
return True
elif entry == IdxEntry.CACHE_INDIR:
return self.use_cache
elif entry == IdxEntry.REQUEST_TYPES:
return True
elif entry == IdxEntry.KV_CACHE_BLOCK_OFFSETS:
return self.use_cache and self.paged_kv_cache
elif entry == IdxEntry.HOST_KV_CACHE_BLOCK_OFFSETS:
return self.use_cache and self.paged_kv_cache
elif entry == IdxEntry.HOST_KV_CACHE_POOL_POINTERS:
return self.use_cache and self.paged_kv_cache
elif entry == IdxEntry.HOST_KV_CACHE_POOL_MAPPING:
return self.use_cache and self.paged_kv_cache
elif entry == IdxEntry.PAST_KEY_VALUE:
return self.use_cache and not self.paged_kv_cache
elif entry == IdxEntry.KV_CACHE_QUANTIZATION_SCALE:
return self.use_cache and self.kv_cache_quant_mode.has_kv_cache_quant(
)
elif entry == IdxEntry.KV_CACHE_DEQUANTIZATION_SCALE:
return self.use_cache and self.kv_cache_quant_mode.has_kv_cache_quant(
)
elif entry == IdxEntry.ATTENTION_OUTPUT_QUANTIZATION_SCALE:
return self.use_fp8_context_fmha and self.kv_cache_quant_mode.has_fp8_qdp(
)
elif entry == IdxEntry.ATTENTION_OUTPUT_SF_SCALE:
return self.fuse_fp4_quant
elif entry == IdxEntry.ROTARY_INV_FREQ:
return self.position_embedding_type.is_rope()
elif entry == IdxEntry.ROTARY_COS_SIN:
return self.position_embedding_type.is_rope()
elif entry == IdxEntry.ALIBI_SLOPES:
return self.position_embedding_type.is_alibi()
elif entry == IdxEntry.RELATIVE_ATTENTION_BIAS:
return self.position_embedding_type == PositionEmbeddingType.relative
elif entry == IdxEntry.CROSS_KV:
return self.do_cross_attention
elif entry == IdxEntry.CROSS_KV_LENGTH:
return self.do_cross_attention
elif entry == IdxEntry.ENCODER_INPUT_LENGTH:
return self.do_cross_attention
elif entry == IdxEntry.HOST_CONTEXT_LENGTH:
return self.remove_input_padding
elif entry == IdxEntry.QKV_BIAS_TENSOR:
return self.qkv_bias_enabled
elif entry == IdxEntry.SPEC_DECODING_PACKED_MASK:
return self.is_spec_decoding_enabled
elif entry == IdxEntry.SPEC_DECODING_POSITION_OFFSETS:
return self.is_spec_decoding_enabled
elif entry == IdxEntry.SPEC_DECODING_GENERATION_LENGTHS:
return self.is_spec_decoding_enabled
elif entry == IdxEntry.MROPE_ROTARY_COS_SIN:
return self.position_embedding_type.is_mrope()
elif entry == IdxEntry.MROPE_POSITION_DELTAS:
return self.position_embedding_type.is_mrope()
elif entry == IdxEntry.HOST_RUNTIME_PERF_KNOBS:
return True
elif entry == IdxEntry.HOST_CONTEXT_PROGRESS:
return True
elif entry == IdxEntry.MLA_FUSED_Q_PROJ_TENSOR:
return self.is_mla_enabled
elif entry == IdxEntry.MLA_Q_B_PROJ_TENSOR:
return self.is_mla_enabled
elif entry == IdxEntry.MLA_KV_B_PROJ_TENSOR:
return self.is_mla_enabled
elif entry == IdxEntry.LOGN_SCALING:
return self.use_logn_scaling
else:
return False
def init_entry_to_index(self):
self.entry_to_index = {}
index = 0
for entry in IdxEntry:
if self.is_entry_used(entry):
self.entry_to_index[entry] = index
index += 1
def get_index(self, entry: IdxEntry) -> int:
if entry not in self.entry_to_index:
raise Exception(
f"Entry {entry} is not existed in gpt attention plugin layer {self.layer.name}"
)
return self.entry_to_index[entry]
def get_partition(device_dim, device_ids):
if device_dim == [0]:
partition = device_ids.shape[0]
elif device_dim == [1]:
partition = device_ids.shape[1]
else:
assert device_dim == [0, 1] or device_dim == [1, 0]
partition = device_ids.size
return partition
class GPTAttentionPlugin(PluginNode):
def __init__(self, layer):
super().__init__(layer)
self.parser = IdxEntryParser(self.plugin_info)
assert self.num_inputs == len(
self.parser.entry_to_index
), f'the number of plugin inputs ({self.num_inputs}) is invalid'
assert self.num_outputs == (
2 if self.parser.is_entry_used(IdxEntry.PAST_KEY_VALUE) else 1
), f'the number of plugin outputs ({self.num_outputs}) has been changed'
def _tp_strategy(self, device_mesh):
strategies_vector = StrategiesVector(self)
head_dim = 1 if self.parser.remove_input_padding else 2
# TODO: allow mesh_dim = [0] or [1]
# for mesh_dim in ([0], [1], [0, 1]):
for mesh_dim in ([0, 1], ):
if self.parser.num_kv_heads != 1:
# MHA or GQA
# TODO: allow to duplicate kv when #kv_head < #partition
q_pdict = {
head_dim: mesh_dim
} # split in heads/hidden dimension
k_pdict = {
head_dim: mesh_dim
} # split in heads/hidden dimension
v_pdict = {
head_dim: mesh_dim
} # split in heads/hidden dimension
pastkv_pdict = {2: mesh_dim} # split in heads dimension
present_kv_pdict = {2: mesh_dim} # split in heads dimension
else:
# MQA
q_pdict = {
head_dim: mesh_dim
} # split in heads/hidden dimension
k_pdict = {} # RR
v_pdict = {} # RR
pastkv_pdict = {} # RR
present_kv_pdict = {} # RR
out0_pdict = {head_dim: mesh_dim}
dim_partition_dict_mapping = {
f'input{self.parser.get_index(IdxEntry.QKV_TENSOR)}': q_pdict,
f'input{self.parser.get_index(IdxEntry.K_TENSOR)}': k_pdict,
f'input{self.parser.get_index(IdxEntry.V_TENSOR)}': v_pdict,
'output0': out0_pdict,
}
if self.parser.is_entry_used(IdxEntry.PAST_KEY_VALUE):
dim_partition_dict_mapping[
f'input{self.parser.get_index(IdxEntry.PAST_KEY_VALUE)}'] = pastkv_pdict
dim_partition_dict_mapping['output1'] = present_kv_pdict
for i in range(self.num_inputs):
if f'input{i}' not in dim_partition_dict_mapping:
dim_partition_dict_mapping[f'input{i}'] = {}
for i in range(self.num_outputs):
if f'output{i}' not in dim_partition_dict_mapping:
dim_partition_dict_mapping[f'output{i}'] = {}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
name = 'gptAttentionPlugin_tp_strategy'
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector
def _dp_strategy(self, device_mesh):
strategies_vector = StrategiesVector(self)
for mesh_dim in ([0], [1], [0, 1]):
dim_partition_dict_mapping = {}
for i in range(self.num_inputs):
dim_partition_dict_mapping[f'input{i}'] = {0: mesh_dim}
for i in range(self.num_outputs):
dim_partition_dict_mapping[f'output{i}'] = {0: mesh_dim}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
name = 'gptAttentionPlugin_dp_strategy'
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector
def _collect_strategies(self, device_mesh):
if device_mesh.size == 1:
default_strategies = self._default_strategy(device_mesh)
else:
# Avoid to use all-replicate strategy for mesh size > 1
# since the CPP runtime does not support it for gpt attention plugin
default_strategies = StrategiesVector(self)
for idx, strategy in enumerate(default_strategies):
strategy.name = 'gptAttentionPlugin_' + strategy.name + f'{idx}'
if self.parser.unfuse_qkv_gemm:
tp_strategies = self._tp_strategy(device_mesh)
default_strategies.extend(tp_strategies)
# if we don't split the batch dim, it should be default strategis
# elif we split the batch dim, it should be dp_strategies
# we can use above information to distinguish the two kinds of strategy
if not self.parser.remove_input_padding:
dp_strategies = self._dp_strategy(device_mesh)
default_strategies.extend(dp_strategies)
return default_strategies
@staticmethod
def parameter_generator(sharding_specs, plugin_info):
def get_shape(entry):
return sharding_specs[
f'input{parser.get_index(entry)}'].get_sharded_shape_per_device(
)
parser = IdxEntryParser(plugin_info)
updated_input_values = {}
batch_size = get_shape(IdxEntry.CONTEXT_LENGTHS)[0]
if parser.use_cache:
beams_width = get_shape(IdxEntry.CACHE_INDIR)[1]
max_seq_length = get_shape(IdxEntry.CACHE_INDIR)[2]
elif not parser.remove_input_padding:
max_seq_length = get_shape(IdxEntry.QKV_BIAS_TENSOR)[1]
else:
max_seq_length = 1
host_request_types = torch.full(
(batch_size, ),
1,
dtype=torch.int32,
device='cpu',
)
updated_input_values[parser.get_index(
IdxEntry.REQUEST_TYPES)] = host_request_types
context_lengths = torch.full(
(batch_size, ),
max_seq_length - 1,
dtype=torch.int32,
device=torch.cuda.current_device(),
)
updated_input_values[parser.get_index(
IdxEntry.CONTEXT_LENGTHS)] = context_lengths
host_max_attention_window_sizes = torch.tensor(
[max_seq_length],
dtype=torch.int32,
device='cpu',
)
updated_input_values[parser.get_index(
IdxEntry.HOST_MAX_ATTENTION_WINDOW
)] = host_max_attention_window_sizes
host_sink_token_length = torch.tensor(
[0],
dtype=torch.int32,
device='cpu',
)
updated_input_values[parser.get_index(
IdxEntry.HOST_SINK_TOKEN_LENGTH)] = host_sink_token_length
if parser.use_cache:
sequence_length = torch.full((batch_size, ),
max_seq_length,
dtype=torch.int32,
device=torch.cuda.current_device())
updated_input_values[parser.get_index(
IdxEntry.SEQUENCE_LENGTH)] = sequence_length
host_past_key_value_length = torch.full((batch_size, ),
max_seq_length - 1,
dtype=torch.int32,
device='cpu')
updated_input_values[parser.get_index(
IdxEntry.HOST_PAST_KEY_VALUE_LENGTHS
)] = host_past_key_value_length
cache_indirections = torch.full(
(batch_size, beams_width, max_seq_length),
0,
dtype=torch.int32,
device=torch.cuda.current_device())
updated_input_values[parser.get_index(
IdxEntry.CACHE_INDIR)] = cache_indirections
if parser.remove_input_padding:
host_context_lengths = torch.full(get_shape(
IdxEntry.HOST_CONTEXT_LENGTH),
max_seq_length - 1,
dtype=torch.int32,
device='cpu')
updated_input_values[parser.get_index(
IdxEntry.HOST_CONTEXT_LENGTH)] = host_context_lengths
return updated_input_values
def _profile_sharding_cost(self, strategy, device_mesh):
sharding_spec = strategy.sharding_specs[
f"input{self.parser.get_index(IdxEntry.QKV_TENSOR)}"]
shard_dims = sharding_spec.dim_partition_dict
device_ids = device_mesh.phy_ids
if 2 in shard_dims:
device_dim = shard_dims[2]
partition = get_partition(device_dim, device_ids)
else:
partition = 1
if self.parser.is_entry_used(IdxEntry.K_TENSOR):
kv_sharding_spec = strategy.sharding_specs[
f"input{self.parser.get_index(IdxEntry.K_TENSOR)}"]
kv_shard_dims = kv_sharding_spec.dim_partition_dict
if 2 in kv_shard_dims:
kv_device_dim = kv_shard_dims[2]
kv_partition = get_partition(kv_device_dim, device_ids)
else:
kv_partition = 1
else:
kv_partition = 1
num_heads = self.plugin_info.pfc_as_ndarray["num_heads"].copy()
num_kv_heads = self.plugin_info.pfc_as_ndarray["num_kv_heads"].copy()
tp_size = self.plugin_info.pfc_as_ndarray["tp_size"].copy()
tp_rank = self.plugin_info.pfc_as_ndarray["tp_rank"].copy()
num_kv_heads = np.maximum(num_kv_heads // kv_partition, 1)
num_heads = np.maximum(num_heads // partition, 1)
tp_size[0] = partition
tp_rank[0] = 0
updated_layer_attrs = {
'tp_size': tp_size,
'tp_rank': tp_rank,
'num_heads': num_heads,
'num_kv_heads': num_kv_heads
}
updated_input_values = self.parameter_generator(strategy.sharding_specs,
self.plugin_info)
elapsed_time = self.node_runtime_profiler.runtime_profile(
self.layer, updated_layer_attrs, updated_input_values, strategy,
device_mesh)
return elapsed_time

View File

@ -1,11 +0,0 @@
from ..identity_node import Identity
from ..plugin_node import PluginNode
class IdentityPlugin(Identity, PluginNode):
def __init__(self, layer):
PluginNode.__init__(self, layer)
def _collect_strategies(self, device_mesh):
return Identity._collect_strategies(self, device_mesh)

View File

@ -1,19 +0,0 @@
import tensorrt as trt
from ..gather_node import Gather
from ..plugin_node import PluginNode
class LookupPlugin(Gather, PluginNode):
def __init__(self, layer):
PluginNode.__init__(self, layer)
self.mode = trt.GatherMode.DEFAULT
self.axis = 0
self.num_elementwise_dims = 0
self.input_id = 1
self.indice_id = 0
self.support_vocab_tp = True
def _collect_strategies(self, device_mesh):
return Gather._collect_strategies(self, device_mesh)

View File

@ -1,28 +0,0 @@
from ..normalization_node import Normalization
from ..plugin_node import PluginNode
class LayernormPlugin(Normalization, PluginNode):
def __init__(self, layer):
PluginNode.__init__(self, layer)
# the is only true for llm model, because layer norm is only effect on hidden dim
hidden_dim = len(self.op_data['input0'].shape) - 1
self.axes = 1 << hidden_dim
self.weight_bias_dim_base = hidden_dim
def _collect_strategies(self, device_mesh):
return Normalization._collect_strategies(self, device_mesh)
class RMSnormPlugin(Normalization, PluginNode):
def __init__(self, layer):
PluginNode.__init__(self, layer)
# the is only true for llm model, because rms norm is only effect on hidden dim
hidden_dim = len(self.op_data['input0'].shape) - 1
self.axes = 1 << hidden_dim
self.weight_bias_dim_base = hidden_dim
def _collect_strategies(self, device_mesh):
return Normalization._collect_strategies(self, device_mesh)

View File

@ -1,73 +0,0 @@
from tensorrt_llm._utils import trt_axes_to_dim
from .node import Node
from .sharding_strategy import StrategiesVector
class Reduce(Node):
def __init__(self, layer):
super().__init__(layer)
layer.to_subclass()
self.reduce_dims = trt_axes_to_dim(layer.as_trt().axes)
self.sum_mapping_dict = {}
num_input_dims = len(self.get_input(0).shape)
if layer.as_trt().keep_dims:
for i in range(num_input_dims):
self.sum_mapping_dict[i] = i
else:
output_index = 0
for i in range(num_input_dims):
if i not in self.reduce_dims:
self.sum_mapping_dict[i] = output_index
output_index += 1
assert output_index == len(self.get_output(0).shape)
layer.to_base_class()
def _collect_strategies(self, device_mesh):
dim_partition_list = []
dim_size = len(self.op_data['input0'].shape)
dim_partition_list.append({})
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
strategies_vector = StrategiesVector(self)
for dim_partition_dict in dim_partition_list:
recover_dims = []
out_partition_dict = {}
for dim in dim_partition_dict.keys():
if dim in self.reduce_dims:
recover_dims.append(dim)
elif dim in self.sum_mapping_dict:
out_partition_dict[
self.sum_mapping_dict[dim]] = dim_partition_dict[dim]
else:
assert 0, f'dim {dim} is not in sum_dims or sum_mapping_dict'
for dim in recover_dims:
dim_partition_dict.pop(dim)
in0_parition_dict = dim_partition_dict
dim_partition_dict_mapping = {
"input0": in0_parition_dict,
"output0": out_partition_dict,
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
name = '{} = <reduce along dim {}> {}'.format(
sharding_spec_mapping['output0'].sharding_sequence,
self.reduce_dims,
sharding_spec_mapping['input0'].sharding_sequence)
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector

View File

@ -1,56 +0,0 @@
from .node import Node
from .sharding_strategy import StrategiesVector
class Select(Node):
def __init__(self, layer):
super().__init__(layer)
batch_dims = [i for i in range(len(self.get_output(0).shape))]
self._generate_bcast_dims(batch_dims, self.get_output(0).shape)
def _collect_strategies(self, device_mesh):
dim_partition_list = []
dim_size = len(self.op_data['output0'].shape)
dim_partition_list.append({})
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
strategies_vector = StrategiesVector(self)
for dim_partition_dict in dim_partition_list:
# the three inputs are condition, true tensor and false tensor
in0_partition_dict = self._recover_bcast_partition_dict(
dim_partition_dict, self.op_data['input0'])
in1_partition_dict = self._recover_bcast_partition_dict(
dim_partition_dict, self.op_data['input1'])
in2_partition_dict = self._recover_bcast_partition_dict(
dim_partition_dict, self.op_data['input2'])
out_partition_dict = dim_partition_dict
dim_partition_dict_mapping = {
"input0": in0_partition_dict,
"input1": in1_partition_dict,
"input2": in2_partition_dict,
"output0": out_partition_dict,
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
name = '{} = <select op {}> {} {}'.format(
sharding_spec_mapping['output0'].sharding_sequence,
sharding_spec_mapping['input0'].sharding_sequence,
sharding_spec_mapping['input1'].sharding_sequence,
sharding_spec_mapping['input2'].sharding_sequence)
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector

View File

@ -1,832 +0,0 @@
import math
import operator
from functools import reduce
from typing import List, Tuple
import pandas as pd
from .comm_spec import CommSpec
from .sharding_spec import ShardingSpec
class ShapeConsistencyManager(object):
def __init__(self):
self.forward_only = True
self.cached_spec_pairs_transform_path = {}
self.cache_hit = 0
self.cache_miss = 0
def all_gather_simulator(self, target_pair):
_, shard_list = target_pair
new_shard_list = []
return new_shard_list
def all_to_all_simulator(self, f_target_pair, b_target_pair):
'''
Simulating all-to-all operation, analyze the communication cost
and simulate the influence of the DimSpec.
We BANNED all representations which shard_list in decreasing order,
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
Therefore, if the behind shard_list is not None, we just extend it to the front shard_list.
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element describes which logical axis will be sharded in that dimension.
e.g.:
all-to-all(S0, S1) -> [S01, R]
all-to-all(S0, R) -> [R, S0]
Otherwise, we extend the front shard_list to behind.
e.g.:
all-to-all(R, S1) -> [S1, R]
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element describes which logical axis will be sharded in that dimension.
'''
_, f_shard_list = f_target_pair
_, b_shard_list = b_target_pair
if not len(b_shard_list):
b_shard_list.extend(f_shard_list)
f_shard_list = []
else:
f_shard_list.extend(b_shard_list)
b_shard_list = []
return f_shard_list, b_shard_list
def shard_simulator(self, target_pair, legal_sharding_dims):
'''
Simulating shard operation, analyze the communication cost(always ZERO)
and simulate the influence of the DimSpec.
We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed.
In addition, We BANNED all representations which shard_list in decreasing order,
such as S10, so shard(S0) -> S10 is NOT allowed.
Therefore, for the R dimension, we could just append any legal sharding dim on it.
e.g.:
shard(R) -> S0
For the S dimension, we need to make sure the shard_list after sharding still keep rising order.
e.g:
shard(S0) -> S01
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element describes which logical axis will be sharded in that dimension.
'''
_, shard_list = target_pair
shard_list_list, logical_process_axis = [], []
for dim in legal_sharding_dims:
if len(shard_list) != 0 and dim <= shard_list[-1]:
continue
new_shard_list = shard_list + [dim]
shard_list_list.append(new_shard_list)
logical_process_axis.append([dim])
# we support sorted 2D mesh here
if len(legal_sharding_dims) == 2 and len(shard_list) == 0:
shard_list_list.append(legal_sharding_dims)
logical_process_axis.append(legal_sharding_dims)
return shard_list_list, logical_process_axis
def mix_gather_simulator(self, f_target_pair, b_target_pair):
'''
Assume index of f and b target pairs are 'f' and 'b'
S0S1 => Input: (f, [0]), (b, [1]) Output: [f, b], [[0], [1]]
S1S0 => Input: (f, [1]), (b, [0]) Output: [f, b], [[1], [0]]
S01R => Input: (f, [0, 1]), (b, []) Output: [f], [[0, 1]]
RS01 => Input: (f, []), (b, [0, 1]) Output: [b], [[0, 1]]
'''
if f_target_pair[1] and b_target_pair[1]:
return [f_target_pair[0],
b_target_pair[0]], [f_target_pair[1], b_target_pair[1]]
if f_target_pair[1]:
return [f_target_pair[0]], [f_target_pair[1]]
if b_target_pair[1]:
return [b_target_pair[0]], [b_target_pair[1]]
def get_all_all_gather_spec(self, source_spec, orig_cost):
'''
Get all valid sharding specs from source_spec with single all-gather operation, and
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
For the all-gather operation, we just care about the S dimension.
Argument:
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
orig_cost(Dict[str, float]): the original communication cost before this operation.
Return:
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation.
Example:
dim_partition_dict = {0: [0], 1: [1]}
# DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager()
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
print(rst_dict)
Output:
{DistSpec:
shard_sequence: R,S1,R
device_mesh_shape: (4, 4): 0, DistSpec:
shard_sequence: S0,R,R
device_mesh_shape: (4, 4): 0}
'''
valid_spec_dict = {}
comm_pattern = 'all_gather'
for target_pair in source_spec.dim_partition_dict.items():
shard_list = self.all_gather_simulator(target_pair)
index = target_pair[0]
new_dim_partition_dict = source_spec.dim_partition_dict.copy()
# We won't add empty list into dim_partition_dict
# The key will be popped if the related shard_list is empty
if shard_list:
new_dim_partition_dict[index] = shard_list
else:
new_dim_partition_dict.pop(index)
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
gather_dim = index
logical_process_axis = target_pair[1]
comm_spec = CommSpec(comm_pattern,
sharding_spec=source_spec,
gather_dim=[gather_dim],
logical_process_axis=[logical_process_axis],
forward_only=self.forward_only)
# compute the communication cost with CommSpec
# generate new sharding spec
new_sharding_spec = ShardingSpec(
source_spec.device_mesh,
source_spec.data_type_size,
source_spec.entire_shape,
source_spec.max_entire_shape,
source_spec.raw_shape,
dim_partition_dict=new_dim_partition_dict)
if not new_sharding_spec.sanity_check():
continue
cost = comm_spec.get_comm_cost()
valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost)
return valid_spec_dict
def get_all_all_to_all_spec(self, source_spec, orig_cost):
'''
Get all valid sharding specs from source_spec with single all-to-all operation, and
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
For the all-to-all operation, we just care about the pairs containing S dimension.
Argument:
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
orig_cost(Dict[str, float]): the original communication cost before this operation.
Return:
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
Example:
dim_partition_dict = {0: [0], 1: [1]}
# DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager()
rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
print(rst_dict)
Output:
{DistSpec:
shard_sequence: S01,R,R
device_mesh_shape: (4, 4): 0, DistSpec:
shard_sequence: R,S1,S0
device_mesh_shape: (4, 4): 0, DistSpec:
shard_sequence: S0,R,S1
device_mesh_shape: (4, 4): 0}
'''
valid_spec_dict = {}
comm_pattern = 'all_to_all'
tensor_dims = len(source_spec.entire_shape)
for f_index in range(tensor_dims - 1):
for b_index in range(f_index + 1, tensor_dims):
# skip (R, R) cases
if f_index not in source_spec.dim_partition_dict and b_index not in source_spec.dim_partition_dict:
continue
else:
if f_index in source_spec.dim_partition_dict:
'''
# skip (S01, R) -> (R, S01) is NOT allowed
if len(source_spec.dim_partition_dict[f_index]) >= 2:
continue
'''
f_target_pair = (f_index, [
*source_spec.dim_partition_dict[f_index]
])
else:
f_target_pair = (f_index, [])
if b_index in source_spec.dim_partition_dict:
'''
# skip (R, S01) -> (S01, R) is NOT allowed
if len(source_spec.dim_partition_dict[b_index]) >= 2:
continue
'''
b_target_pair = (b_index, [
*source_spec.dim_partition_dict[b_index]
])
else:
b_target_pair = (b_index, [])
# skip (S1, S0) -> S10
if f_target_pair[1] and b_target_pair[1] and f_target_pair[1][
0] >= b_target_pair[1][0]:
continue
f_shard_list, b_shard_list = self.all_to_all_simulator(
f_target_pair, b_target_pair)
f_index = f_target_pair[0]
b_index = b_target_pair[0]
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
if len(f_shard_list) < len(f_target_pair[1]):
gather_dim = f_index
shard_dim = b_index
logical_process_axis = f_target_pair[1]
else:
gather_dim = b_index
shard_dim = f_index
logical_process_axis = b_target_pair[1]
comm_spec = CommSpec(
comm_pattern,
sharding_spec=source_spec,
gather_dim=[gather_dim],
shard_dim=[shard_dim],
logical_process_axis=[logical_process_axis],
forward_only=self.forward_only)
# compute the communication cost with CommSpec
new_dim_partition_dict = source_spec.dim_partition_dict.copy()
# We won't add empty list into dim_partition_dict
# The key will be popped if the related shard_list is empty
if f_shard_list:
new_dim_partition_dict[f_index] = f_shard_list
else:
new_dim_partition_dict.pop(f_index)
if b_shard_list:
new_dim_partition_dict[b_index] = b_shard_list
else:
new_dim_partition_dict.pop(b_index)
# generate new sharding spec
new_sharding_spec = ShardingSpec(
source_spec.device_mesh,
source_spec.data_type_size,
source_spec.entire_shape,
source_spec.max_entire_shape,
source_spec.raw_shape,
dim_partition_dict=new_dim_partition_dict)
if not new_sharding_spec.sanity_check():
continue
cost = comm_spec.get_comm_cost()
valid_spec_dict[new_sharding_spec] = (comm_spec,
cost + orig_cost)
return valid_spec_dict
def get_all_shard_spec(self, source_spec, orig_cost):
'''
Get all valid sharding specs from source_spec with single shard operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
For the sharding operation, we just care about legal sharding dimensions.
Argument:
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
orig_cost(float): the original communication cost before this operation.
Return:
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
Example:
dim_partition_dict = {0: [0]}
# DistSpec:
# shard_sequence: S0,R,R
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager()
rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0})
print(rst_dict)
Output:
{DistSpec:
shard_sequence: S01,R,R
device_mesh_shape: (4, 4): 0, DistSpec:
shard_sequence: S0,S1,R
device_mesh_shape: (4, 4): 0, DistSpec:
shard_sequence: S0,R,S1
device_mesh_shape: (4, 4): 0}
'''
valid_spec_dict = {}
comm_pattern = 'split'
# legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims = [
i for i in range(len(source_spec.device_mesh.mesh_shape))
]
for dim, shard_list in source_spec.dim_partition_dict.items():
for element in shard_list:
legal_sharding_dims.remove(element)
if len(legal_sharding_dims) == 0:
return valid_spec_dict
tensor_dims = len(source_spec.entire_shape)
for index in range(tensor_dims):
if index not in source_spec.dim_partition_dict:
shard_list_list, logical_process_axes = self.shard_simulator(
(index, []), legal_sharding_dims)
else:
shard_list_list, logical_process_axes = self.shard_simulator(
(index, source_spec.dim_partition_dict[index]),
legal_sharding_dims)
if not shard_list_list:
continue
for shard_list, logical_process_axis in zip(shard_list_list,
logical_process_axes):
new_dim_partition_dict = source_spec.dim_partition_dict.copy()
new_dim_partition_dict[index] = shard_list
# generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
comm_spec = CommSpec(
comm_pattern,
sharding_spec=source_spec,
shard_dim=[index],
logical_process_axis=[logical_process_axis],
forward_only=self.forward_only)
# generate new sharding spec
new_sharding_spec = ShardingSpec(
source_spec.device_mesh,
source_spec.data_type_size,
source_spec.entire_shape,
source_spec.max_entire_shape,
source_spec.raw_shape,
dim_partition_dict=new_dim_partition_dict)
if not new_sharding_spec.sanity_check():
continue
# compute the communication cost with CommSpec
cost = comm_spec.get_comm_cost()
valid_spec_dict[new_sharding_spec] = (comm_spec,
cost + orig_cost)
return valid_spec_dict
def get_all_mixed_shard_spec(self, source_spec, orig_cost):
'''
Get all valid sharding specs from source_spec with single shard operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
For the sharding operation, we just care about legal sharding dimensions.
'''
valid_spec_dict = {}
comm_pattern = 'split'
# legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims = [
i for i in range(len(source_spec.device_mesh.mesh_shape))
]
for dim, shard_list in source_spec.dim_partition_dict.items():
for element in shard_list:
legal_sharding_dims.remove(element)
if len(legal_sharding_dims) != 2:
return valid_spec_dict
tensor_dims = len(source_spec.entire_shape)
for f_index in range(tensor_dims):
for b_index in range(tensor_dims):
if f_index != b_index:
shard_dims = [f_index, b_index]
logical_process_axes = [[legal_sharding_dims[0]],
[legal_sharding_dims[1]]]
new_dim_partition_dict = source_spec.dim_partition_dict.copy(
)
new_dim_partition_dict[f_index] = [legal_sharding_dims[0]]
new_dim_partition_dict[b_index] = [legal_sharding_dims[1]]
comm_spec = CommSpec(
comm_pattern,
sharding_spec=source_spec,
shard_dim=shard_dims,
logical_process_axis=logical_process_axes,
forward_only=self.forward_only)
# generate new sharding spec
new_sharding_spec = ShardingSpec(
source_spec.device_mesh,
source_spec.data_type_size,
source_spec.entire_shape,
source_spec.max_entire_shape,
source_spec.raw_shape,
dim_partition_dict=new_dim_partition_dict)
if not new_sharding_spec.sanity_check():
continue
cost = comm_spec.get_comm_cost()
valid_spec_dict[new_sharding_spec] = (comm_spec,
cost + orig_cost)
return valid_spec_dict
def get_all_mix_gather_spec(self, source_spec, orig_cost):
'''
S0S1 -> RR
S1S0 -> RR
S01R -> RR
RS01 -> RR
'''
valid_spec_dict = {}
comm_pathern = 'all_gather'
tensor_dims = len(source_spec.entire_shape)
for f_index in range(tensor_dims - 1):
for b_index in range(f_index + 1, tensor_dims):
if (f_index not in source_spec.dim_partition_dict) and (
b_index not in source_spec.dim_partition_dict):
continue
else:
if f_index in source_spec.dim_partition_dict:
# skip (S10, R) -> (R, R)
'''
if len(
f_target_pair[1]
) == 2 and f_target_pair[1][0] >= f_target_pair[1][1]:
continue
'''
f_target_pair = (f_index, [
*source_spec.dim_partition_dict[f_index]
])
else:
f_target_pair = (f_index, [])
if b_index in source_spec.dim_partition_dict:
# skip (R, S10) -> (R, R)
'''
if len(
b_target_pair[1]
) == 2 and b_target_pair[1][0] >= b_target_pair[1][1]:
continue
'''
b_target_pair = (b_index, [
*source_spec.dim_partition_dict[b_index]
])
else:
b_target_pair = (b_index, [])
if len(f_target_pair[1]) + len(b_target_pair[1]) != 2:
continue
gather_dim, logical_process_axes = self.mix_gather_simulator(
f_target_pair, b_target_pair)
comm_spec = CommSpec(comm_pathern,
sharding_spec=source_spec,
gather_dim=gather_dim,
logical_process_axis=logical_process_axes,
forward_only=self.forward_only,
mix_gather=True)
new_dim_partition_dict = {}
# generate new sharding spec
new_sharding_spec = ShardingSpec(
source_spec.device_mesh,
source_spec.data_type_size,
source_spec.entire_shape,
source_spec.max_entire_shape,
source_spec.raw_shape,
dim_partition_dict=new_dim_partition_dict)
if not new_sharding_spec.sanity_check():
continue
cost = comm_spec.get_comm_cost()
valid_spec_dict[new_sharding_spec] = (comm_spec,
cost + orig_cost)
return valid_spec_dict
def get_all_one_step_transform_spec(self, source_spec, orig_cost):
'''
Get all valid sharding specs from source_spec with one step transform, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
Note:
all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before,
and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive,
we could safely put them together.
Argument:
source_spec(ShardingSpec): the ShardingSpec of the source_spec.
orig_cost(float): the original communication cost before this operation.
Return:
valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation.
'''
valid_spec_dict = {}
valid_spec_dict.update(
self.get_all_all_gather_spec(source_spec, orig_cost))
valid_spec_dict.update(
self.get_all_all_to_all_spec(source_spec, orig_cost))
valid_spec_dict.update(
self.get_all_mix_gather_spec(source_spec, orig_cost))
valid_spec_dict.update(
self.get_all_mixed_shard_spec(source_spec, orig_cost))
valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost))
return valid_spec_dict
def mem_cost(self, comm_action_sequence: List[CommSpec], mem_pattern='opt'):
"""memory cost of the communication action sequence
Args:
comm_action_sequence (List[CommSpec]): list of communication actions
Returns:
TrainCycleItem: memory (numel) cost of such comm_action_sequence
"""
def compute_shape(sharding_spec: ShardingSpec):
if 'opt' == mem_pattern:
return sharding_spec.get_sharded_shape_per_device()
elif 'max' == mem_pattern:
return sharding_spec.get_max_sharded_shape_per_device()
else:
return 0.0
def gather_analysis(comm_spec, peak_mem):
"""analyze all_gather memory footprint
all_gather will allocate memory for the output tensor, and there will be temp memory for
all_gather operation, which is twice the size of output tensor
Args:
comm_spec (CommSpec): input CommSpec
"""
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = reduce(operator.mul, input_shape, 1)
for axes in comm_spec.logical_process_axis:
for axis in axes:
output_numel = input_numel * comm_spec.device_mesh.mesh_shape[
axis]
alloc_mem = (input_numel +
output_numel * 2) * comm_spec.sharding_spec.dtype_size
peak_mem = max(peak_mem, alloc_mem)
return peak_mem
def reduce_scatter_analysis(comm_spec, peak_mem):
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = reduce(operator.mul, input_shape, 1)
output_numel = input_numel
for axes in comm_spec.logical_process_axis:
for axis in axes:
output_numel = output_numel / comm_spec.device_mesh.mesh_shape[
axis]
alloc_mem = (input_numel +
output_numel * 2) * comm_spec.sharding_spec.dtype_size
peak_mem = max(peak_mem, alloc_mem)
return peak_mem
def split_analysis(comm_spec: CommSpec, peak_mem: int):
"""analyze split memory footprint
split will allocate memory for the output tensor if we don't apply shard on the first dimension of
the input tensor. If we apply shard on the first dimension, the `torch.tensor.contiguous()` will not
generate new tensor in this case, so no memory will be allocated.
Args:
comm_spec (CommSpec): input CommSpec
discard_input (bool): whether to discard the input tensor
alloc_numel (int): current allocated numel
peak_numel (int): current peak numel
"""
shard_dim = comm_spec.shard_dim
if shard_dim != 0:
# if we don't shard the tensor on the first dimension, the split action will
# generate a new tensor
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = reduce(operator.mul, input_shape, 1)
output_numel = input_numel
for axes in comm_spec.logical_process_axis:
for axis in axes:
output_numel = output_numel / comm_spec.device_mesh.mesh_shape[
axis]
alloc_mem = (input_numel +
output_numel) * comm_spec.sharding_spec.dtype_size
peak_mem = max(peak_mem, alloc_mem)
else:
# if we shard the tensor on the first dimension, the split action will not generate
# a new tensor, and as it will preserve a reference to the input tensor, we could
# override the discard_input option here
# NOTE: this special case might fail in some weird cases, e.g. if we have three split
# actions in the comm actions sequence, the first split action operate on the second dimension,
# the second split action operate on the first dimension, and the third split action operate, again,
# on the second dimension. Therefore, after the first two actions in the sequence, we will allocate
# memory the same size as the output of first split action. However, the third split action will discard
# the input tensor, and it actually should discard the tensor generated by the first split action, so in
# the current memory estimation framework, we will overestimate the memory usage. But the above case is
# kind of weird, and I think we could ignore it for now.
pass
return peak_mem
def reduce_analysis(comm_spec: CommSpec, peak_mem: int):
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = reduce(operator.mul, input_shape, 1)
output_numel = input_numel
alloc_mem = (input_numel +
output_numel) * comm_spec.sharding_spec.dtype_size
peak_mem = max(peak_mem, alloc_mem)
return peak_mem
def all2all_analysis(comm_spec: CommSpec, peak_mem: int):
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = reduce(operator.mul, input_shape, 1)
output_numel = input_numel
comm_spec.shard_dim
alloc_mem = (input_numel +
output_numel * 3) * comm_spec.sharding_spec.dtype_size
peak_mem = max(peak_mem, alloc_mem)
return peak_mem
def peer_to_peer_analysis(comm_spec: CommSpec, peak_mem: int):
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = reduce(operator.mul, input_shape, 1)
alloc_mem = (input_numel) * comm_spec.sharding_spec.dtype_size
peak_mem = max(peak_mem, alloc_mem)
return peak_mem
pattern_to_func_dict = {
'all_gather': gather_analysis,
'all_to_all': all2all_analysis,
'split': split_analysis,
'all_reduce': reduce_analysis,
'reduce_scatter': reduce_scatter_analysis,
'peer_to_peer': peer_to_peer_analysis
}
fwd_actions = []
# construct forward and backward comm actions sequence
for comm_spec in comm_action_sequence:
fwd_action = pattern_to_func_dict[comm_spec.comm_pattern]
fwd_actions.append(fwd_action)
# analyze memory footprint of forward comm actions sequence
fwd_peak_numel = 0
for idx, action_spec_pair in enumerate(
zip(fwd_actions, comm_action_sequence)):
# the first forward comm action will not discard input
fwd_action, comm_spec = action_spec_pair
fwd_peak_numel = fwd_action(comm_spec, fwd_peak_numel)
return fwd_peak_numel
def print_shape_consistency_result(self,
transform_path,
comm_action_sequence,
resharding_cost,
file=None):
for idx, tpath in enumerate(transform_path):
print(
f'sharding_info = [op_shape:{tpath.entire_shape}, sharding_spec:{tpath.sharding_sequence}, sharded_shape:{tpath.get_sharded_shape_per_device()}]',
end=" ",
file=file)
print('->', end=" ", file=file)
try:
commspec = comm_action_sequence[idx]
comm = [
commspec.comm_pattern, commspec.gather_dim,
commspec.shard_dim, commspec.logical_process_axis
]
except:
comm = ''
print(f'comm_info = {comm}', end=" ", file=file)
print('->', end=" ", file=file)
print(f'total_cost = {resharding_cost}', file=file)
def construct_transform_path_from_cache(self, src_spec, target_spec,
old_transform_path,
old_comm_action_sequence,
orig_cost):
new_transform_path = [src_spec]
new_comm_action_sequence = []
new_cost = orig_cost
new_src_spec = src_spec
for idx, old_comm_spec in enumerate(old_comm_action_sequence):
new_comm_spec = CommSpec(
old_comm_spec.comm_pattern,
sharding_spec=new_src_spec,
gather_dim=old_comm_spec.gather_dim,
shard_dim=old_comm_spec.shard_dim,
logical_process_axis=old_comm_spec.logical_process_axis,
forward_only=old_comm_spec.forward_only,
mix_gather=old_comm_spec.mix_gather)
new_comm_action_sequence.append(new_comm_spec)
new_cost += new_comm_spec.get_comm_cost()
old_target_spec = old_transform_path[idx + 1]
new_target_spec = ShardingSpec(new_src_spec.device_mesh,
new_src_spec.data_type_size,
new_src_spec.entire_shape,
new_src_spec.max_entire_shape,
new_src_spec.raw_shape,
old_target_spec.dim_partition_dict)
new_transform_path.append(new_target_spec)
new_src_spec = new_target_spec
assert new_transform_path[-1].get_sharded_shape_per_device(
) == target_spec.get_sharded_shape_per_device(
), 'failed to insert the cache transform path'
return new_transform_path, new_comm_action_sequence, new_cost
def shape_consistency(
self, source_spec: ShardingSpec, target_spec: ShardingSpec
) -> Tuple[List[ShardingSpec], List[CommSpec], float]:
'''
This method will find a path to transform source_spec to target_spec with
a greedy algorithm.
The basic idea is:
Step1:
Generate all one-step transform sequences from source_spec.
Step2:
Pick the 'best' sharding spec following the heuristic function.
Step3:
Repeat above steps until the source spec transform to target spec.
'''
MAX_TRANSFORM_STEPS = 20
total_cost = 0.0
total_steps = 0
transform_path = []
comm_action_sequence = []
# We do nothing if the sharding spec is all the same.
if source_spec.sharding_sequence_difference(target_spec) == 0:
return (transform_path, comm_action_sequence, total_cost)
spec_pairs = (str(source_spec.sharding_sequence),
str(target_spec.sharding_sequence))
if spec_pairs in self.cached_spec_pairs_transform_path:
transform_path, comm_action_sequence = self.cached_spec_pairs_transform_path[
spec_pairs]
new_transform_path, new_comm_action_sequence, new_total_cost = self.construct_transform_path_from_cache(
source_spec, target_spec, transform_path, comm_action_sequence,
total_cost)
self.cache_hit += 1
return (new_transform_path, new_comm_action_sequence,
new_total_cost)
else:
self.cache_miss += 1
temp_sharding_spec = source_spec
transform_path.append(temp_sharding_spec)
# To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms
while total_steps <= MAX_TRANSFORM_STEPS:
valid_transform_spec_dict = self.get_all_one_step_transform_spec(
temp_sharding_spec, total_cost)
best_difference_score = math.inf
for sharding_spec, info_pairs in valid_transform_spec_dict.items():
comm_spec, cost = info_pairs
spec_difference = sharding_spec.sharding_sequence_difference(
target_spec)
if spec_difference == 0:
total_cost = cost
transform_path.append(sharding_spec)
comm_action_sequence.append(comm_spec)
self.cached_spec_pairs_transform_path[spec_pairs] = (
transform_path, comm_action_sequence)
return (transform_path, comm_action_sequence, total_cost)
if spec_difference < best_difference_score:
temp_sharding_spec = sharding_spec
temp_cost = cost
temp_comm_spec = comm_spec
best_difference_score = spec_difference
transform_path.append(temp_sharding_spec)
comm_action_sequence.append(temp_comm_spec)
total_cost = temp_cost
total_steps += 1
raise RuntimeError(
f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps."
)
def dum_transform_path_from_cache(self):
src_specs, tgt_specs, path_strs = [], [], []
for spec_pairs, trans_comm_path in self.cached_spec_pairs_transform_path.items(
):
src_specs.append(spec_pairs[0])
tgt_specs.append(spec_pairs[1])
trans_paths, comm_specs = trans_comm_path[0], trans_comm_path[1]
path_str = f'{spec_pairs[0]}->'
for idx in range(1, len(trans_paths)):
comm_spec = comm_specs[idx - 1]
comm_str = f'{comm_spec.comm_pattern}: gather_dim{comm_spec.gather_dim}, shard_dim{comm_spec.shard_dim}, mesh_axis{comm_spec.logical_process_axis}->'
path_str += comm_str
path_str += f'{trans_paths[idx].sharding_sequence}->'
path_strs.append(path_str)
ret_dict = {
'src_spec': src_specs,
'dst_specs': tgt_specs,
'trans_path': path_strs
}
ret_df = pd.DataFrame.from_dict(ret_dict)
return ret_df

View File

@ -1,41 +0,0 @@
import copy
from .node import Node
from .sharding_strategy import StrategiesVector
class Shape(Node):
def _update_memory_cost(self, strategies):
pass
def _collect_strategies(self, device_mesh):
# one input for softmax node
predecessor = self.predecessor_nodes[0]
strategies_vector = StrategiesVector(self)
for idx, strategy in enumerate(predecessor.strategies_vector):
# current node's local name input0 -> global name xxx
global_input_name = self.op_data['input0'].name
# global name xxx -> pre node local output name
prenode_local_name = predecessor.global_to_local_op_name[
global_input_name]
dim_partition_dict = copy.deepcopy(
strategy.sharding_specs[prenode_local_name].dim_partition_dict)
in0_partition_dict = dim_partition_dict
dim_partition_dict_mapping = {
"input0": in0_partition_dict,
"output0": {},
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
return strategies_vector
name = '{} = <shape> {}'.format(
sharding_spec_mapping['output0'].sharding_sequence,
sharding_spec_mapping['input0'].sharding_sequence)
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector

View File

@ -1,418 +0,0 @@
import operator
from functools import reduce
import tensorrt as trt
from tensorrt_llm.logger import logger
ALLGATHER_COST = 20
SHARD_COST = 5
STEP_PENALTY = 6
NAN = 'nan'
def _convert_str_to_shard_list(str_spec):
'''
Convert str_spec into shard_list.
Argument:
str_spec(str): dim spec in str type.
'''
if str_spec == 'R':
return []
if str_spec == 'S0':
return [0]
if str_spec == 'S1':
return [1]
if str_spec == 'S01':
return [0, 1]
def _build_difference_2d_dict():
'''
Build a difference mapping for 2D device mesh case. It will be used to
compute the difference between DimSpec pairs.
'''
source_spec_list = ['R', 'S0', 'S1', 'S01']
target_spec_list = ['R', 'S0', 'S1', 'S01']
difference_dict = {}
for source_spec in source_spec_list:
for target_spec in target_spec_list:
spec_pair = (source_spec, target_spec)
source_shard_list = _convert_str_to_shard_list(source_spec)
target_shard_list = _convert_str_to_shard_list(target_spec)
# source same as target
if source_shard_list == target_shard_list:
difference = 0
# all_gather(source) -> target
elif len(source_shard_list) == len(
target_shard_list
) + 1 and source_shard_list[:-1] == target_shard_list:
difference = ALLGATHER_COST
# shard(source) -> target
elif len(source_shard_list) == len(
target_shard_list
) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[
-1] not in source_shard_list:
difference = SHARD_COST
# S1 -> S0 or S0 -> S1
elif len(source_shard_list) == len(target_shard_list):
# source -> R -> target
difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST
# R -> S01
elif len(source_shard_list) == len(target_shard_list) - 2:
difference = SHARD_COST + STEP_PENALTY + SHARD_COST
# S01 -> R
elif len(source_shard_list) == len(target_shard_list) + 2:
difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST
# S1 -> S01
elif len(source_shard_list) == len(target_shard_list) - 1:
difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + STEP_PENALTY + SHARD_COST
# S01 -> S1
elif len(source_shard_list) == len(target_shard_list) + 1:
difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + STEP_PENALTY + SHARD_COST
else:
difference = NAN
difference_dict[spec_pair] = difference
return difference_dict
_difference_dict = _build_difference_2d_dict()
class DimSpec:
'''
Sharding spec for single dimension of the sharded tensor describe the sharding dimension of
logical device mesh and give a method to compute the difference between them.
Argument:
shard_list(List[int]): if shard_list is empty, the dim spec will be 'R' type.
Otherwise, the element in shard_list means the data will be sharded in that dimension.
'''
def __init__(self, shard_list):
self.is_replica = len(shard_list) == 0
self.shard_list = shard_list
def __eq__(self, other):
return str(self) == str(other)
def __repr__(self):
if self.is_replica:
return 'R'
target = 'S'
for dim in self.shard_list:
target += str(dim)
return target
def difference(self, other):
'''
The difference between two DimSpec.
Argument:
other(DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two DimSpec.
Example:
dim_spec = DimSpec([0])
other_dim_spec = DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
Output:
5
'''
difference = _difference_dict[(str(self), str(other))]
return difference
def get_sharding_sequence(num_dims, dims, device_dims):
sharding_sequence = [DimSpec([])] * num_dims
for dim, shard_list in zip(dims, device_dims):
sharding_sequence[dim] = DimSpec(shard_list)
return sharding_sequence
class ShardingSpec:
'''
Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong
to, the entire shape of the tensor before sharded, and the sharding sequence looks like
[R, R, S0, S1].
Argument:
device_mesh: A logical view of a physical mesh.
entire_shape: The entire shape of tensor before sharded.
dim_partition_dict: The key is the dimension of tensor to be sharded,
and the value of the key describe which logical axis will be sharded in that dimension.
sharding_sequence(List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
'''
def __init__(self,
device_mesh,
data_type_size,
data_shape,
max_data_shape,
raw_data_shape,
dim_partition_dict=None,
sharding_sequence=None):
self.device_mesh = device_mesh
self.data_type_size = data_type_size
self.dtype = data_type_size[0]
self.dtype_size = data_type_size[1]
self.entire_shape = data_shape
self.max_entire_shape = max_data_shape
self.raw_shape = raw_data_shape
self.dim_partition_dict = dim_partition_dict
self.sharding_sequence = sharding_sequence
self.enable_shard_unbalanced_shape = device_mesh.config.enable_shard_unbalanced_shape
self.enable_shard_dynamic_shape = device_mesh.config.enable_shard_dynamic_shape
if self.sharding_sequence is None:
self.dim_partition_dict = self._merge_same_dim_mesh_list(
len(self.entire_shape), self.dim_partition_dict)
self.dim_partition_dict = self._convert_dim_partition_dict(
len(self.entire_shape), self.dim_partition_dict)
assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.'
self.convert_dict_to_shard_sequence()
elif self.dim_partition_dict is None:
assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.'
self.convert_shard_sequence_to_dict()
self.dim_partition_dict = self._merge_same_dim_mesh_list(
len(self.entire_shape), self.dim_partition_dict)
self.dim_partition_dict = self._convert_dim_partition_dict(
len(self.entire_shape), self.dim_partition_dict)
self.sharded_shape, self.max_sharded_shape = [*self.entire_shape], [
*self.max_entire_shape
]
for dim, shard_list in self.dim_partition_dict.items():
mesh_list = [
self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list
]
shard_partitions = reduce(operator.mul, mesh_list, 1)
self.sharded_shape[dim] = (self.sharded_shape[dim] +
shard_partitions - 1) // shard_partitions
self.max_sharded_shape[dim] = (self.max_sharded_shape[dim] +
shard_partitions -
1) // shard_partitions
def print_spec(self, file=None):
print(
f"sharding_sequence = {self.sharding_sequence}, shape = {self.get_sharded_shape_per_device()}",
file=file,
)
def _merge_same_dim_mesh_list(self, dim_size, dim_partition_dict):
'''
This method is used to merge the different key value which points to same physical position.
For example:
dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position.
In this method, above dim_partition_dict will be converted to {1: [0, 1]}
'''
converted_dim_partition_dict = {}
for dim, mesh_list in dim_partition_dict.items():
if dim < 0:
dim = dim_size + dim
if dim not in converted_dim_partition_dict:
converted_dim_partition_dict[dim] = mesh_list
else:
converted_dim_partition_dict[dim].extend(mesh_list)
converted_dim_partition_dict[dim].sort()
return converted_dim_partition_dict
def _convert_dim_partition_dict(self, dim_size, dim_partition_dict):
dims_to_convert = []
for dim, mesh_list in dim_partition_dict.items():
if dim < 0:
dims_to_convert.append(dim)
for dim in dims_to_convert:
dim_partition_dict.pop(dim)
dim_partition_dict[dim_size + dim] = mesh_list
return dim_partition_dict
def _remove_mesh_dim_one(self, dim_partition_dict):
dims_to_remove = []
for dim, mesh_list in dim_partition_dict.items():
new_mesh_list = []
for mesh_dim in mesh_list:
if self.device_mesh.mesh_shape[mesh_dim] != 1:
new_mesh_list.append(mesh_dim)
if 0 != len(new_mesh_list):
dim_partition_dict[dim] = new_mesh_list
else:
dims_to_remove.append(dim)
for dim in dims_to_remove:
dim_partition_dict.pop(dim)
return dim_partition_dict
def __repr__(self):
res = "DistSpec("
res += f"shard_sequence={self.sharding_sequence},"
res += f"shape={self.device_mesh.mesh_shape}"
res += ")"
return res
def sanity_check(self):
# make sure all axes in logical device mesh only be used once
dim_check_list = [*range(len(self.device_mesh.mesh_shape))]
for dim, shard_list in self.dim_partition_dict.items():
for element in shard_list:
if element in dim_check_list:
dim_check_list.remove(element)
else:
logger.warning(
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}. dim_partition_dict={self.dim_partition_dict}"
)
return False
# make sure that the dimension is not out of index
for dim in self.dim_partition_dict.keys():
# we have tried to convert the negative value to positive value, if it is larger than the dim_size or negative still, it is out of index
if dim >= len(self.entire_shape) or dim < 0:
print(
f"The dim_partition_dict specifies to shard dimension {dim} but the entire_shape only has {len(self.entire_shape)} dimensions"
)
return False
if not self.enable_shard_dynamic_shape:
# make sure to not to shard on dynamic shape
for dim, shard_list in self.dim_partition_dict.items():
if len(shard_list) == 0:
continue
if len(self.raw_shape) == 0:
continue
if -1 == self.raw_shape[dim]:
return False
# make sure that the sharding for a dimension is divisible by the number of devices
for dim, shard_list in self.dim_partition_dict.items():
if len(shard_list) == 0:
continue
tensor_dim_size = self.entire_shape[dim]
num_devices = 1
for element in shard_list:
num_devices *= self.device_mesh.mesh_shape[element]
if num_devices == 1:
# we only support RR when the device is 1
return False
if not self.enable_shard_unbalanced_shape:
if tensor_dim_size % num_devices != 0 or tensor_dim_size == 1:
'''
print(
f'The size of static dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.'
)
'''
return False
else:
if tensor_dim_size == 1:
return False
'''
if self.get_sharded_size_per_device() > (2**31 - 1):
print(
f'memory footprint per device {self.get_sharded_size_per_device()} is larger than 2**31 - 1'
)
return False
'''
return True
def convert_dict_to_shard_sequence(self):
'''
Convert dim_partition_dict into list of DimSpec, and assign it to sharding_sequence.
'''
sharding_sequence = [DimSpec([])] * len(self.entire_shape)
for dim, shard_list in self.dim_partition_dict.items():
sharding_sequence[dim] = DimSpec(shard_list)
self.sharding_sequence = sharding_sequence
def convert_shard_sequence_to_dict(self):
'''
Convert sharding_sequence into dim_partition_dict.
'''
new_dim_partition_dict = {}
for index, dim_spec in enumerate(self.sharding_sequence):
if not dim_spec.is_replica:
if index not in new_dim_partition_dict:
new_dim_partition_dict[index] = []
new_dim_partition_dict[index].extend(dim_spec.shard_list)
self.dim_partition_dict = new_dim_partition_dict
def sharding_sequence_difference(self, other):
'''
This function is a naive version of difference computation. It just simply accumulates difference every dimension between the
pair of sharding sequence.
Example:
dim_partition_dict = {0: [0, 1]}
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
dim_partition_dict_to_compare = {0: [0], 1: [1]}
# DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4)
sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
Output:
25
Argument:
other(ShardingSpec): The ShardingSpec to compared with.
Return:
difference(int): Difference between two ShardingSpec.
'''
assert len(self.sharding_sequence) == len(
other.sharding_sequence
), f'Cannot compare difference for two sharding specs with different length.'
difference = 0
for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence,
other.sharding_sequence):
difference += orig_dim_spec.difference(other_dim_spec)
return difference
def get_sharded_shape_per_device(self, ):
return self.sharded_shape
def get_sharded_element_per_device(self, ):
sharded_shape = self.get_sharded_shape_per_device()
if len(sharded_shape) == 0:
num_elements = 1
else:
num_elements = trt.volume(sharded_shape)
return num_elements
def get_sharded_size_per_device(self, ):
num_elements = self.get_sharded_element_per_device()
return num_elements * self.dtype_size
def get_max_sharded_shape_per_device(self, ):
return self.max_sharded_shape
def get_max_sharded_element_per_device(self, ):
max_sharded_shape = self.get_max_sharded_shape_per_device()
if len(max_sharded_shape) == 0:
num_elements = 1
else:
num_elements = trt.volume(max_sharded_shape)
return num_elements
def get_max_sharded_size_per_device(self, ):
num_elements = self.get_max_sharded_element_per_device()
return num_elements * self.dtype_size

View File

@ -1,77 +0,0 @@
class ShardingStrategy(object):
def __init__(self,
name=None,
sharding_specs=None,
communication_actions=None):
self.name = name or ""
self.sharding_specs = sharding_specs or {}
self.communication_actions = communication_actions
self.sharding_cost = 0
self.communication_cost = 0
self.resharding_costs = {}
self.best_resharding_cost = {}
self.node_names = {}
self.comm_buff_memory_footprint = 0
self.inout_memory_footprint = 0
self.const_memory_footprint = 0
self.peak_memory_footprint = 0
self.computation_macs = 0
self.alpha_beta_cost = 0
def print_strategy(self, best_resharding_cost_only=False, file=None):
def print_resharding_costs(resharding_cost):
for prenode_node_name, rcosts in resharding_cost.items():
if isinstance(prenode_node_name, int):
idx = prenode_node_name
prenode_node_name = self.node_names[idx]
print(f' pre_node = {idx} {prenode_node_name}',
file=file)
else:
print(f' pre_node = {prenode_node_name}', file=file)
for idx, rcost in enumerate(rcosts):
transpaths, commspecs, cost = rcost
print(f' {idx}: ', end=' ', file=file)
device_mesh.shape_consistency_manager.print_shape_consistency_result(
transpaths, commspecs, cost, file)
print(f'name = {self.name}', file=file)
print(f'sharding_cost = {self.sharding_cost}', file=file)
print(
f'communication_buffer_memory_footprint = {self.comm_buff_memory_footprint}, communication_cost = {self.communication_cost}',
file=file)
print(f'inout_memory_footprint = {self.inout_memory_footprint}',
file=file)
print(f'peak_memory_footprint = {self.peak_memory_footprint}',
file=file)
print(f'const_memory_footprint = {self.const_memory_footprint}',
file=file)
print('sharding_specs:', file=file)
device_mesh = None
for specname, spec in self.sharding_specs.items():
print(specname + ', ', end=' ', file=file)
spec.print_spec(file)
device_mesh = spec.device_mesh
if best_resharding_cost_only and self.best_resharding_cost:
print('best_resharding_costs:', file=file)
print_resharding_costs(self.best_resharding_cost)
else:
print('resharding costs:', file=file)
print_resharding_costs(self.resharding_costs)
class StrategiesVector(list):
'''
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
strategies of the node.
Argument:
node (Node): node for which the list of sharding strategies are generated.
'''
def __init__(self, node):
super().__init__()
self.node = node

View File

@ -1,238 +0,0 @@
import copy
from .node import Node
from .sharding_strategy import StrategiesVector
class Shuffle(Node):
def __init__(self, layer):
super().__init__(layer)
layer.to_subclass()
self.first_tanspose_dims = layer.as_trt().first_transpose
self.second_transpose_dims = layer.as_trt().second_transpose
self.zero_is_placeholder = layer.as_trt().zero_is_placeholder
self.is_first_transepose_identity = (sorted(
self.first_tanspose_dims) == list(self.first_tanspose_dims))
self.input_shape = self.get_input(0).shape
self.is_second_transepose_identity = (sorted(
self.second_transpose_dims) == list(self.second_transpose_dims))
output_shape = list(self.get_output(0).shape)
self.reshape_dims = copy.deepcopy(output_shape)
if not self.is_second_transepose_identity:
for i in self.second_transpose_dims:
if self.second_transpose_dims[i] != i:
self.reshape_dims[
self.second_transpose_dims[i]] = output_shape[i]
self.is_reshape_identity = (list(self.reshape_dims) == list(
self.input_shape))
layer.to_base_class()
def _collect_transpose_strategies(self, device_mesh, transpose_dims):
dim_partition_list = []
dim_size = len(self.op_data['input0'].shape)
dim_partition_list.append({})
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
strategies_vector = StrategiesVector(self)
# dim_partition_dict can be the same as previous node if solver's time is a problem
for dim_partition_dict in dim_partition_list:
in0_partition_dict = dim_partition_dict
out_partition_dict = {}
for split_dim, mesh_dim in in0_partition_dict.items():
trans_dim = transpose_dims[split_dim]
out_partition_dict[trans_dim] = mesh_dim
dim_partition_dict_mapping = {
"input0": in0_partition_dict,
"output0": out_partition_dict,
}
if self.num_inputs == 2:
dim_partition_dict_mapping["input1"] = {}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
name = '{} = <shuffle_transpose_only op> {}'.format(
sharding_spec_mapping['output0'].sharding_sequence,
sharding_spec_mapping['input0'].sharding_sequence)
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector
def _find_reshape_partitions(self, input_shape, output_shape,
input_partition_dict):
len_input_shape, len_output_shape = len(input_shape), len(output_shape)
output_partition_dict = {}
i, j = 0, 0
while i < len_input_shape or j < len_output_shape:
if i < len_input_shape and input_shape[i] == 1:
i = i + 1
continue
if j < len_output_shape and output_shape[j] == 1:
j = j + 1
continue
if input_shape[i] == output_shape[j]:
if i in input_partition_dict:
output_partition_dict[j] = input_partition_dict[i]
# it keep the dimension, so need to keep the partition dims
i, j = i + 1, j + 1
elif input_shape[i] < output_shape[j]:
# we detect if the input dims are merged in the reshape dim
value = input_shape[i]
for ii in range(i + 1, len_input_shape):
value = value * input_shape[ii]
if value == output_shape[j]:
# it is merged, we set the output's merged dim partition as all inputs' dims
mesh_dim = []
for in_dim in range(i, ii + 1):
if in_dim in input_partition_dict:
mesh_dim = mesh_dim + input_partition_dict[
in_dim]
if len(mesh_dim) > 0:
output_partition_dict[j] = sorted(mesh_dim)
i, j = ii + 1, j + 1
break
else:
# we don't find the merged dimensions, the difference may from random reshape, we don't support it now
return {}, {}
else:
# we detect if the input dim is split into reshape dims
value = output_shape[j]
for jj in range(j + 1, len_output_shape):
value = value * output_shape[jj]
if value == input_shape[i]:
# it is split pattern
if i in input_partition_dict:
output_partition_dict[j] = input_partition_dict[i]
i, j = i + 1, jj + 1
break
else:
# we don't find the split dimensions, the difference may from random reshape
return {}, {}
return input_partition_dict, output_partition_dict
def _collect_reshape_strategies(self, device_mesh):
dim_partition_list = []
dim_size = len(self.op_data['input0'].shape)
dim_partition_list.append({})
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
strategies_vector = StrategiesVector(self)
# dim_partition_dict can be the same as previous node if solver's time is a problem
for dim_partition_dict in dim_partition_list:
in0_partition_dict = dim_partition_dict
in0_partition_dict, out_partition_dict = self._find_reshape_partitions(
self.input_shape, self.reshape_dims, in0_partition_dict)
dim_partition_dict_mapping = {
"input0": in0_partition_dict,
"output0": out_partition_dict,
}
if self.num_inputs == 2:
dim_partition_dict_mapping["input1"] = {}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
name = '{} = <shuffle_reshape op> {}'.format(
sharding_spec_mapping['output0'].sharding_sequence,
sharding_spec_mapping['input0'].sharding_sequence)
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector
def _collect_identity_strategies(self, device_mesh):
dim_partition_list = []
dim_size = len(self.op_data['input0'].shape)
dim_partition_list.append({})
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
strategies_vector = StrategiesVector(self)
# dim_partition_dict can be the same as previous node if solver's time is a problem
for dim_partition_dict in dim_partition_list:
in0_partition_dict = dim_partition_dict
out_partition_dict = copy.deepcopy(dim_partition_dict)
dim_partition_dict_mapping = {
"input0": in0_partition_dict,
"output0": out_partition_dict,
}
if self.num_inputs == 2:
dim_partition_dict_mapping["input1"] = {}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
name = '{} = <shuffle_identity op> {}'.format(
sharding_spec_mapping['output0'].sharding_sequence,
sharding_spec_mapping['input0'].sharding_sequence)
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector
def _collect_strategies(self, device_mesh):
is_identify_list = (self.is_first_transepose_identity,
self.is_reshape_identity,
self.is_second_transepose_identity)
if is_identify_list == (True, True, True):
return self._collect_identity_strategies(device_mesh)
elif is_identify_list == (True, True, False):
return self._collect_transpose_strategies(
device_mesh, self.second_transpose_dims)
elif is_identify_list == (False, True, True):
return self._collect_transpose_strategies(device_mesh,
self.first_transpose_dims)
elif is_identify_list == (True, False, True):
return self._collect_reshape_strategies(device_mesh)
else:
assert False, f"Unsupported shuffle pattern now {is_identify_list}"
def _profile_sharding_cost(self, strategy, device_mesh):
updated_layer_attrs = {}
updated_input_values = {}
output_shape = strategy.sharding_specs[
'output0'].get_sharded_shape_per_device()
self.layer.to_subclass()
second_transpose = self.layer.as_trt().second_transpose
self.layer.to_base_class()
reshape_dims = [*output_shape]
for i in range(len(output_shape)):
reshape_dims[second_transpose[i]] = output_shape[i]
if self.layer.num_inputs >= 2:
updated_input_values[1] = reshape_dims
else:
updated_layer_attrs['reshape_dims'] = reshape_dims
elapsed_time = self.node_runtime_profiler.runtime_profile(
self.layer, updated_layer_attrs, updated_input_values, strategy,
device_mesh)
return elapsed_time

View File

@ -1,100 +0,0 @@
import copy
from .node import Node
from .sharding_strategy import StrategiesVector
class Slice(Node):
def __init__(self, layer):
super().__init__(layer)
layer.to_subclass()
input_shape = self.get_input(0).shape
output_shape = self.get_output(0).shape
assert len(input_shape) == len(
output_shape
), f'dims of input shape {input_shape} != dims of output shape {output_shape}'
if layer.num_inputs >= 2 and layer.get_input(1) is not None:
start = layer.get_input(1).value
else:
start = layer.as_trt().start
if layer.num_inputs >= 4 and layer.get_input(3) is not None:
stride = layer.get_input(3).value
else:
stride = layer.as_trt().stride
self.keep_partition_dims = [(input_shape[i] == output_shape[i]
and start[i] == 0 and stride[i] == 1)
for i in range(len(input_shape))]
layer.to_base_class()
def _update_memory_cost(self, strategies):
for strategy in strategies:
# for slice node, it input0's read = output0's write
inout_memory_footprint = strategy.sharding_specs[
'output0'].get_sharded_size_per_device() * 2
strategy.inout_memory_footprint = inout_memory_footprint
strategy.peak_memory_footprint = (
strategy.sharding_specs['input0'].
get_max_sharded_size_per_device() + strategy.
sharding_specs['output0'].get_max_sharded_size_per_device())
def _collect_strategies(self, device_mesh):
dim_partition_list = []
dim_size = len(self.op_data['input0'].shape)
dim_partition_list.append({})
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
strategies_vector = StrategiesVector(self)
# dim_partition_dict can be the same as previous node if solver's time is a problem
for dim_partition_dict in dim_partition_list:
for dim in range(len(self.keep_partition_dims)):
if (not self.keep_partition_dims[dim]
) and dim in dim_partition_dict:
dim_partition_dict.pop(dim)
in0_partition_dict = dim_partition_dict
out_partition_dict = copy.deepcopy(dim_partition_dict)
dim_partition_dict_mapping = {
"input0": in0_partition_dict,
"output0": out_partition_dict,
}
for i in range(1, self.num_inputs):
if self.predecessor_nodes[i]:
dim_partition_dict_mapping[f"input{i}"] = {}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
name = '{} = {} <slice op> '.format(
sharding_spec_mapping['output0'].sharding_sequence,
sharding_spec_mapping['input0'].sharding_sequence)
for i in range(1, self.num_inputs):
if self.predecessor_nodes[i]:
name = name + str(
sharding_spec_mapping[f'input{i}'].sharding_sequence)
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector
def _profile_sharding_cost(self, strategy, device_mesh):
updated_layer_attrs = {}
updated_input_values = {}
shape = strategy.sharding_specs['output0'].get_sharded_shape_per_device(
)
if self.layer.num_inputs >= 3 and self.layer.get_input(2) is not None:
updated_input_values[2] = shape
else:
updated_layer_attrs['shape'] = shape
elapsed_time = self.node_runtime_profiler.runtime_profile(
self.layer, updated_layer_attrs, updated_input_values, strategy,
device_mesh)
return elapsed_time

View File

@ -1,54 +0,0 @@
import copy
from tensorrt_llm._utils import trt_axes_to_dim
from .node import Node
from .sharding_strategy import StrategiesVector
class SoftMax(Node):
def __init__(self, layer):
super().__init__(layer)
layer.to_subclass()
self.softmax_dim = trt_axes_to_dim(layer.as_trt().axes)[0]
layer.to_base_class()
def _collect_strategies(self, device_mesh):
dim_partition_list = []
dim_size = len(self.op_data['input0'].shape)
dim_partition_list.append({})
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
strategies_vector = StrategiesVector(self)
# dim_partition_dict can be the same as previous node if solver's time is a problem
for dim_partition_dict in dim_partition_list:
if self.softmax_dim in dim_partition_dict:
dim_partition_dict.pop(self.softmax_dim)
in0_partition_dict = dim_partition_dict
out_partition_dict = copy.deepcopy(dim_partition_dict)
dim_partition_dict_mapping = {
"input0": in0_partition_dict,
"output0": out_partition_dict,
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
name = '{} = <softmax along dim {}> {}'.format(
sharding_spec_mapping['output0'].sharding_sequence,
self.softmax_dim,
sharding_spec_mapping['input0'].sharding_sequence)
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector

View File

@ -1,42 +0,0 @@
import copy
from .node import Node
from .sharding_strategy import StrategiesVector
class Unary(Node):
def _collect_strategies(self, device_mesh):
dim_partition_list = []
dim_size = len(self.op_data['input0'].shape)
dim_partition_list.append({})
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_1d_sharding([0, 1], dim_size))
dim_partition_list.extend(
self._enumerate_all_possible_2d_sharding([0], [1], dim_size))
strategies_vector = StrategiesVector(self)
# dim_partition_dict can be the same as previous node if solver's time is a problem
for dim_partition_dict in dim_partition_list:
in0_partition_dict = dim_partition_dict
out_partition_dict = copy.deepcopy(dim_partition_dict)
dim_partition_dict_mapping = {
"input0": in0_partition_dict,
"output0": out_partition_dict,
}
sharding_spec_mapping = self._to_sharding_spec_mapping(
dim_partition_dict_mapping, device_mesh)
if 0 == len(sharding_spec_mapping):
continue
name = '{} = <unary op> {}'.format(
sharding_spec_mapping['output0'].sharding_sequence,
sharding_spec_mapping['input0'].sharding_sequence)
sharding_strategy = self._get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
strategies_vector.append(sharding_strategy)
return strategies_vector

View File

@ -1,308 +0,0 @@
import contextlib
import threading
try:
from types import NoneType
except ImportError:
NoneType = type(None)
from typing import ByteString, Iterable, MutableMapping
import tensorrt as trt
import torch
from tensorrt_llm._utils import get_extra_attr, np_dtype_to_trt, set_extra_attr
from tensorrt_llm.logger import logger
from tensorrt_llm.network import PluginInfo, get_plugin_info
LAYER_TYPE_2_CLASS = {
trt.LayerType.ACTIVATION: trt.IActivationLayer,
trt.LayerType.CONCATENATION: trt.IConcatenationLayer,
trt.LayerType.CONSTANT: trt.IConstantLayer,
trt.LayerType.ELEMENTWISE: trt.IElementWiseLayer,
trt.LayerType.FILL: trt.IFillLayer,
trt.LayerType.GATHER: trt.IGatherLayer,
trt.LayerType.MATRIX_MULTIPLY: trt.IMatrixMultiplyLayer,
trt.LayerType.REDUCE: trt.IReduceLayer,
trt.LayerType.SELECT: trt.ISelectLayer,
trt.LayerType.SHUFFLE: trt.IShuffleLayer,
trt.LayerType.SLICE: trt.ISliceLayer,
trt.LayerType.SOFTMAX: trt.ISoftMaxLayer,
trt.LayerType.UNARY: trt.IUnaryLayer,
trt.LayerType.SHAPE: trt.IShapeLayer,
trt.LayerType.ASSERTION: trt.IAssertionLayer,
trt.LayerType.CAST: trt.ICastLayer,
trt.LayerType.NORMALIZATION: trt.INormalizationLayer,
trt.LayerType.IDENTITY: trt.IIdentityLayer,
trt.LayerType.PLUGIN_V2: trt.IPluginV2Layer,
}
def to_subclass_layer(trt_layer):
trt_layer.__class__ = LAYER_TYPE_2_CLASS[trt_layer.type]
def to_base_class_layer(trt_layer):
trt_layer.__class__ = trt.ILayer
def to_trt_weights(ndarray):
weight = trt.Weights(
np_dtype_to_trt(ndarray.dtype),
ndarray.ctypes.data,
ndarray.size,
)
# Prevent numpy array from going out of weight's lifetime scope
set_extra_attr(weight, "numpy", ndarray)
return weight
@contextlib.contextmanager
def silent_trt_logger():
min_severity = logger.trt_logger.min_severity
logger.trt_logger.min_severity = trt.Logger.ERROR
yield
logger.trt_logger.min_severity = min_severity
def compare_tensor(trt_tensor, new_trt_tensor):
assert trt_tensor.name == new_trt_tensor.name
assert trt_tensor.dtype == new_trt_tensor.dtype
assert tuple(trt_tensor.shape) == tuple(new_trt_tensor.shape)
assert trt_tensor.broadcast_across_batch == new_trt_tensor.broadcast_across_batch
assert trt_tensor.location == new_trt_tensor.location
assert trt_tensor.is_network_input == new_trt_tensor.is_network_input
assert trt_tensor.is_network_output == new_trt_tensor.is_network_output
assert trt_tensor.dynamic_range == new_trt_tensor.dynamic_range
assert trt_tensor.is_shape_tensor == new_trt_tensor.is_shape_tensor
assert trt_tensor.is_execution_tensor == new_trt_tensor.is_execution_tensor
assert trt_tensor.allowed_formats == new_trt_tensor.allowed_formats
def compare_network(trt_network, new_trt_network):
assert trt_network.num_inputs == new_trt_network.num_inputs
for i in range(trt_network.num_inputs):
input = trt_network.get_input(i)
new_input = new_trt_network.get_input(i)
compare_tensor(input, new_input)
assert trt_network.num_outputs == new_trt_network.num_outputs
for i in range(trt_network.num_outputs):
output = trt_network.get_output(i)
new_output = new_trt_network.get_output(i)
compare_tensor(output, new_output)
assert trt_network.num_layers == new_trt_network.num_layers
for index, new_index in zip(get_sorted_layer_ids(trt_network),
get_sorted_layer_ids(new_trt_network)):
layer = trt_network.get_layer(index)
new_layer = new_trt_network.get_layer(new_index)
assert layer.name == new_layer.name
assert layer.type == new_layer.type
assert layer.precision_is_set == new_layer.precision_is_set
assert layer.precision == new_layer.precision
assert layer.num_inputs == new_layer.num_inputs
for j in range(layer.num_inputs):
input = layer.get_input(j)
new_input = new_layer.get_input(j)
if input is None:
assert new_input is None
else:
assert new_input is not None
compare_tensor(input, new_input)
assert layer.num_outputs == new_layer.num_outputs
for j in range(layer.num_outputs):
output = layer.get_output(j)
new_output = new_layer.get_output(j)
compare_tensor(output, new_output)
assert layer.output_type_is_set(j) == new_layer.output_type_is_set(
j)
if layer.output_type_is_set(j):
assert layer.get_output_type(j) == new_layer.get_output_type(j)
def get_sorted_layer_ids(trt_network):
inputs = set()
for i in range(trt_network.num_inputs):
inputs.add(trt_network.get_input(i).name)
layer_ids = [*range(trt_network.num_layers)]
sorted_layer_ids = []
walked_tensors = set(inputs)
while len(layer_ids) > 0:
layer_id = layer_ids.pop(0)
layer = trt_network.get_layer(layer_id)
no_dependencies = True
for j in range(layer.num_inputs):
input = layer.get_input(j)
if input is None:
continue
if input.name in walked_tensors:
continue
else:
no_dependencies = False
break
if no_dependencies:
sorted_layer_ids.append(layer_id)
for j in range(layer.num_outputs):
output = layer.get_output(j)
if output is None:
continue
walked_tensors.add(output.name)
else:
layer_ids.append(layer_id)
assert len(sorted_layer_ids) == trt_network.num_layers
return sorted_layer_ids
def to_tuple(values):
if isinstance(values, (int, float, str, bool, NoneType, ByteString)):
return values
elif isinstance(values, (trt.Dims, trt.Permutation)):
if values.__len__() < 0:
return None
else:
return tuple(values)
elif isinstance(values, Iterable):
return tuple(to_tuple(v) for v in values)
elif isinstance(values, MutableMapping):
return tuple((k, to_tuple(v)) for k, v in values.items())
else:
return values
_base_layer_attr_names = set(dir(trt.ILayer))
def get_cache_key(layer, shapes, values, dtypes=None, updated_attrs=None):
updated_attrs = updated_attrs or {}
layer_type = layer.type
to_subclass_layer(layer)
attr_names = set(dir(layer)) - _base_layer_attr_names
if layer_type == trt.LayerType.CONSTANT:
attr_names.remove("weights")
elif layer_type == trt.LayerType.SHUFFLE:
if layer.num_inputs >= 2:
attr_names.remove("reshape_dims")
elif layer_type == trt.LayerType.SLICE:
if layer.num_inputs >= 2 and layer.get_input(1) is not None:
attr_names.remove("start")
if layer.num_inputs >= 3 and layer.get_input(2) is not None:
attr_names.remove("shape")
if layer.num_inputs >= 4 and layer.get_input(3) is not None:
attr_names.remove("stride")
elif layer_type == trt.LayerType.FILL:
attr_names.remove("is_alpha_beta_int64")
if layer.num_inputs >= 1 and layer.get_input(0) is not None:
attr_names.remove("shape")
if layer.num_inputs >= 2 and layer.get_input(1) is not None:
attr_names.remove("alpha")
if layer.num_inputs >= 3 and layer.get_input(2) is not None:
attr_names.remove("beta")
if layer_type != trt.LayerType.PLUGIN_V2:
attr_key = tuple(
(name, to_tuple(updated_attrs.get(name) or getattr(layer, name)))
for name in sorted(attr_names))
else:
network = get_trt_network(layer)
plugin_info = get_plugin_info(network, layer.name)
assert plugin_info is not None, f"layer {layer.name} does not register plugin info"
attr_key = tuple(
(name, tuple(updated_attrs.get(name) or data))
for name, data in sorted(plugin_info.pfc_as_list.items()))
to_base_class_layer(layer)
shape_key = ()
value_key = ()
dtype_key = ()
for i in range(layer.num_inputs):
input = layer.get_input(i)
if input is not None:
shape_key += (tuple(shapes[input.name]), )
if input.name in values:
value = values[input.name]
# All torch tensors are derived from input shapes and pfc,
# thus we ignore them in cache key
if isinstance(value, torch.Tensor):
value = None
else:
value = tuple(value)
value_key += (value, )
else:
value_key += (None, )
if dtypes is not None:
dtype_key += (dtypes[input.name], )
else:
shape_key += (None, )
value_key += (None, )
dtype_key += (None, )
if dtypes is not None:
for i in range(layer.num_outputs):
output = layer.get_output(i)
dtype_key += (dtypes[output.name], )
cache_key = (layer.type, attr_key, shape_key, value_key)
if dtypes is not None:
cache_key += (dtype_key, )
return cache_key
def get_trt_network(layer: trt.ILayer):
network = get_extra_attr(layer, "network")
assert network is not None
return network
def set_trt_network(layer: trt.ILayer, network: trt.INetworkDefinition):
set_extra_attr(layer, "network", network)
def get_updated_plugin(plugin_info: PluginInfo, updated_attrs):
fields = []
for field in plugin_info.pfc:
name = field.name
if name in updated_attrs:
field = trt.PluginField(name, updated_attrs[name], field.type)
else:
field = trt.PluginField(name, plugin_info.pfc_as_ndarray[name],
field.type)
fields.append(field)
pfc = trt.PluginFieldCollection(fields)
plugin = plugin_info.plugin_creator.create_plugin(plugin_info.plugin_name,
pfc)
new_plugin_info = PluginInfo(plugin_info.plugin_creator,
plugin_info.plugin_name, pfc)
return plugin, new_plugin_info
_builder_flags = threading.local()
_strongly_typed = threading.local()
def get_builder_flags():
return getattr(_builder_flags, 'value', 0)
def get_strongly_typed():
return getattr(_strongly_typed, 'value', False)
@contextlib.contextmanager
def current_flags(builder_flags, strongly_typed):
previous_builder_flags = get_builder_flags()
_builder_flags.value = builder_flags
previous_strongly_typed = get_strongly_typed()
_strongly_typed.value = strongly_typed
yield
_builder_flags.value = previous_builder_flags
_strongly_typed.value = previous_strongly_typed
def get_engine_information(engine_file) -> str:
with open(engine_file, "rb") as f:
engine_buffer = f.read()
runtime = trt.Runtime(logger.trt_logger)
engine = runtime.deserialize_cuda_engine(engine_buffer)
inspector = engine.create_engine_inspector()
return inspector.get_engine_information(trt.LayerInformationFormat.JSON)
def print_engine_info(engine_file) -> dict:
with open(engine_file, "rb") as f:
engine_buffer = f.read()
from tensorrt_llm.runtime.session import Session
Session.from_serialized_engine(engine_buffer)._print_engine_info()

View File

@ -44,16 +44,8 @@ def get_settings_from_engine(
with open(config_path, "r") as config_json:
config = json.load(config_json)
engine_world_map = config["pretrained_config"]["mapping"]
mapping = config["pretrained_config"]["mapping"]
engine_build_cfg = config["build_config"]
engine_parallel_map = engine_build_cfg["auto_parallel_config"]
world_config = {
"pp_size": engine_world_map["pp_size"],
"tp_size": engine_world_map["tp_size"],
"world_size": engine_world_map["world_size"],
"gpus_per_node": engine_parallel_map["gpus_per_node"],
}
executor_settings = {
"max_batch_size": engine_build_cfg["max_batch_size"],
@ -64,7 +56,7 @@ def get_settings_from_engine(
"sw_version": config["version"],
"engine_dir": str(engine_path.absolute()),
"settings_config": executor_settings,
"world_config": world_config,
"mapping": mapping,
})
runtime_config["performance_options"] = {}
@ -104,12 +96,13 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
enable_chunked_prefill = llm_args_dict.get("enable_chunked_prefill",
enable_chunked_prefill)
world_config = {
mapping = {
"pp_size": params.get("pp"),
"tp_size": params.get("tp"),
"world_size": params.get("pp") * params.get("tp"),
"ep_size": params.get("ep"),
"cluster_size": params.get("cluster_size"),
"moe_ep_size": params.get("ep"),
"moe_cluster_size": params.get("cluster_size"),
"gpus_per_node": params.get("gpus_per_node"),
}
if params.get("max_batch_size") and params.get("max_num_tokens"):
@ -184,7 +177,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
"max_num_tokens": int(max_num_tokens),
"chunking": enable_chunked_prefill,
},
"world_config": world_config,
"mapping": mapping,
"backend": backend,
"decoding_config": {},
"performance_options": {

View File

@ -1,7 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass
from importlib.util import find_spec
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
@ -29,7 +28,9 @@ class RuntimeConfig(BaseModel):
engine_dir: Optional[Path] = None
sw_version: str
settings_config: ExecutorSettingsConfig
world_config: ExecutorWorldConfig
# TODO: this is a dict corresponding to the Mapping class, the type should be
# changed to Mapping after the Mapping class is migrated to a Pydantic model.
mapping: Dict[str, Any]
decoding_config: Optional[DecodingConfig] = None
performance_options: PerformanceOptions
backend: Literal["pytorch", "_autodeploy", None] = None
@ -47,15 +48,15 @@ class RuntimeConfig(BaseModel):
"skip_tokenizer_init":
True,
"pipeline_parallel_size":
self.world_config.pp_size,
self.mapping["pp_size"],
"tensor_parallel_size":
self.world_config.tp_size,
self.mapping["tp_size"],
"gpus_per_node":
self.world_config.gpus_per_node,
self.mapping["gpus_per_node"],
"moe_expert_parallel_size":
self.world_config.ep_size,
self.mapping["moe_ep_size"],
"moe_cluster_parallel_size":
self.world_config.cluster_size,
self.mapping["moe_cluster_size"],
"trust_remote_code":
True,
"enable_chunked_prefill":
@ -164,57 +165,6 @@ class DecodingConfig(BaseModel):
return trtllm.DecodingConfig(**kwargs)
class ExecutorWorldConfig(BaseModel):
pp_size: int = 1
tp_size: int = 1
# None to make LLM-API deduce it with a rule.
gpus_per_node: Optional[int] = None
leader_mode: bool = False
ep_size: Optional[int] = None
cluster_size: Optional[int] = None
@model_validator(mode="after")
def validate_world_size(self) -> ExecutorWorldConfig:
if self.gpus_per_node is None:
return self
parallel_world = self.pp_size * self.tp_size
num_gpus = self.world_size * self.gpus_per_node
valid_world = bool(num_gpus >= parallel_world)
if not valid_world:
raise ValueError(
f"World configuration is invalid, TP * PP ({parallel_world})"
"does not equal the total number of available GPUs"
f"({num_gpus}).")
return self
@property
def world_size(self) -> int:
return self.pp_size * self.tp_size
def _get_tensorrt_llm_executor_worker_path(self) -> Path:
module_path = find_spec("tensorrt_llm").loader.get_filename()
exec_path = Path(module_path).parent / 'bin' / 'executorWorker'
return exec_path.absolute()
def get_parallel_config(self) -> trtllm.ParallelConfig:
if self.leader_mode:
comm_mode = trtllm.CommunicationMode.LEADER
orchestrator_config = None
else:
comm_mode = trtllm.CommunicationMode.ORCHESTRATOR
orchestrator_config = trtllm.OrchestratorConfig(
True, str(self._get_tensorrt_llm_executor_worker_path()))
return trtllm.ParallelConfig(
trtllm.CommunicationType.MPI,
comm_mode,
orchestrator_config=orchestrator_config,
)
class ExecutorSettingsConfig(BaseModel):
chunking: bool = True
scheduler_policy: CapacitySchedulerPolicy = Field(

View File

@ -336,10 +336,10 @@ class ReportUtility:
# World and runtime info
stats_dict["world_info"] = {
"tp_size": self.rt_cfg.world_config.tp_size,
"pp_size": self.rt_cfg.world_config.pp_size,
"ep_size": self.rt_cfg.world_config.ep_size,
"world_size": self.rt_cfg.world_config.world_size,
"tp_size": self.rt_cfg.mapping["tp_size"],
"pp_size": self.rt_cfg.mapping["pp_size"],
"ep_size": self.rt_cfg.mapping["moe_ep_size"],
"world_size": self.rt_cfg.mapping["world_size"],
"max_batch_size": self.rt_cfg.settings_config.max_batch_size,
"max_num_tokens": self.rt_cfg.settings_config.max_num_tokens,
"scheduling_policy": self.rt_cfg.settings_config.scheduler_policy,
@ -380,7 +380,7 @@ class ReportUtility:
self.per_user_output_throughput_tok_s,
# Output throughput per GPU (total throughput / world size)
"output_throughput_per_gpu_tok_s":
self.output_throughput_tok_s / self.rt_cfg.world_config.world_size,
self.output_throughput_tok_s / self.rt_cfg.mapping["world_size"],
# Request latency percentiles
"request_latency_percentiles_ms":
self.statistics.request_latency_percentiles.model_dump(

View File

@ -30,8 +30,6 @@ import tensorrt as trt
from ._common import _is_building, check_max_num_tokens, serialize_engine
from ._utils import (get_sm_version, np_bfloat16, np_float8, str_dtype_to_trt,
to_json_file, trt_gte)
from .auto_parallel import auto_parallel
from .auto_parallel.config import AutoParallelConfig
from .bindings import KVCacheType
from .functional import PositionEmbeddingType
from .graph_rewriting import optimize
@ -84,19 +82,12 @@ class BuilderConfig(object):
"plugin_config": {
# the network plugin_config (if any) attached to this BuilderConfig object
# inside the Builder.build_engine
},
"auto_parallel_config": {
# the network auto_parallel_config (if any) attached to this BuilderConfig object
# inside the Builder.build_engine
}
}
'''
config = {'builder_config': {}}
for k in self.__dict__.keys():
if k not in [
'_trt_builder_config', 'plugin_config',
'auto_parallel_config'
]:
if k not in ['_trt_builder_config', 'plugin_config']:
config['builder_config'][k] = self.__getattribute__(k)
if hasattr(self, 'plugin_config'):
assert isinstance(self.plugin_config, PluginConfig), \
@ -270,17 +261,6 @@ class Builder():
min_shape = [*shape_profile.min]
opt_shape = [*shape_profile.opt]
max_shape = [*shape_profile.max]
if network._auto_parallel_config is not None:
io_shards = network._auto_parallel_config["io_shards"]
if input_name in io_shards:
shards = io_shards[input_name]
for dim, shard_num in shards.items():
min_shape[dim] = int(
math.floor(min_shape[dim] / shard_num))
opt_shape[dim] = int(
round(opt_shape[dim] / shard_num))
max_shape[dim] = int(
math.ceil(max_shape[dim] / shard_num))
profile.set_shape(input_name, min_shape, opt_shape, max_shape)
logger.debug(
f'{input_name}, min: {min_shape}, opt: {opt_shape}, max: {max_shape}, dimension names: {shape_profile.dimension_names}'
@ -389,7 +369,6 @@ class Builder():
'''
assert isinstance(network, Network)
builder_config.plugin_config = network.plugin_config
builder_config.auto_parallel_config = network.auto_parallel_config
if builder_config.trt_builder_config.num_optimization_profiles == 0:
self._add_optimization_profile(network, builder_config)
logger.info(
@ -506,7 +485,6 @@ class BuildConfig:
input_timing_cache (str, optional): Path to input timing cache file. If None, no input cache used. Defaults to None.
output_timing_cache (str): Path to output timing cache file. Defaults to 'model.cache'.
lora_config (LoraConfig): Configuration for LoRA (Low-Rank Adaptation) fine-tuning. Defaults to default LoraConfig.
auto_parallel_config (AutoParallelConfig): Configuration for automatic parallelization. Defaults to default AutoParallelConfig.
weight_sparsity (bool): Whether to enable weight sparsity optimization. Defaults to False.
weight_streaming (bool): Whether to enable weight streaming for large models. Defaults to False.
plugin_config (PluginConfig): Configuration for TensorRT LLM plugins. Defaults to default PluginConfig.
@ -538,8 +516,6 @@ class BuildConfig:
input_timing_cache: str = None
output_timing_cache: str = 'model.cache'
lora_config: LoraConfig = field(default_factory=LoraConfig)
auto_parallel_config: AutoParallelConfig = field(
default_factory=AutoParallelConfig)
weight_sparsity: bool = False
weight_streaming: bool = False
plugin_config: PluginConfig = field(default_factory=PluginConfig)
@ -659,9 +635,7 @@ class BuildConfig:
defaults.get('input_timing_cache'))
output_timing_cache = config.pop('output_timing_cache',
defaults.get('output_timing_cache'))
lora_config = LoraConfig.from_dict(config.get('lora_config', {}))
auto_parallel_config = AutoParallelConfig.from_dict(
config.get('auto_parallel_config', {}))
lora_config = LoraConfig(**config.get('lora_config', {}))
max_encoder_input_len = config.pop(
'max_encoder_input_len', defaults.get('max_encoder_input_len'))
weight_streaming = config.pop('weight_streaming',
@ -704,7 +678,6 @@ class BuildConfig:
input_timing_cache=input_timing_cache,
output_timing_cache=output_timing_cache,
lora_config=lora_config,
auto_parallel_config=auto_parallel_config,
use_strip_plan=use_strip_plan,
max_encoder_input_len=max_encoder_input_len,
weight_sparsity=weight_sparsity,
@ -727,9 +700,7 @@ class BuildConfig:
if output.get('kv_cache_type', None) is not None:
output['kv_cache_type'] = str(output['kv_cache_type'].name)
output['plugin_config'] = output['plugin_config'].model_dump()
output['lora_config'] = output['lora_config'].to_dict()
output['auto_parallel_config'] = output['auto_parallel_config'].to_dict(
)
output['lora_config'] = output['lora_config'].model_dump()
return output
def update_from_dict(self, config: dict):
@ -892,7 +863,6 @@ def get_engine_version(engine_dir: str) -> Union[None, str]:
def optimize_model_with_config(model: PretrainedModel,
build_config: BuildConfig):
use_auto_parallel = build_config.auto_parallel_config.enabled
gemm_swiglu_plugin = build_config.plugin_config.gemm_swiglu_plugin
low_latency_gemm_swiglu_plugin = build_config.plugin_config.low_latency_gemm_swiglu_plugin
if gemm_swiglu_plugin or low_latency_gemm_swiglu_plugin:
@ -922,12 +892,11 @@ def optimize_model_with_config(model: PretrainedModel,
use_ootb_moe=build_config.plugin_config.moe_plugin is None,
use_fused_mlp=(build_config.plugin_config.use_fused_mlp
and not is_enc_dec
and not (is_recurrent_gemma and is_fp8)
and not use_auto_parallel),
and not (is_recurrent_gemma and is_fp8)),
gemm_swiglu_plugin_dtype=gemm_swiglu_plugin,
low_latency_gemm_swiglu_plugin_dtype=low_latency_gemm_swiglu_plugin,
use_fused_rg_lru=is_recurrent_gemma,
use_unfused_qkv_gemm=use_auto_parallel,
use_unfused_qkv_gemm=False,
use_prompt_tuning=(build_config.max_prompt_embedding_table_size > 0),
use_lora=build_config.plugin_config.lora_plugin is not None,
max_lora_rank=build_config.lora_config.max_lora_rank,
@ -1281,7 +1250,6 @@ def build(model: PretrainedModel, build_config: BuildConfig) -> Engine:
network = builder.create_network()
network.plugin_config = build_config.plugin_config
use_auto_parallel = build_config.auto_parallel_config.enabled
use_weight_only = model.config.quant_mode.is_weight_only()
per_group = model.config.quant_mode.has_per_group_scaling()
use_smooth_quant = model.config.quant_mode.has_act_and_weight_quant()
@ -1391,15 +1359,6 @@ def build(model: PretrainedModel, build_config: BuildConfig) -> Engine:
if model.config.architecture != "DecoderModel":
optimize(network)
if use_auto_parallel:
config = build_config.auto_parallel_config
config.builder_flags = builder_config.trt_builder_config.flags
sharded_networks = auto_parallel(network, config)
network = sharded_networks[model.config.mapping.rank]
if not build_config.auto_parallel_config.debug_mode:
mapping = network.auto_parallel_config["mapping"]
model.config.mapping = mapping
if build_config.visualize_network is not None:
with net_guard(network):
network.to_onnx(build_config.visualize_network)

View File

@ -26,8 +26,6 @@ import torch
from tensorrt_llm._utils import (local_mpi_rank, local_mpi_size, mpi_barrier,
mpi_comm, mpi_rank, mpi_world_size)
from tensorrt_llm.auto_parallel import infer_cluster_config
from tensorrt_llm.auto_parallel.cluster_info import cluster_infos
from tensorrt_llm.bindings import KVCacheType
from tensorrt_llm.builder import BuildConfig, Engine, build
from tensorrt_llm.logger import logger, severity_map
@ -305,29 +303,6 @@ def parse_arguments():
"Maximum lengths of draft tokens for speculative decoding target model."
)
autopp_parser = parser.add_argument_group("Auto parallel arguments")
autopp_parser.add_argument('--auto_parallel',
type=int,
default=1,
help="MPI world size for auto parallel.")
autopp_parser.add_argument(
'--gpus_per_node',
type=int,
default=8,
help=
"Number of GPUs each node has in a multi-node setup. This is a cluster spec and can be greater/smaller than world size. "
"This option is only used for auto parallel specified with ``--auto_parallel``."
)
autopp_parser.add_argument(
'--cluster_key',
type=str,
default=None,
choices=cluster_infos.keys(),
help=
"Unique name for target GPU type. Inferred from current GPU type if not specified. "
"This option is only used for auto parallel specified with ``--auto_parallel``."
)
plugin_config_parser = parser.add_argument_group("Plugin config arguments")
add_plugin_argument(plugin_config_parser)
return parser
@ -355,16 +330,7 @@ def build_model(
assert not build_config.plugin_config.pp_reduce_scatter or architecture == "MixtralForCausalLM", \
"PP reduce scatter is only supported in the mixtral model."
model_config.mapping.gpus_per_node = build_config.auto_parallel_config.gpus_per_node
if build_config.auto_parallel_config.enabled:
assert rank < build_config.auto_parallel_config.world_size
assert model_config.mapping.pp_size == 1 and model_config.mapping.tp_size == 1, \
"You must convert to full model with TP=1&&PP=1 to use auto parallel planner"
model_config.mapping.auto_parallel = True
model_config.mapping.world_size = build_config.auto_parallel_config.world_size
model_config.mapping.rank = rank
else:
assert rank < model_config.mapping.world_size
assert rank < model_config.mapping.world_size
rank_config = copy.deepcopy(model_config)
rank_config.set_rank(rank)
@ -419,17 +385,7 @@ def parallel_build(model_config: PretrainedConfig,
model_cls=None,
**kwargs):
if build_config.auto_parallel_config.enabled:
if model_config.mapping.world_size > 1:
raise RuntimeError(
"manually TP and PP are not supported in auto parallel mode.")
if build_config.auto_parallel_config.debug_mode:
world_size = 1
else:
world_size = build_config.auto_parallel_config.world_size
else:
world_size = model_config.mapping.world_size
world_size = model_config.mapping.world_size
use_mpi = mpi_world_size() > 1
if not use_mpi and workers == 1:
@ -550,10 +506,6 @@ def main():
raise RuntimeError(
"multiple_profiles is enabled, while opt_num_tokens is set. "
"They are not supposed to be working in the same time for now.")
if args.cluster_key is not None:
cluster_config = dict(cluster_key=args.cluster_key)
else:
cluster_config = infer_cluster_config()
# This should only be used for debugging.
# The env var BUILDER_FORCE_NUM_PROFILES should override the number of
@ -605,20 +557,6 @@ def main():
args.input_timing_cache,
'output_timing_cache':
args.output_timing_cache,
'auto_parallel_config': {
'world_size':
args.auto_parallel,
'gpus_per_node':
args.gpus_per_node,
'sharded_io_allowlist': [
'past_key_value_\\d+',
'present_key_value_\\d*',
],
'same_buffer_io': {
'past_key_value_(\\d+)': 'present_key_value_\\1',
},
**cluster_config,
},
'dry_run':
args.dry_run,
'visualize_network':

View File

@ -140,8 +140,7 @@ class BuildCache:
@staticmethod
def prune_build_config_for_cache_key(build_config: dict) -> dict:
# The BuildCache will be disabled once auto_pp is enabled, so 'auto_parallel_config' should be removed
black_list = ['auto_parallel_config', 'dry_run']
black_list = ['dry_run']
dic = build_config.copy()
for key in black_list:
if key in dic:

View File

@ -6,7 +6,7 @@ import math
import os
import types
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from dataclasses import dataclass
from enum import Enum, EnumMeta
from pathlib import Path
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional,
@ -25,7 +25,6 @@ from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules)
from .._utils import mpi_rank
from ..auto_parallel import AutoParallelConfig, infer_cluster_config
if TYPE_CHECKING:
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
@ -277,23 +276,21 @@ class AttentionDpConfig(StrictBaseModel):
return cls(**data)
@dataclass
class _ParallelConfig:
''' The model distribution configs for LLM. '''
class _ParallelConfig(StrictBaseModel):
"""The model distribution configs for LLM."""
tp_size: int = 1
pp_size: int = 1
cp_size: int = 1
gpus_per_node: int = 8
moe_cluster_size: int = 1
moe_tp_size: int = 1
moe_ep_size: int = 1
cp_config: dict = field(default_factory=dict)
# Set default for MoE fields to -1 to trigger auto-calculation in Mapping
moe_cluster_size: int = -1
moe_tp_size: int = -1
moe_ep_size: int = -1
cp_config: dict = Field(default_factory=dict)
enable_attention_dp: bool = False
enable_lm_head_tp_in_adp: bool = False
auto_parallel: bool = False
_world_size: int = field(default=1, init=False)
_devices: Optional[List[int]] = field(default=None, init=False)
_devices: Optional[List[int]] = PrivateAttr(default=None)
@property
def devices(self) -> List[int]:
@ -310,18 +307,7 @@ class _ParallelConfig:
self._devices = devices
@property
def world_size(self) -> bool:
if self.auto_parallel:
if self.tp_size > 1 or self.pp_size > 1 or self.cp_size > 1:
raise RuntimeError(
"manually TP and PP are not supported in auto parallel mode."
)
return self._world_size
if self._world_size > 1:
raise RuntimeError(
"world_size > 1 is only supported in auto parallel mode.")
def world_size(self) -> int:
return self.tp_size * self.pp_size * self.cp_size
@property
@ -332,12 +318,9 @@ class _ParallelConfig:
@world_size.setter
def world_size(self, world_size: int):
if self.auto_parallel:
self._world_size = world_size
elif (not self.auto_parallel
) and world_size != self.tp_size * self.pp_size * self.cp_size:
if world_size != self.tp_size * self.pp_size * self.cp_size:
raise ValueError(
f"world_size {world_size} should be equal to tp_size * pp_size {self.tp_size * self.pp_size * self.cp_size} "
f"world_size {world_size} should be equal to tp_size * pp_size * cp_size {self.tp_size * self.pp_size * self.cp_size} "
)
@property
@ -356,8 +339,7 @@ class _ParallelConfig:
enable_lm_head_tp_in_adp=self.enable_lm_head_tp_in_adp,
moe_cluster_size=self.moe_cluster_size,
moe_tp_size=self.moe_tp_size,
moe_ep_size=self.moe_ep_size,
auto_parallel=self.auto_parallel)
moe_ep_size=self.moe_ep_size)
class CalibConfig(StrictBaseModel):
@ -1704,7 +1686,7 @@ class BaseLlmArgs(StrictBaseModel):
status="prototype",
)
_parallel_config: Optional[object] = PrivateAttr(default=None)
_parallel_config: Optional[_ParallelConfig] = PrivateAttr(default=None)
_model_format: Optional[_ModelFormatKind] = PrivateAttr(default=None)
_speculative_model: Optional[str] = PrivateAttr(default=None)
_speculative_model_format: Optional[_ModelFormatKind] = PrivateAttr(
@ -1979,7 +1961,7 @@ class BaseLlmArgs(StrictBaseModel):
is QuantAlgo.FP8):
self._update_plugin_config("manage_weights", True)
if self.parallel_config._world_size == 1 and self.build_config:
if self.parallel_config.world_size == 1 and self.build_config:
self.build_config.plugin_config.nccl_plugin = None
if self.enable_lora and self.backend != 'pytorch':
@ -2182,7 +2164,6 @@ class BaseLlmArgs(StrictBaseModel):
moe_cluster_size = pretrained_config.mapping.moe_cluster_size
moe_tp_size = pretrained_config.mapping.moe_tp_size
moe_ep_size = pretrained_config.mapping.moe_ep_size
world_size = pretrained_config.mapping.world_size
gpus_per_node = pretrained_config.mapping.gpus_per_node
# load parallel_config
if self.parallel_config.tp_size != 1 and self.parallel_config.tp_size != tp_size:
@ -2197,20 +2178,14 @@ class BaseLlmArgs(StrictBaseModel):
raise ValueError(
f"cp_size {self.parallel_config.cp_size} is not consistent with the checkpoint's cp_size {cp_size}"
)
if (self.parallel_config.auto_parallel
and self.parallel_config.world_size != 1 and world_size != 1):
raise ValueError(
f"auto parallel with world_size {self.parallel_config.world_size} does not support checkpoint with "
"world_size {world_size} > 1")
if not self.parallel_config.auto_parallel:
self._parallel_config = _ParallelConfig(
tp_size=tp_size,
pp_size=pp_size,
cp_size=cp_size,
gpus_per_node=gpus_per_node,
moe_cluster_size=moe_cluster_size,
moe_tp_size=moe_tp_size,
moe_ep_size=moe_ep_size)
self._parallel_config = _ParallelConfig(
tp_size=tp_size,
pp_size=pp_size,
cp_size=cp_size,
gpus_per_node=gpus_per_node,
moe_cluster_size=moe_cluster_size,
moe_tp_size=moe_tp_size,
moe_ep_size=moe_ep_size)
def get_runtime_sizes(self, ) -> Tuple[int, int, int, int]:
return (
@ -2222,21 +2197,6 @@ class BaseLlmArgs(StrictBaseModel):
class TrtLlmArgs(BaseLlmArgs):
auto_parallel: bool = Field(
default=False,
description="Enable auto parallel mode.",
deprecated=
"Use tensor_parallel_size/pipeline_parallel_size/xxx_parallel_size instead.",
)
auto_parallel_world_size: Optional[int] = Field(
default=None,
description="The world size for auto parallel mode.",
deprecated=
"Use tensor_parallel_size/pipeline_parallel_size/xxx_parallel_size instead.",
)
enable_tqdm: bool = Field(default=False,
description="Enable tqdm for progress bar.")
@ -2288,16 +2248,10 @@ class TrtLlmArgs(BaseLlmArgs):
default=False, description="Normalize log probabilities.")
# Private attributes
_auto_parallel_config: Optional[AutoParallelConfig] = PrivateAttr(
default=None)
# This is used to hold the options for convert_checkpoint
_convert_checkpoint_options: Dict[str,
Any] = PrivateAttr(default_factory=dict)
@property
def auto_parallel_config(self) -> AutoParallelConfig:
return self._auto_parallel_config
@field_validator('calib_config', mode='before')
@classmethod
def init_calib_config(cls, v):
@ -2325,26 +2279,6 @@ class TrtLlmArgs(BaseLlmArgs):
# No else clause needed since validation already happened
return self
@model_validator(mode="after")
def validate_auto_parallel(self):
self._auto_parallel_config = AutoParallelConfig(
sharded_io_allowlist=[
"past_key_value_\\d+",
"present_key_value_\\d*",
],
same_buffer_io={
"past_key_value_(\\d+)": "present_key_value_\\1",
},
**infer_cluster_config(),
)
self.parallel_config.auto_parallel = self.auto_parallel
if self.parallel_config.auto_parallel:
self.parallel_config.world_size = self.auto_parallel_world_size
return self
@model_validator(mode="after")
def validate_enable_build_cache(self):
if not self.enable_build_cache:

View File

@ -5,17 +5,17 @@ import shutil
import tempfile
import time
import weakref
from dataclasses import asdict, dataclass, field
from dataclasses import asdict, dataclass, field, is_dataclass
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
import transformers
from pydantic import BaseModel
from tqdm import tqdm
from .._utils import (global_mpi_rank, local_mpi_rank, mpi_barrier,
mpi_broadcast, mpi_rank, release_gc)
from ..auto_parallel import AutoParallelConfig
# yapf: disable
from ..bindings.executor import (BatchingType, CapacitySchedulerPolicy,
ContextChunkingPolicy, ExecutorConfig,
@ -134,18 +134,6 @@ class ModelLoader:
assert self.llm_args.build_config
self.build_config = self.llm_args.build_config
self.auto_parallel_config = AutoParallelConfig(
world_size=llm_args.parallel_config.world_size if llm_args.
parallel_config.auto_parallel else 1)
default_config = self.llm_args.auto_parallel_config
self.auto_parallel_config.set_defaults(
cluster_key=default_config.cluster_key,
cluster_info=default_config.cluster_info,
same_buffer_io=default_config.same_buffer_io,
sharded_io_allowlist=default_config.sharded_io_allowlist,
)
self._gather_build_steps()
def _gather_build_steps(self):
@ -545,11 +533,7 @@ class ModelLoader:
# avoid the original build_config is modified, avoid the side effect
copied_build_config = copy.deepcopy(self.build_config)
copied_build_config.update(
auto_parallel_config=self.auto_parallel_config)
copied_build_config.update_kv_cache_type(self._model_info.architecture)
if self.auto_parallel_config.enabled:
self.model.config.mapping.rank = self.rank
assert self.model is not None, "model is loaded yet."
self._engine = build(self.model, copied_build_config)
@ -732,9 +716,9 @@ class CachedModelLoader:
def build_cache_enabled(self) -> bool:
_enable_build_cache, _ = get_build_cache_config_from_env()
return (self.llm_args.enable_build_cache or _enable_build_cache) and (
self.llm_args.model_format is _ModelFormatKind.HF
) and not self.llm_args.parallel_config.auto_parallel
return (self.llm_args.enable_build_cache
or _enable_build_cache) and (self.llm_args.model_format
is _ModelFormatKind.HF)
def _get_engine_cache_stage(self) -> CachedStage:
''' Get the cache stage for engine building. '''
@ -743,15 +727,19 @@ class CachedModelLoader:
assert self._hf_model_dir is not None, "HF model dir is required for cache key."
def serialize(d) -> str:
dic = asdict(d) if not isinstance(
d, PretrainedConfig) else d.to_dict()
if hasattr(d, "to_dict"):
dic = d.to_dict()
elif is_dataclass(d):
dic = asdict(d)
elif isinstance(d, BaseModel):
dic = d.model_dump(mode="json")
else:
raise ValueError(f"Could not serialize type: {type(d)}")
return json.dumps(dic, sort_keys=True)
parallel_config = self.llm_args.parallel_config
force_rebuild = False
if parallel_config.auto_parallel:
force_rebuild = True
if self.llm_args.model_format is not _ModelFormatKind.HF:
force_rebuild = True

View File

@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from typing import Dict, List, Literal, Optional
from ._utils import DictConversion
from pydantic import BaseModel, Field
def get_missing_qkv_modules_from_lora_modules(
@ -80,23 +79,16 @@ def use_lora(
f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")
@dataclass
class LoraConfig(DictConversion):
lora_dir: List[str] = field(default_factory=list)
lora_ckpt_source: str = "hf"
class LoraConfig(BaseModel):
lora_dir: List[str] = Field(default_factory=list)
lora_ckpt_source: Literal["hf", "nemo"] = "hf"
max_lora_rank: int = 64
lora_target_modules: List[str] = field(default_factory=list)
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
lora_target_modules: List[str] = Field(default_factory=list)
trtllm_modules_to_hf_modules: Dict[str, str] = Field(default_factory=dict)
max_loras: Optional[int] = None
max_cpu_loras: Optional[int] = None
swap_gate_up_proj_lora_b_weight: bool = True
def __post_init__(self):
assert self.lora_ckpt_source in [
"hf", "nemo"
], (f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
)
@property
def missing_qkv_modules(self) -> List[str]:
return get_missing_qkv_modules_from_lora_modules(

View File

@ -54,7 +54,6 @@ class MappingBase:
moe_ep_size=-1, # -1 means no moe
attn_tp_size=-1,
attn_cp_size=-1,
auto_parallel=False,
enable_attention_dp=False,
enable_lm_head_tp_in_adp=False):
# set default values for non-moe cases
@ -97,17 +96,10 @@ class MappingBase:
f"attn_cp_size must be 1 for now for ulysses, but got {attn_tp_size}, {attn_cp_size}."
)
if auto_parallel:
if tp_size != 1 or pp_size != 1 or cp_size != 1:
raise ValueError(
"When auto parallel is enabled, tp_size, pp_size, cp_size must be 1, "
f"but got {tp_size}, {pp_size}, {cp_size}.")
else:
if tp_size * pp_size * cp_size != world_size:
raise ValueError(
"world_size must equal to tp_size * pp_size * cp_size, "
f"but got {world_size} != {tp_size} * {pp_size} * {cp_size}."
)
if tp_size * pp_size * cp_size != world_size:
raise ValueError(
"world_size must equal to tp_size * pp_size * cp_size, "
f"but got {world_size} != {tp_size} * {pp_size} * {cp_size}.")
moe_tp_ep_size = moe_tp_size * moe_ep_size
self.moe_tp_cluster_ep_size = moe_tp_ep_size * moe_cluster_size
@ -139,7 +131,6 @@ class MappingBase:
self.moe_cluster_size = moe_cluster_size
self.attn_tp_size = attn_tp_size
self.attn_cp_size = attn_cp_size
self.auto_parallel = auto_parallel
self.world_size = world_size
self.enable_attention_dp = enable_attention_dp
if enable_lm_head_tp_in_adp:
@ -169,8 +160,7 @@ class MappingBase:
and self.moe_ep_size == other.moe_ep_size
and self.attn_tp_size == other.attn_tp_size
and self.attn_cp_size == other.attn_cp_size
and self.cp_config == other.cp_config
and self.auto_parallel == other.auto_parallel)
and self.cp_config == other.cp_config)
def __hash__(self):
return hash((
@ -187,7 +177,6 @@ class MappingBase:
self.attn_cp_size,
# note: we do not allow updating cp_config after initialization
tuple(sorted(self.cp_config.items())),
self.auto_parallel,
))
@property
@ -339,7 +328,6 @@ class MappingBase:
'attn_tp_size': self.attn_tp_size,
'attn_cp_size': self.attn_cp_size,
'cp_config': self.cp_config,
'auto_parallel': self.auto_parallel,
'enable_attention_dp': self.enable_attention_dp,
'enable_lm_head_tp_in_adp': self.enable_lm_head_tp_in_adp,
}
@ -463,7 +451,6 @@ class Mapping(MappingBase):
moe_ep_size=-1, # -1 means no moe
attn_tp_size=-1,
attn_cp_size=-1,
auto_parallel=False,
enable_attention_dp=False,
enable_lm_head_tp_in_adp=False):
super().__init__(world_size=world_size,
@ -478,7 +465,6 @@ class Mapping(MappingBase):
moe_ep_size=moe_ep_size,
attn_tp_size=attn_tp_size,
attn_cp_size=attn_cp_size,
auto_parallel=auto_parallel,
enable_attention_dp=enable_attention_dp,
enable_lm_head_tp_in_adp=enable_lm_head_tp_in_adp)
@ -516,17 +502,15 @@ class MpiTopology(Mapping):
@property
def tp_rank(self) -> int:
return 0 if self.auto_parallel else self.rank % self.tp_size
return self.rank % self.tp_size
@property
def pp_rank(self) -> int:
return 0 if self.auto_parallel else self.rank // (self.tp_size *
self.cp_size)
return self.rank // (self.tp_size * self.cp_size)
@property
def cp_rank(self) -> int:
return 0 if self.auto_parallel else self.rank % (
self.tp_size * self.cp_size) // self.tp_size
return self.rank % (self.tp_size * self.cp_size) // self.tp_size
@property
def tp_group(self) -> List[int]:

View File

@ -739,9 +739,7 @@ class PretrainedModel(Module,
config.set_rank(rank)
rank = config.mapping.rank
if config.mapping.auto_parallel:
rank = 0
elif config.mapping.cp_size > 1:
if config.mapping.cp_size > 1:
# tp_cp_pp rank -> tp_pp rank: because different cp ranks share the same ckpt
tp_size = config.mapping.tp_size
cp_size = config.mapping.cp_size
@ -1307,7 +1305,7 @@ def unfuse_qkv_gemm(model: PretrainedModel) -> PretrainedModel:
for name, layer in model.named_modules():
if isinstance(layer, Attention) and not layer.cross_attention:
assert layer.tp_size == 1, "please disable manual tp when enable auto parallel"
assert layer.tp_size == 1, "unfuse_qkv_gemm requires tp_size == 1"
if layer.qkv is None:
continue
qkv_params = get_init_params(layer.qkv, ColumnLinear)

View File

@ -151,7 +151,6 @@ class Network(object):
self._strongly_typed = trt.INetworkDefinition.get_flag(
self._trt_network, trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
self._unfilled_weights: Dict[str, Tuple[np.array, np.array]] = {}
self._auto_parallel_config: Dict[str, Any] = None
return self
@ -209,10 +208,6 @@ class Network(object):
def strongly_typed(self) -> bool:
return self._strongly_typed
@property
def auto_parallel_config(self) -> Dict[str, Any]:
return self._auto_parallel_config
def _add_input(self,
tensor,
name,

View File

@ -27,13 +27,12 @@ def read_config(config_path: Path):
plugin_config = builder_config['plugin_config']
pretrained_config = config['pretrained_config']
lora_config = builder_config['lora_config']
auto_parallel_config = builder_config['auto_parallel_config']
use_gpt_attention_plugin = plugin_config["gpt_attention_plugin"]
remove_input_padding = plugin_config["remove_input_padding"]
use_lora_plugin = plugin_config["lora_plugin"]
tp_size = pretrained_config['mapping']['tp_size']
pp_size = pretrained_config['mapping']['pp_size']
gpus_per_node = auto_parallel_config['gpus_per_node']
gpus_per_node = pretrained_config['mapping']['gpus_per_node']
world_size = tp_size * pp_size
assert world_size == mpi_world_size(), \
f'Engine world size ({world_size}) != Runtime world size ({mpi_world_size()})'

View File

@ -50,17 +50,12 @@ BASE_EXAMPLE_CLASSES = {
"tensorrt_llm._torch.models.modeling_siglip": ["SiglipVisionModel"],
"tensorrt_llm._torch.models.modeling_vila": ["VilaModel"],
"tensorrt_llm._torch.models.modeling_gpt_oss": ["GptOssForCausalLM"],
### ending import of torch models classes
"tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig", "LoadFormat"],
"tensorrt_llm._torch.pyexecutor.llm_request":
["LogitsStorage", "PyResult", "LlmResult", "LlmResponse", "LogProbStorage"],
"tensorrt_llm._torch.speculative.mtp": ["MTPConfig"],
"tensorrt_llm._torch.speculative.interface": ["SpeculativeDecodingMode"],
"tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig", "LoadFormat"],
"tensorrt_llm.auto_parallel.config": ["AutoParallelConfig", "CostModel"],
"tensorrt_llm.auto_parallel.cluster_info":
["ClusterInfo", "MathThroughput"],
"tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig", "LoadFormat"],
### ending import of torch models classes
"tensorrt_llm.bindings.executor": [
"BatchingType", "CacheTransceiverConfig", "CapacitySchedulerPolicy",
"ContextPhaseParams", "ContextChunkingPolicy", "DynamicBatchConfig",
@ -70,8 +65,6 @@ BASE_EXAMPLE_CLASSES = {
"KvCacheRetentionConfig.TokenRangeRetentionConfig", "PeftCacheConfig",
"SchedulerConfig"
],
"tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig"],
"tensorrt_llm._torch.model_config": ["MoeLoadBalancerConfig"],
"tensorrt_llm.builder": ["BuildConfig"],
"tensorrt_llm.disaggregated_params": ["DisaggregatedParams"],
"tensorrt_llm.inputs.multimodal": ["MultimodalInput"],

View File

@ -633,11 +633,10 @@
"examples/test_llama.py::test_llm_llama_lookahead_xqa_fp8_1gpu[llama-3.1-8b]": 121.70169674104545,
"examples/test_llama.py::test_llm_llama_lookahead_xqa_fp8_1gpu[llama-3.2-1b]": 79.68221819098108,
"examples/test_llama.py::test_llm_llama_v1_1gpu_kv_cache_reuse_with_prompt_table[llama-7b]": 167.92376559507102,
"examples/test_llama.py::test_llm_llama_v1_2gpu_summary[llama-7b-nb:4-enable_auto_parallel]": 317.7816583644599,
"examples/test_llama.py::test_llm_llama_v1_2gpu_summary[llama-7b-nb:4]": 317.7816583644599,
"examples/test_llama.py::test_llm_llama_v1_4gpu_paged_kv_cache[llama-3.1-8b]": 122.89023815206019,
"examples/test_llama.py::test_llm_llama_v1_multiple_lora_1gpu[luotuo_japan-llama-7b-lora_fp16-base_fp16]": 119.51703953905962,
"examples/test_llama.py::test_llm_llama_v1_multiple_lora_1gpu[luotuo_japan-llama-7b-lora_fp16-base_fp8]": 176.65850483701797,
"examples/test_llama.py::test_llm_llama_v2_1gpu_auto_parallel[llama-v2-7b-hf]": 535.973838724196,
"examples/test_llama.py::test_llm_llama_v2_lora_1gpu[chinese-llama-2-lora-13b-llama-v2-13b-hf-lora_fp16-base_awq]": 420.1779588930076,
"examples/test_llama.py::test_llm_llama_v2_lora_1gpu[chinese-llama-2-lora-13b-llama-v2-13b-hf-lora_fp16-base_fp16]": 895.7611340929288,
"examples/test_llama.py::test_llm_llama_v2_lora_1gpu[chinese-llama-2-lora-13b-llama-v2-13b-hf-lora_fp16-base_fp8]": 314.3205590210273,

View File

@ -894,32 +894,6 @@ def test_llm_llama_v2_gather_logits_2gpu_pp2(llama_example_root,
summary_cmd)
@skip_post_blackwell
@pytest.mark.parametrize("llama_model_root", ['llama-v2-7b-hf'], indirect=True)
def test_llm_llama_v2_1gpu_auto_parallel(llama_example_root, llama_model_root,
llm_venv, cmodel_dir, engine_dir):
model_name = 'llama_v2'
data_type = 'float16'
model_dir = convert_weights(llm_venv=llm_venv,
example_root=llama_example_root,
cmodel_dir=cmodel_dir,
model=model_name,
model_path=llama_model_root,
data_type=data_type)
print("Build engines...")
build_cmd = [
"trtllm-build",
f"--checkpoint_dir={model_dir}",
f"--max_batch_size={1}",
f"--max_input_len={1024}",
f"--output_dir={engine_dir}",
"--auto_parallel=8",
]
check_call(" ".join(build_cmd), shell=True, env=llm_venv._new_env)
@skip_post_blackwell
@pytest.mark.skip_less_device(2)
@pytest.mark.parametrize("num_beams", [1, 4],
@ -1078,24 +1052,21 @@ def test_llm_llama_v3_1_autoq_2gpu_mmlu(llama_example_root, llama_model_root,
@skip_post_blackwell
@pytest.mark.skip_less_device(2)
@pytest.mark.skip_less_device_memory(80000)
@pytest.mark.parametrize("use_auto_parallel", [True, False],
ids=["enable_auto_parallel", "disable_auto_parallel"])
@pytest.mark.parametrize("num_beams", [4],
ids=lambda num_beams: f'nb:{num_beams}')
@pytest.mark.parametrize("llama_model_root", ['llama-7b', 'llama-30b'],
indirect=True)
def test_llm_llama_v1_2gpu_summary(llama_example_root, llama_model_root,
llm_datasets_root, llm_rouge_root, llm_venv,
cmodel_dir, engine_dir, num_beams,
use_auto_parallel):
cmodel_dir, engine_dir, num_beams):
model_name = 'llama_v1_2gpu'
model_dir = convert_weights(llm_venv=llm_venv,
example_root=llama_example_root,
cmodel_dir=cmodel_dir,
model=model_name,
model_path=llama_model_root,
gpus=1 if use_auto_parallel else 2,
tp_size=1 if use_auto_parallel else 2,
gpus=2,
tp_size=2,
pp_size=1)
print("Build engines...")
@ -1108,8 +1079,6 @@ def test_llm_llama_v1_2gpu_summary(llama_example_root, llama_model_root,
"--remove_input_padding=enable",
f"--max_beam_width={num_beams}",
]
if use_auto_parallel:
build_cmd += ["--auto_parallel=2"]
check_call(" ".join(build_cmd), shell=True, env=llm_venv._new_env)

View File

@ -102,11 +102,10 @@ examples/test_llama.py::test_llm_llama_lookahead_single_gpu_summary[llama-3.1-8b
examples/test_llama.py::test_llm_llama_lookahead_xqa_fp8_1gpu[llama-3.1-8b]
examples/test_llama.py::test_llm_llama_lookahead_xqa_fp8_1gpu[llama-3.2-1b]
examples/test_llama.py::test_llm_api_lookahead_decoding_1gpu[Llama-3.1-8B-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct]
examples/test_llama.py::test_llm_llama_v1_2gpu_summary[llama-7b-nb:4-enable_auto_parallel]
examples/test_llama.py::test_llm_llama_v1_2gpu_summary[llama-7b-nb:4]
examples/test_llama.py::test_llm_llama_v1_4gpu_paged_kv_cache[llama-3.1-8b]
examples/test_llama.py::test_llm_llama_v1_multiple_lora_1gpu[luotuo_japan-llama-7b-lora_fp16-base_fp16]
examples/test_llama.py::test_llm_llama_v1_multiple_lora_1gpu[luotuo_japan-llama-7b-lora_fp16-base_fp8]
examples/test_llama.py::test_llm_llama_v2_1gpu_auto_parallel[llama-v2-7b-hf]
examples/test_llama.py::test_llm_llama_v2_lora_1gpu[chinese-llama-2-lora-13b-llama-v2-13b-hf-lora_fp16-base_awq]
examples/test_llama.py::test_llm_llama_v2_lora_1gpu[chinese-llama-2-lora-13b-llama-v2-13b-hf-lora_fp16-base_fp16]
examples/test_llama.py::test_llm_llama_v2_lora_1gpu[chinese-llama-2-lora-13b-llama-v2-13b-hf-lora_fp16-base_fp8]

View File

@ -157,7 +157,6 @@ l0_a30:
- examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-bart-large-cnn-bfloat16-enable_gemm_plugin-enable_attention_plugin-disable_paged_kv_cache-tp:1-pp:1-nb:1-disable_fp8] # 1 mins
- accuracy/test_cli_flow.py::TestLlama7B::test_manage_weights # 2 mins
- accuracy/test_llm_api.py::TestQwen2_7BInstruct::test_weight_only
- examples/test_llama.py::test_llm_llama_v2_1gpu_auto_parallel[llama-v2-7b-hf] # 15 mins
- accuracy/test_cli_flow.py::TestMistral7B::test_beam_search # 5 mins
- examples/test_mistral.py::test_llm_mistral_v1_1gpu[mistral-7b-v0.1-float16-max_attention_window_size_4096-summarization_long]
- examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-t5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:2-disable_fp8]

View File

@ -253,7 +253,6 @@ examples/test_qwen2audio.py::test_llm_qwen2audio_single_gpu[qwen2_audio_7b_instr
examples/test_nemotron_nas.py::test_nemotron_nas_summary_2gpu[DeciLM-7B] SKIP (https://nvbugs/5444636)
examples/test_multimodal.py::test_llm_fp8_multimodal_general[fp8-fp8-cnn_dailymail-Qwen2-VL-7B-Instruct-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False] SKIP (https://nvbugs/5453709)
examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5453709)
examples/test_llama.py::test_llm_llama_v2_1gpu_auto_parallel[llama-v2-7b-hf] SKIP (https://nvbugs/5453742)
triton_server/test_triton.py::test_gpt_ib[gpt-ib] SKIP (https://nvbugs/5431116)
accuracy/test_llm_api.py::TestMistralNemo12B::test_fp8 SKIP (https://nvbugs/5413197)
triton_server/test_triton.py::test_gpt_ib_streaming[gpt-ib-streaming] SKIP (https://nvbugs/5371349)
@ -267,7 +266,6 @@ examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-3-mini-128k-instruct] SKIP
examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-3-small-128k-instruct] SKIP (https://nvbugs/5465143)
examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-3.5-mini-instruct] SKIP (https://nvbugs/5465143)
examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-4-mini-instruct] SKIP (https://nvbugs/5465143)
examples/test_llama.py::test_llm_llama_v1_2gpu_summary[llama-7b-nb:4-enable_auto_parallel] SKIP (https://nvbugs/5453742)
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen1.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_vl_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5359696)

View File

@ -26,9 +26,8 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
# check expected parallel config
world_size = expected_ad_config.world_size
expected_parallel_config = _ParallelConfig(
auto_parallel=True, gpus_per_node=expected_llm_args.gpus_per_node
tp_size=world_size, gpus_per_node=expected_llm_args.gpus_per_node
)
expected_parallel_config.world_size = world_size
assert llm_args._parallel_config == expected_parallel_config, (
f"Expected parallel_config {expected_parallel_config}, got {llm_args._parallel_config}"
)

View File

@ -53,7 +53,6 @@ def build_engine(model_name, request):
llm = LLM(model_path,
tensor_parallel_size=tp_size,
auto_parallel_world_size=tp_size,
build_config=build_config)
engine_dir = TemporaryDirectory(suffix="-engine_dir")

View File

@ -48,7 +48,6 @@ def engine_from_fp8_quantization(model_name):
llm = LLM(model_path,
tensor_parallel_size=tp_size,
auto_parallel_world_size=tp_size,
quant_config=quant_config,
calib_config=calib_config,
build_config=build_config)

View File

@ -459,21 +459,12 @@ def test_llm_generate_async():
def _test_llm_generate_async(model_name=default_model_name,
tp_size: int = 1,
use_auto_parallel: bool = False,
tokenizer=None):
if "Mixtral" in model_name and use_auto_parallel:
pytest.skip("Auto parallel is not supported for Mixtral models")
tp_size = tp_size if not use_auto_parallel else 1
world_size = tp_size if use_auto_parallel else None
llm = LLM(
model=get_model_path(model_name),
tokenizer=tokenizer,
kv_cache_config=global_kvcache_config,
tensor_parallel_size=tp_size,
auto_parallel=use_auto_parallel,
auto_parallel_world_size=world_size,
fast_build=True,
)

View File

@ -6,7 +6,6 @@ import yaml
import tensorrt_llm.bindings.executor as tle
from tensorrt_llm import LLM as TorchLLM
from tensorrt_llm import AutoParallelConfig
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm.builder import LoraConfig
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
@ -334,10 +333,6 @@ def test_update_llm_args_with_extra_dict_with_nested_dict():
plugin_config = PluginConfig(dtype='float16', nccl_plugin=None)
build_config = BuildConfig(max_input_len=1024,
lora_config=LoraConfig(lora_ckpt_source='hf'),
auto_parallel_config=AutoParallelConfig(
world_size=1,
same_buffer_io={},
debug_outputs=[]),
plugin_config=plugin_config)
extra_llm_args_dict = {
"build_config": build_config.to_dict(),

View File

@ -135,25 +135,17 @@ def test_llm_return_logprobs_tp2(prompt_logprobs: Optional[int],
tp_size=2)
@pytest.mark.parametrize("use_auto_parallel", [True, False],
ids=["enable_auto_parallel", "disable_auto_parallel"])
@pytest.mark.parametrize("from_ckpt", [True, False],
ids=["from_ckpt", "from_hf"])
@pytest.mark.gpu2
@pytest.mark.part0
def test_llm_generate_async_tp2(
engine_from_checkpoint: tempfile.TemporaryDirectory, from_ckpt: bool,
use_auto_parallel: bool):
if use_auto_parallel and from_ckpt:
pytest.skip("Skip auto parallel for TP2 checkpoint")
engine_from_checkpoint: tempfile.TemporaryDirectory, from_ckpt: bool):
model_dir = engine_from_checkpoint.name if from_ckpt else get_model_path(
llama_model_path)
tokenizer_dir = get_model_path(llama_model_path)
tokenizer = TransformersTokenizer.from_pretrained(tokenizer_dir)
_test_llm_generate_async(model_dir,
tp_size=2,
use_auto_parallel=use_auto_parallel,
tokenizer=tokenizer)
_test_llm_generate_async(model_dir, tp_size=2, tokenizer=tokenizer)
@skip_gpu_memory_less_than(70 * 1024**3)
@ -201,7 +193,6 @@ def test_llm_pp2():
prompts, ["D E F G H I J K"],
sampling_params=SamplingParams(max_tokens=8),
pipeline_parallel_size=2,
auto_parallel=False,
kv_cache_config=global_kv_cache_config)

View File

@ -1,12 +1,6 @@
import os
import tempfile
import torch
from tensorrt_llm import serialization
from tensorrt_llm.auto_parallel.config import AutoParallelConfig
from tensorrt_llm.auto_parallel.parallelization import ParallelConfig
from tensorrt_llm.auto_parallel.simplifier import GraphConfig, StageType
class TestClass:
@ -83,38 +77,5 @@ def test_serialization_complex_object_disallowed_class():
excep) == "Import torch._utils | _rebuild_tensor_v2 is not allowed"
def test_parallel_config_serialization():
with tempfile.TemporaryDirectory() as tmpdir:
# Create a ParallelConfig instance with some test data
config = ParallelConfig()
config.version = "test_version"
config.network_hash = "test_hash"
config.auto_parallel_config = AutoParallelConfig(
world_size=2, gpus_per_node=2, cluster_key="test_cluster")
config.graph_config = GraphConfig(num_micro_batches=2,
num_blocks=3,
num_stages=2)
config.cost = 1.5
config.stage_type = StageType.START
config_path = os.path.join(tmpdir, "parallel_config.pkl")
config.save(config_path)
loaded_config = ParallelConfig.from_file(config_path)
# Verify the loaded config matches the original
assert loaded_config.version == config.version
assert loaded_config.network_hash == config.network_hash
assert loaded_config.auto_parallel_config.world_size == config.auto_parallel_config.world_size
assert loaded_config.auto_parallel_config.gpus_per_node == config.auto_parallel_config.gpus_per_node
assert loaded_config.auto_parallel_config.cluster_key == config.auto_parallel_config.cluster_key
assert loaded_config.graph_config.num_micro_batches == config.graph_config.num_micro_batches
assert loaded_config.graph_config.num_blocks == config.graph_config.num_blocks
assert loaded_config.graph_config.num_stages == config.graph_config.num_stages
assert loaded_config.cost == config.cost
assert loaded_config.stage_type == config.stage_type
if __name__ == "__main__":
test_serialization_allowed_class()
test_parallel_config_serialization()

View File

@ -15,7 +15,6 @@ from utils.util import force_ampere
import tensorrt_llm
from tensorrt_llm import BuildConfig, Mapping, SamplingParams
from tensorrt_llm._utils import mpi_barrier
from tensorrt_llm.auto_parallel import AutoParallelConfig, infer_cluster_config
from tensorrt_llm.executor import GenerationExecutor
from tensorrt_llm.models import LLaMAForCausalLM
@ -56,7 +55,7 @@ def get_batch_output_text_expected(model_name):
# 76s on ipp1-1197, loading weights 18s (varies based on network speed), network/engine creation 27s
def build_and_run_tp2(rank, model_name, engine_dir, use_auto_parallel):
def build_and_run_tp2(rank, model_name, engine_dir):
'''Do not save the engine, all in one LLaMAForCausalLM object
'''
batch_output_text_expected = get_batch_output_text_expected(model_name)
@ -67,26 +66,10 @@ def build_and_run_tp2(rank, model_name, engine_dir, use_auto_parallel):
max_batch_size, max_isl, max_osl = 8, 256, 256
hf_model_dir = str(llm_models_root() / model_name)
mapping = Mapping(world_size=TP_SIZE, rank=rank, tp_size=TP_SIZE)
auto_parallel_config = AutoParallelConfig()
if use_auto_parallel:
mapping = Mapping()
mapping.rank = rank
auto_parallel_config = AutoParallelConfig(
world_size=TP_SIZE,
sharded_io_allowlist=[
"past_key_value_\\d+",
"present_key_value_\\d*",
],
same_buffer_io={
"past_key_value_(\\d+)": "present_key_value_\\1",
},
**infer_cluster_config(),
)
build_config = BuildConfig(max_batch_size=max_batch_size,
max_input_len=max_isl,
max_seq_len=max_osl + max_isl,
strongly_typed=True,
auto_parallel_config=auto_parallel_config)
strongly_typed=True)
# build and run by one llama object
llama = LLaMAForCausalLM.from_hugging_face(hf_model_dir, mapping=mapping)
engine = tensorrt_llm.build(llama, build_config)
@ -117,21 +100,17 @@ def build_and_run_tp2(rank, model_name, engine_dir, use_auto_parallel):
@force_ampere
@pytest.mark.parametrize("use_auto_parallel", [True, False],
ids=["enable_auto_parallel", "disable_auto_parallel"])
@pytest.mark.parametrize("model_name",
["llama-models/llama-7b-hf", "Mixtral-8x7B-v0.1"])
def test_multi_gpu(model_name, use_auto_parallel):
def test_multi_gpu(model_name):
if torch.cuda.device_count() < TP_SIZE:
print(f"The test needs at least ${TP_SIZE} GPUs, skipping")
return
if "Mixtral" in model_name and use_auto_parallel:
pytest.skip("Auto parallel is not supported for Mixtral models")
engine_dir = tempfile.TemporaryDirectory().name
with MPIPoolExecutor(max_workers=TP_SIZE) as executor:
results = executor.map(build_and_run_tp2, (0, 1), [model_name] * 2,
[engine_dir] * 2, [use_auto_parallel] * 2)
[engine_dir] * 2)
for r in results:
assert r is True