diff --git a/cpp/tests/resources/data/test_model_lora_config.json b/cpp/tests/resources/data/test_model_lora_config.json index 1e3286d8a4..17006d6258 100644 --- a/cpp/tests/resources/data/test_model_lora_config.json +++ b/cpp/tests/resources/data/test_model_lora_config.json @@ -93,36 +93,6 @@ ], "trtllm_modules_to_hf_modules": {} }, - "auto_parallel_config": { - "world_size": 1, - "gpus_per_node": 8, - "cluster_key": "A100-PCIe-80GB", - "cluster_info": null, - "sharding_cost_model": "alpha_beta", - "comm_cost_model": "alpha_beta", - "enable_pipeline_parallelism": false, - "enable_shard_unbalanced_shape": false, - "enable_shard_dynamic_shape": false, - "enable_reduce_scatter": true, - "builder_flags": null, - "debug_mode": false, - "infer_shape": true, - "validation_mode": false, - "same_buffer_io": { - "past_key_value_(\\d+)": "present_key_value_\\1" - }, - "same_spec_io": {}, - "sharded_io_allowlist": [ - "past_key_value_\\d+", - "present_key_value_\\d*" - ], - "fast_reduce": true, - "fill_weights": false, - "parallel_config_cache": null, - "profile_cache": null, - "dump_path": null, - "debug_outputs": [] - }, "weight_sparsity": false, "weight_streaming": false, "plugin_config": { diff --git a/examples/models/core/llama/README.md b/examples/models/core/llama/README.md index e0c6d0858f..d88d593fb1 100644 --- a/examples/models/core/llama/README.md +++ b/examples/models/core/llama/README.md @@ -132,16 +132,6 @@ trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp16_wq \ --output_dir ./tmp/llama/7B/trt_engines/weight_only/1-gpu/ \ --gemm_plugin auto -# Build LLaMA 7B using 2-way auto parallelism (deprecated). -python convert_checkpoint.py --model_dir ./tmp/llama/7B/ \ - --output_dir ./tllm_checkpoint_1gpu_fp16 \ - --dtype float16 - -trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp16 \ - --output_dir ./tmp/llama/7B/trt_engines/fp16/2-gpu/ \ - --gemm_plugin auto \ - --auto_parallel 2 - # Build LLaMA 7B using 2-way tensor parallelism. python convert_checkpoint.py --model_dir ./tmp/llama/7B/ \ --output_dir ./tllm_checkpoint_2gpu_tp2 \ diff --git a/tensorrt_llm/__init__.py b/tensorrt_llm/__init__.py index 9164077dc6..425140016b 100644 --- a/tensorrt_llm/__init__.py +++ b/tensorrt_llm/__init__.py @@ -77,7 +77,6 @@ from ._utils import (default_gpus_per_node, local_mpi_rank, local_mpi_size, mpi_barrier, mpi_comm, mpi_rank, mpi_world_size, set_mpi_comm, str_dtype_to_torch, str_dtype_to_trt, torch_dtype_to_trt) -from .auto_parallel import AutoParallelConfig, auto_parallel from .builder import BuildConfig, Builder, BuilderConfig, build from .disaggregated_params import DisaggregatedParams from .functional import Tensor, constant @@ -130,8 +129,6 @@ __all__ = [ 'Module', 'functional', 'models', - 'auto_parallel', - 'AutoParallelConfig', 'quantization', 'tools', 'LLM', diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index 3a8a03fa70..1d8b95d4db 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -391,10 +391,13 @@ class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings): rank to automatically shard the model. This is just to ensure that other objects in the runtime that may read parallel_config can do so. """ + + # Set tp_size = self.world_size so that _ParallelConfig.world_size will return the + # correct value (computed as tp_size * pp_size * cp_size). This does not necessarily + # mean that TP will actually be used. self._parallel_config = _ParallelConfig( - auto_parallel=True, gpus_per_node=self.gpus_per_node + tp_size=self.world_size, gpus_per_node=self.gpus_per_node ) - self._parallel_config.world_size = self.world_size return self @model_validator(mode="after") diff --git a/tensorrt_llm/_torch/device_mesh.py b/tensorrt_llm/_torch/device_mesh.py index 512ba955ea..ca8db83385 100644 --- a/tensorrt_llm/_torch/device_mesh.py +++ b/tensorrt_llm/_torch/device_mesh.py @@ -74,18 +74,15 @@ class DeviceMeshTopologyImpl(_MappingBaseForTypeCheck): # Access rank @property def tp_rank(self) -> int: - assert not self.auto_parallel, "Auto parallel is not currently supported in Ray mode." return self.tp_group_pg.rank() @property def pp_rank(self) -> int: - assert not self.auto_parallel, "Auto parallel is not currently supported in Ray mode." return self.pp_group_pg.rank() @property def cp_rank(self) -> int: # TODO: WIP - assert not self.auto_parallel, "Auto parallel is not currently supported in Ray mode." return self.cp_group_pg.rank() # Access group ranks diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index 4c696511dc..7a98a93377 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -25,7 +25,6 @@ import tempfile import trace import weakref from contextlib import contextmanager -from dataclasses import asdict from enum import EnumMeta from functools import lru_cache, partial, wraps from pathlib import Path @@ -799,38 +798,6 @@ def trace_func(func): return wrapper -class DictConversion: - - @classmethod - def from_dict(cls, config: Dict[str, Any]): - obj = cls() - fields = obj.__dataclass_fields__ - for key, value in config.items(): - assert hasattr(obj, key), f"cannot find {key} in {obj}" - field_cls = fields[key].type - if (isinstance(field_cls, type) - and issubclass(field_cls, DictConversion) - and isinstance(value, dict)): - value = field_cls.from_dict(value) - setattr(obj, key, value) - return obj - - def to_dict(self): - return asdict(self) - - @classmethod - def from_json_file(cls, file): - with open(file) as f: - return cls.from_dict(json.load(f)) - - def set_defaults(self, **kwargs): - for key, default in kwargs.items(): - value = getattr(self, key) - if (value is None - or (isinstance(value, (list, dict)) and len(value) == 0)): - setattr(self, key, default) - - class BaseEnumMeta(EnumMeta): def __contains__(cls, item): diff --git a/tensorrt_llm/auto_parallel/__init__.py b/tensorrt_llm/auto_parallel/__init__.py deleted file mode 100644 index 52b246a7f8..0000000000 --- a/tensorrt_llm/auto_parallel/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .auto_parallel import auto_parallel -from .cluster_info import infer_cluster_config -from .config import AutoParallelConfig - -__all__ = [ - 'auto_parallel', - 'AutoParallelConfig', - 'infer_cluster_config', -] diff --git a/tensorrt_llm/auto_parallel/auto_parallel.py b/tensorrt_llm/auto_parallel/auto_parallel.py deleted file mode 100644 index adaea2a4f0..0000000000 --- a/tensorrt_llm/auto_parallel/auto_parallel.py +++ /dev/null @@ -1,266 +0,0 @@ -import gc -import os -from concurrent.futures import ThreadPoolExecutor -from pathlib import Path - -import tensorrt as trt -import torch -from filelock import FileLock - -from tensorrt_llm.functional import DimRange, Tensor -from tensorrt_llm.logger import logger -from tensorrt_llm.network import Network, net_guard - -from .config import AutoParallelConfig -from .device_mesh import LogicalDeviceMesh, PhysicalDeviceMesh -from .node_graph import NodeGraph -from .parallelization import ParallelConfig, parallelize -from .pipeline_graph import PipelineGraph -from .simplifier import GraphConfig, Simplifier, StageType -from .utils import current_flags - - -def to_network(graph: PipelineGraph, network: Network): - logger.debug("Converting graph to network") - trt_network = graph.as_trt() - trt_network.name = network.trt_network.name - new_network = Network() - new_network._init(trt_network) - new_network._dtype = network._dtype - new_network._plugin_config = network._plugin_config - new_network._unfilled_weights = graph._unfilled_weights - new_network._auto_parallel_config = graph._auto_parallel_config - with net_guard(network): - for i in range(trt_network.num_inputs): - input = trt_network.get_input(i) - tensor = Tensor(is_network_input=False) - if input.name in network._inputs: - profiles = network._inputs[input.name].profiles - elif len(network._inputs) == 0: - profiles = [] - else: - shape = input.shape - num_profiles = len(list(network._inputs.values())[0].profiles) - profile = DimRange(shape, [None] * len(shape)) - profiles = [profile] * num_profiles - tensor.profiles = profiles - tensor.trt_tensor = input - new_network._inputs[input.name] = tensor - return new_network - - -def find_solution( - node_graph: NodeGraph, - graph_config: GraphConfig, - lmesh: LogicalDeviceMesh, - memory_budget: int, - flags: list, - device: int, - dump_path: str, -) -> ParallelConfig: - torch.cuda.set_device(device) - with current_flags(*flags): - cost_graph = node_graph.get_cost_graph(lmesh) - num_stages = graph_config.num_stages - if num_stages == 1: - stage_types = [None] - elif num_stages == 2: - stage_types = [StageType.START, StageType.END] - else: - stage_types = [StageType.START, StageType.BLOCK, StageType.END] - - best_config, best_solution = None, None - for stage_type in stage_types: - if stage_type is not None: - node_graph.set_slowest_stage(stage_type, graph_config) - solution = node_graph.find_solution( - cost_graph, - memory_budget, - ) - cost = solution.total_cost - if best_config is None or cost < best_config.cost: - best_config = ParallelConfig() - best_config.graph_config = graph_config - best_config.lmesh = lmesh - best_config.cost = cost - best_config.graph_strategy = solution.node_best_strategy - best_config.stage_type = stage_type - best_solution = solution - if dump_path is not None: - lock = FileLock(f"{dump_path}/path.lock", thread_local=False) - vlz_name = f"{dump_path}/solution." - if graph_config.num_micro_batches != 1: - vlz_name += f"mbs{graph_config.num_micro_batches}." - if graph_config.num_stages != 1: - vlz_name += f"stages{graph_config.num_stages}." - vlz_name += lmesh.cluster_key - with lock: - node_graph.visualize_solution( - best_solution, - vlz_name, - ignore_shape_io=True, - ) - return best_config - - -def infer_builder_flags(network): - fp16_enabled = False - bf16_enabled = False - int8_enabled = False - fp8_enabled = False - - def check_dtype(tensor): - nonlocal fp16_enabled - nonlocal bf16_enabled - nonlocal int8_enabled - nonlocal fp8_enabled - if tensor.dtype == trt.DataType.HALF: - fp16_enabled = True - elif tensor.dtype == trt.DataType.BF16: - bf16_enabled = True - elif tensor.dtype == trt.DataType.INT8: - int8_enabled = True - elif tensor.dtype == trt.DataType.FP8: - fp8_enabled = True - - trt_network = network.trt_network - for i in range(trt_network.num_inputs): - input = trt_network.get_input(i) - check_dtype(input) - for i in range(trt_network.num_layers): - layer = trt_network.get_layer(i) - for j in range(layer.num_outputs): - output = layer.get_output(j) - check_dtype(output) - - builder_flags = 0 - if fp16_enabled: - builder_flags |= 1 << int(trt.BuilderFlag.FP16) - builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) - if bf16_enabled: - builder_flags |= 1 << int(trt.BuilderFlag.BF16) - builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) - if int8_enabled: - builder_flags |= 1 << int(trt.BuilderFlag.INT8) - if fp8_enabled: - builder_flags |= 1 << int(trt.BuilderFlag.FP8) - builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) - return builder_flags - - -def auto_parallel(network: Network, config: AutoParallelConfig): - logger.warning( - "auto_parallel is deprecated, " - "please use explicit parallelism like tp_size/pp_size instead.") - debug_mode = config.debug_mode - memory_budget = config.get_cluster_info( - ).memory_budget_per_device * 1024 * 1024 * 1024 - enable_pipeline_parallelism = config.enable_pipeline_parallelism - if config.world_size < config.gpus_per_node: - num_hosts = 1 - num_devices_per_host = config.world_size - else: - assert config.world_size % config.gpus_per_node == 0 - num_hosts = config.world_size // config.gpus_per_node - num_devices_per_host = config.gpus_per_node - parallel_config_cache = config.parallel_config_cache - dump_path = config.dump_path if debug_mode else None - fill_weights = config.fill_weights - - if num_hosts == 1 and num_devices_per_host == 1: - return [network] - - if dump_path is not None: - if not os.path.exists(dump_path): - os.makedirs(dump_path) - - builder_flags = config.builder_flags or infer_builder_flags(network) - flags = [builder_flags, network.strongly_typed] - with current_flags(*flags): - simplifier = Simplifier(network, config) - network_hash = simplifier.get_network_hash() - - best_config = None - if parallel_config_cache is not None and Path( - parallel_config_cache).exists(): - parallel_config = ParallelConfig.from_file(parallel_config_cache) - if (ParallelConfig.VERSION == parallel_config.version - and network_hash == parallel_config.network_hash - and config == parallel_config.auto_parallel_config): - logger.info( - f"use cache of parallel config from {parallel_config_cache}" - ) - best_config = parallel_config - - if best_config is None: - num_devices = num_hosts * num_devices_per_host - phy_ids = [[ - i + j * num_devices_per_host - for i in range(num_devices_per_host) - ] for j in range(num_hosts)] - phy_mesh = PhysicalDeviceMesh(phy_ids, config) - if enable_pipeline_parallelism: - num_micro_batches_list = simplifier.list_all_num_micro_batches() - else: - num_micro_batches_list = [1] - - jobs = [] - for num_micro_batches in num_micro_batches_list: - simplifier.infer_shapes(num_micro_batches) - if enable_pipeline_parallelism: - pipeline_configs = phy_mesh.list_all_pipeline_configs() - else: - pipeline_configs = [(1, num_devices)] - for num_stages, num_devices_per_stage in pipeline_configs: - # TODO: add fallback path that allows num_micro_batches >= num_stages - # if no solution satisfies memory budget - if num_micro_batches < num_stages: - continue - simplified_graph, graph_config = simplifier.simplify_graph( - phy_mesh, - num_stages, - num_devices_per_stage, - ) - if simplified_graph is None: - continue - node_graph = NodeGraph(simplified_graph) - node_graph.assign_cost_weights(graph_config) - lmeshes = graph_config.stage_phy_meshes[ - 0].get_logical_meshes() - for lmesh in lmeshes: - jobs.append( - (node_graph, graph_config, lmesh, memory_budget * - (num_devices / num_devices_per_stage))) - - try: - with ThreadPoolExecutor() as executor: - best_config = sorted( - executor.map( - lambda x: find_solution( - *x, - flags, - torch.cuda.current_device(), - dump_path, - ), - jobs, - ), - key=lambda x: x.cost, - )[0] - finally: - phy_mesh.close() - - if parallel_config_cache is not None: - best_config.network_hash = network_hash - best_config.auto_parallel_config = config - best_config.save(parallel_config_cache) - - new_graphs = parallelize(simplifier, best_config) - - networks = [to_network(new_graph, network) for new_graph in new_graphs] - if debug_mode and fill_weights: - networks[0]._fill_weights() - - gc.collect() - torch.cuda.empty_cache() - - return networks diff --git a/tensorrt_llm/auto_parallel/cluster_info.py b/tensorrt_llm/auto_parallel/cluster_info.py deleted file mode 100644 index 6693c00753..0000000000 --- a/tensorrt_llm/auto_parallel/cluster_info.py +++ /dev/null @@ -1,539 +0,0 @@ -import copy -import re -from dataclasses import dataclass, field -from typing import Dict, Tuple, Union - -import pynvml -import torch - -try: - from cuda.bindings import runtime as cudart -except ImportError: - from cuda import cudart - -from tensorrt_llm._utils import DictConversion -from tensorrt_llm.logger import logger -from tensorrt_llm.profiler import PyNVMLContext, _device_get_memory_info_fn - - -@dataclass -class MathThroughput(DictConversion): - int4: int = 0 # Tflops - int8: int = 0 # Tflops - fp8: int = 0 # Tflops - float16: int = 0 # Tflops - bfloat16: int = 0 # Tflops - float32: int = 0 # Tflops - - @staticmethod - def to_tflops( - ipc_per_sm: "MathThroughput", - sm_count: int, - clock_mhz: int, - ) -> "MathThroughput": - tflops = MathThroughput() - for name in ipc_per_sm.__dataclass_fields__: - setattr( - tflops, name, - getattr(ipc_per_sm, name) * sm_count * clock_mhz // int(1e6)) - return tflops - - -@dataclass -class ClusterInfo(DictConversion): - inter_node_bw_per_device: int = 25 # GBps - intra_node_bw_per_device: int = 0 # GBps - inter_node_latency: int = 10 # us - intra_node_latency: int = 10 # us - intra_node_sharp: bool = False - inter_node_sharp: bool = True - - memory_bw: int = 0 # GBps - memory_budget_per_device: int = 0 # GB - - math_throughput: MathThroughput = field(default_factory=MathThroughput) - - memory_efficiency: float = 1.0 - math_efficiency: float = 1.0 - communication_efficiency: float = 1.0 - - -_math_throughputs = { - "A100": MathThroughput( - int8=624, - float16=312, - bfloat16=312, - float32=156, - ), -} - -_bandwidths = { - "PCIe-3": 16, - "PCIe-4": 32, - "PCIe-5": 64, -} - -cluster_infos = { - # from https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf - "A100-SXM-80GB": - ClusterInfo( - intra_node_bw_per_device=300, - memory_bw=2039, - memory_budget_per_device=80, - math_throughput=_math_throughputs["A100"], - ), - "A100-SXM-40GB": - ClusterInfo( - intra_node_bw_per_device=300, - memory_bw=1555, - memory_budget_per_device=40, - math_throughput=_math_throughputs["A100"], - ), - "A100-PCIe-80GB": - ClusterInfo( - intra_node_bw_per_device=_bandwidths["PCIe-4"], - memory_bw=1935, - memory_budget_per_device=80, - math_throughput=_math_throughputs["A100"], - ), - "A100-PCIe-40GB": - ClusterInfo( - intra_node_bw_per_device=_bandwidths["PCIe-4"], - memory_bw=1555, - memory_budget_per_device=40, - math_throughput=_math_throughputs["A100"], - ), - # from https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet - "H100-SXM": - ClusterInfo( - inter_node_bw_per_device=50, - intra_node_bw_per_device=450, - intra_node_sharp=True, - memory_bw=3350, - memory_budget_per_device=80, - math_throughput=MathThroughput( - int8=1979, - fp8=1979, - float16=989, - bfloat16=989, - float32=495, - ), - ), - "H100-PCIe": - ClusterInfo( - inter_node_bw_per_device=50, - intra_node_bw_per_device=_bandwidths["PCIe-5"], - memory_bw=2000, - memory_budget_per_device=80, - math_throughput=MathThroughput( - int8=1513, - fp8=1513, - float16=756, - bfloat16=756, - float32=378, - ), - ), - "H20": - ClusterInfo( - inter_node_bw_per_device=50, - intra_node_bw_per_device=450, - memory_bw=4000, - memory_budget_per_device=96, - math_throughput=MathThroughput( - int8=293, - fp8=293, - float16=147, - bfloat16=147, - float32=74, - ), - ), - # from https://nvdam.widen.net/s/nb5zzzsjdf/hpc-datasheet-sc23-h200-datasheet-3002446 - "H200-SXM": - ClusterInfo( - inter_node_bw_per_device=50, - intra_node_bw_per_device=450, - memory_bw=4800, - memory_budget_per_device=141, - math_throughput=MathThroughput( - int8=3958, - fp8=3958, - float16=1979, - bfloat16=1979, - float32=67, - ), - ), - "H200-NVL": - ClusterInfo( - inter_node_bw_per_device=50, - intra_node_bw_per_device=450, - memory_bw=4800, - memory_budget_per_device=141, - math_throughput=MathThroughput( - int8=3341, - fp8=3341, - float16=1671, - bfloat16=1671, - float32=60, - ), - ), - # from https://images.nvidia.com/content/Solutions/data-center/a40/nvidia-a40-datasheet.pdf - "A40": - ClusterInfo( - intra_node_bw_per_device=_bandwidths["PCIe-4"], - memory_bw=696, - memory_budget_per_device=48, - math_throughput=MathThroughput( - int4=600, - int8=300, - float16=150, - bfloat16=150, - float32=75, - ), - ), - # from https://www.nvidia.com/content/dam/en-zz/Solutions/data-center/products/a30-gpu/pdf/a30-datasheet.pdf - "A30": - ClusterInfo( - intra_node_bw_per_device=_bandwidths["PCIe-4"], - memory_bw=933, - memory_budget_per_device=24, - math_throughput=MathThroughput( - int4=661, - int8=330, - float16=165, - bfloat16=165, - float32=82, - ), - ), - # from https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/datasheet-new/nvidia-a10-datasheet.pdf - "A10": - ClusterInfo( - intra_node_bw_per_device=_bandwidths["PCIe-4"], - memory_bw=600, - memory_budget_per_device=24, - math_throughput=MathThroughput( - int4=500, - int8=250, - float16=125, - bfloat16=125, - float32=62.5, - ), - ), - "A10G": - ClusterInfo( - intra_node_bw_per_device=_bandwidths["PCIe-4"], - memory_bw=600, - memory_budget_per_device=24, - math_throughput=MathThroughput( - int4=280, - int8=140, - float16=70, - bfloat16=70, - float32=35, - ), - ), - # from https://resources.nvidia.com/en-us-l40s/l40s-datasheet-28413 - "L40S": - ClusterInfo( - intra_node_bw_per_device=_bandwidths["PCIe-4"], - memory_bw=864, - memory_budget_per_device=48, - math_throughput=MathThroughput( - int4=733, - int8=733, - fp8=733, - float16=362, - bfloat16=362, - float32=183, - ), - ), - # from https://images.nvidia.cn/content/Solutions/data-center/vgpu-L40-datasheet.pdf - "L40": - ClusterInfo( - intra_node_bw_per_device=_bandwidths["PCIe-4"], - memory_bw=864, - memory_budget_per_device=48, - math_throughput=MathThroughput( - int4=724, - int8=362, - fp8=362, - float16=181, - bfloat16=181, - float32=90, - ), - ), - "L20": - ClusterInfo( - intra_node_bw_per_device=_bandwidths["PCIe-4"], - memory_bw=864, - memory_budget_per_device=48, - math_throughput=MathThroughput( - int8=238, - fp8=238, - float16=119, - bfloat16=119, - float32=60, - ), - ), - # from https://nvdam.widen.net/s/rvq98gbwsw/l4-datasheet-2595652 - "L4": - ClusterInfo( - intra_node_bw_per_device=_bandwidths["PCIe-4"], - memory_bw=300, - memory_budget_per_device=24, - math_throughput=MathThroughput( - int8=242, - fp8=242, - float16=120, - bfloat16=120, - float32=60, - ), - ), - "L2": - ClusterInfo( - intra_node_bw_per_device=_bandwidths["PCIe-4"], - memory_bw=300, - memory_budget_per_device=24, - math_throughput=MathThroughput( - int8=193, - fp8=193, - float16=97, - bfloat16=97, - float32=48, - ), - ), -} - - -def infer_cluster_key() -> str: - - def match(product, name): - # Use A100 as example, the regex pattern matches for: - # - NVIDIA A100 80GB - # - NVIDIA A100-PCIE - # - NVIDIA A100 - # And does not match A1000 etc. - return re.match(f".*{product}([ -]|$).*", name) is not None - - def is_sxm(): - return "SXM" in device_name - - def is_80gb(): - return "80GB" in device_name - - def is_32gb(): - return "32GB" in device_name - - device_name = torch.cuda.get_device_name(torch.cuda.current_device()) - - if match("A100", device_name): - if is_sxm(): - if is_80gb(): - return "A100-SXM-80GB" - else: - return "A100-SXM-40GB" - else: - if is_80gb(): - return "A100-PCIe-80GB" - else: - return "A100-PCIe-40GB" - elif match("A10G", device_name): - return "A10G" - elif match("A10", device_name): - return "A10" - elif match("A30", device_name): - return "A30" - elif match("A40", device_name): - return "A40" - elif match("H100", device_name): - if is_sxm(): - return "H100-SXM" - else: - return "H100-PCIe" - elif match("H200", device_name): - if is_sxm(): - return "H200-SXM" - else: - return "H200-NVL" - elif match("L40S", device_name): - return "L40S" - elif match("L40", device_name): - return "L40" - elif match("L4", device_name): - return "L4" - return None - - -def ipc_per_sm(compute_cap: Tuple[int, int]) -> MathThroughput: - ipc_table = { - (9, 0): - MathThroughput( - int8=16384, - fp8=16384, - float16=8192, - bfloat16=8192, - float32=4096, - ), - (8, 0): - MathThroughput( - int4=8192, - int8=4096, - float16=2048, - bfloat16=2048, - float32=1024, - ), - (8, 6): - MathThroughput( - int4=4096, - int8=2048, - float16=1024, - bfloat16=1024, - float32=512, - ), - (8, 9): - MathThroughput( - int4=2048, - int8=1024, - fp8=1024, - float16=512, - bfloat16=512, - float32=256, - ), - (7, 0): - MathThroughput( - float16=1024, - float32=128, - ), - (7, 5): - MathThroughput( - int4=4096, - int8=2048, - float16=1024, - float32=128, - ), - } - return ipc_table.get(compute_cap, MathThroughput()) - - -def nvlink_version(version_enum: int) -> int: - nvl_version_table = { - 1: 1, - 2: 2, - 3: 2, # 2.2 - 4: 3, - 5: 3, # 3.1 - 6: 4, - 7: 5, - } - return nvl_version_table[version_enum] - - -def nvlink_bandwidth(nvlink_version: int) -> int: - nvl_bw_table = { - 1: 80, - 2: 150, - 3: 300, - 4: 450, - 5: 900, - } - return nvl_bw_table[nvlink_version] - - -def infer_cluster_info() -> ClusterInfo: - device = torch.cuda.current_device() - index = device.index if isinstance(device, torch.device) else device - with PyNVMLContext(): - handle = pynvml.nvmlDeviceGetHandleByIndex(index) - compute_cap = pynvml.nvmlDeviceGetCudaComputeCapability(handle) - logger.info(f"Compute capability: {compute_cap}") - err, properties = cudart.cudaGetDeviceProperties(index) - sm_count = properties.multiProcessorCount - logger.info(f"SM count: {sm_count}") - sm_clock = pynvml.nvmlDeviceGetMaxClockInfo( - handle, - pynvml.NVML_CLOCK_SM, - ) - logger.info(f"SM clock: {sm_clock} MHz") - math_throughput = MathThroughput.to_tflops( - ipc_per_sm(compute_cap), - sm_count, - sm_clock, - ) - for name in math_throughput.__dataclass_fields__: - tflops = getattr(math_throughput, name) - logger.info(f"{name} TFLOPS: {tflops}") - - mem_info = _device_get_memory_info_fn(handle) - memory_budget = mem_info.total // (1024**3) - logger.info(f"Total Memory: {memory_budget} GiB") - - mem_clock = pynvml.nvmlDeviceGetMaxClockInfo( - handle, - pynvml.NVML_CLOCK_MEM, - ) - logger.info(f"Memory clock: {mem_clock} MHz") - mem_bus_width = pynvml.nvmlDeviceGetMemoryBusWidth(handle) - logger.info(f"Memory bus width: {mem_bus_width}") - memory_bw = mem_bus_width * mem_clock * 2 // int(8e3) - logger.info(f"Memory bandwidth: {memory_bw} GB/s") - - try: - is_nvl_active = bool(pynvml.nvmlDeviceGetNvLinkState(handle, 0)) - logger.info(f"NVLink is active: {is_nvl_active}") - except pynvml.NVMLError: - is_nvl_active = False - - intra_node_sharp = False - if is_nvl_active: - nvl_version_enum = pynvml.nvmlDeviceGetNvLinkVersion(handle, 0) - nvl_version = nvlink_version(nvl_version_enum) - logger.info(f"NVLink version: {nvl_version}") - nvl_bw = nvlink_bandwidth(nvl_version) - logger.info(f"NVLink bandwidth (unidirectional): {nvl_bw} GB/s") - intra_node_bw = nvl_bw - if nvl_version >= 4: - intra_node_sharp = True - else: - pcie_speed = pynvml.nvmlDeviceGetPcieSpeed(handle) - logger.info(f"PCIe speed: {pcie_speed} Mbps") - pcie_link_width = pynvml.nvmlDeviceGetCurrPcieLinkWidth(handle) - logger.info(f"PCIe link width: {pcie_link_width}") - pcie_bw = pcie_speed * pcie_link_width // int(8e3) - logger.info(f"PCIe bandwidth: {pcie_bw} GB/s") - intra_node_bw = pcie_bw - - cluster_info = ClusterInfo( - math_throughput=math_throughput, - memory_bw=memory_bw, - memory_budget_per_device=memory_budget, - intra_node_bw_per_device=intra_node_bw, - intra_node_sharp=intra_node_sharp, - ) - return cluster_info - - -def infer_cluster_config() -> Dict[str, Union[str, ClusterInfo]]: - device_name = torch.cuda.get_device_name(torch.cuda.current_device()) - cluster_key = infer_cluster_key() - if cluster_key is not None: - return dict(cluster_key=cluster_key) - else: - try: - cluster_info = infer_cluster_info() - except pynvml.NVMLError: - fallback_cluster_key = "L40" - cluster_info = copy.copy(cluster_infos[fallback_cluster_key]) - memory_budget = torch.cuda.mem_get_info()[1] // (1024**3) - cluster_info.memory_budget_per_device = memory_budget - logger.warning( - f"Failed to infer cluster info for {device_name}, " - f"treat it as a {fallback_cluster_key} node with {memory_budget} GB memory. " - "This setting makes no effect if you do not use auto parallel.") - return dict( - cluster_key=device_name.replace(" ", "-"), - cluster_info=cluster_info, - ) - - -if __name__ == "__main__": - logger.set_level("info") - infer_cluster_info() diff --git a/tensorrt_llm/auto_parallel/config.py b/tensorrt_llm/auto_parallel/config.py deleted file mode 100644 index 1ef203fefe..0000000000 --- a/tensorrt_llm/auto_parallel/config.py +++ /dev/null @@ -1,61 +0,0 @@ -from dataclasses import dataclass, field -from enum import auto -from typing import Dict, List, Optional, Union - -from strenum import LowercaseStrEnum - -from tensorrt_llm._utils import BaseEnumMeta, DictConversion - -from .cluster_info import ClusterInfo, cluster_infos - - -class CostModel(LowercaseStrEnum, metaclass=BaseEnumMeta): - ALPHA_BETA = auto() - PROFILE = auto() - S_CURVE = auto() - # Zero cost model is for test purpose. - # Use zero cost model for communication will make solver prefer sharding - # Use zero cost model for computation will make solver prefer replication - ZERO = auto() - - -@dataclass -class AutoParallelConfig(DictConversion): - # cluster configuration - world_size: int = 1 - gpus_per_node: int = 8 - cluster_key: str = None - cluster_info: Optional[ClusterInfo] = None - - # cost model configuration - sharding_cost_model: str = CostModel.ALPHA_BETA - comm_cost_model: str = CostModel.ALPHA_BETA - - # strategy configuration - enable_pipeline_parallelism: bool = False - enable_shard_unbalanced_shape: bool = False - enable_shard_dynamic_shape: bool = False - enable_reduce_scatter: bool = True - - # parallelization configuration - builder_flags: Optional[int] = None - debug_mode: bool = False - infer_shape: bool = True - validation_mode: bool = False - same_buffer_io: Dict[str, str] = field(default_factory=dict) - same_spec_io: Dict[str, str] = field(default_factory=dict) - sharded_io_allowlist: List[str] = field(default_factory=list) - fill_weights: bool = False - - # debug configuration - parallel_config_cache: Optional[str] = None - profile_cache: Optional[str] = None - dump_path: Optional[str] = None - debug_outputs: Union[List[str], str] = field(default_factory=list) - - def get_cluster_info(self) -> ClusterInfo: - return self.cluster_info or cluster_infos[self.cluster_key] - - @property - def enabled(self) -> bool: - return self.world_size > 1 diff --git a/tensorrt_llm/auto_parallel/device_mesh.py b/tensorrt_llm/auto_parallel/device_mesh.py deleted file mode 100644 index d82f70516f..0000000000 --- a/tensorrt_llm/auto_parallel/device_mesh.py +++ /dev/null @@ -1,612 +0,0 @@ -import os -import re -from abc import ABC, abstractmethod -from typing import List - -import h5py -import numpy as np -from filelock import FileLock - -from .config import AutoParallelConfig, CostModel -from .tensor_parallel.shape_consistency import ShapeConsistencyManager - - -class ProfileDB(ABC): - """A database that stores profiling results for multiple device mesh - shapes.""" - - @abstractmethod - def query(self, cluster_key, data_key): - ... - - @abstractmethod - def update(self, cluster_key, data_key, mesh_result): - ... - - def close(self): - pass - - -class MemDB(ProfileDB): - - def __init__(self): - self.data = {} - - def query(self, cluster_key, data_key): - key = (cluster_key, data_key) - mesh_result = self.data.get(key, None) - if mesh_result is None: - return None - else: - return mesh_result[0] - - def update(self, cluster_key, data_key, mesh_result): - key = (cluster_key, data_key) - self.data[key] = mesh_result - - -class Hdf5DB(ProfileDB): - - def __init__(self, name): - self.name = name - lock_name = self.name + ".lock" - self.lock = FileLock(lock_name, thread_local=False) - - def query(self, cluster_key, data_key): - file_name = f"{self.name}.hdf5" - key = str((cluster_key, data_key)) - self.lock.acquire() - mesh_result = None - with h5py.File(file_name, 'a') as f: - if key in f: - self.lock.release() - mesh_result = f[key] - return mesh_result[0] - else: - return None - - def update(self, cluster_key, data_key, mesh_result): - key = str((cluster_key, data_key)) - file_name = f"{self.name}.hdf5" - with h5py.File(file_name, 'a') as f: - f[key] = mesh_result - - def close(self): - self.lock.release(force=True) - - -class LogicalDeviceMesh(object): - - def __init__(self, - phy_mesh_shape, - mesh_shape, - phy_ids, - config: AutoParallelConfig, - alpha, - beta, - sharp, - prof_database=None, - shape_consistency_manager=None, - host_ips=None): - self.phy_mesh_shape = phy_mesh_shape - self.mesh_shape = mesh_shape - self.phy_ids = phy_ids - self.host_ips = host_ips - self.cluster_key = config.cluster_key + '_mesh_shape{}'.format('_'.join( - [str(i) for i in mesh_shape])) - self.prof_min_max_size = [1, 2**34] - self.prof_comm_dtypes = [ - "int8", "uint8", "int32", "uint32", "int64", "uint64", "float16", - "float32", "float64", "bfloat16" - ] - self.devices_group = { - (0, ): [self.phy_ids.transpose(), self.mesh_shape[1] - 1], - (1, ): [self.phy_ids, self.mesh_shape[1]], - (0, 1): [self.phy_ids.reshape([1, self.phy_ids.size]), 0] - } - self.prof_database = prof_database - self.shape_consistency_manager = shape_consistency_manager - self.config = config - self.cluster_info = config.get_cluster_info() - self.hw_alpha = alpha - self.hw_beta = beta - self.hw_sharp = sharp - self.algo_alpha_beta = self._estimate_algo_alpha_beta() - self.comm_op_to_nccl_test_func_name = { - 'all_reduce': 'all_reduce_perf_mpi', - 'all_gather': 'all_gather_perf_mpi', - 'all_to_all': 'alltoall_perf_mpi', - 'reduce_scatter': 'reduce_scatter_perf_mpi', - 'split': 'split', - } - - @property - def size(self) -> int: - return self.phy_ids.size - - def _estimate_algo_alpha_beta(self): - ret = {} - ar_alpha, ar_beta = {}, {} - ag_alpha, ag_beta = {}, {} - rs_alpha, rs_beta = {}, {} - a2a_alpha, a2a_beta = {}, {} - phy_num_hosts, phy_num_devices_per_host = self.phy_mesh_shape - if phy_num_hosts == 1 or phy_num_devices_per_host == 1: - for dims in [(0, ), (1, ), (0, 1), (1, 0)]: - num_devices = 1 - for dim in dims: - num_devices = self.mesh_shape[dim] * num_devices - if num_devices != 1: - ar_alpha[dims] = self.hw_alpha[0] if self.hw_sharp[ - 0] else self.hw_alpha[0] * num_devices / 2 / ( - num_devices - 1) - ar_beta[dims] = self.hw_beta[0] - ag_alpha[dims] = self.hw_alpha[0] * num_devices / ( - num_devices - 1) - ag_beta[dims] = self.hw_beta[0] - rs_alpha[dims] = self.hw_alpha[0] * num_devices / ( - num_devices - 1) - rs_beta[dims] = self.hw_beta[0] - a2a_alpha[dims] = self.hw_alpha[0] * num_devices / ( - num_devices - 1) - a2a_beta[dims] = self.hw_beta[0] - # phy and logical have the same mesh shape if num_hosts > 1 and num_devices_per_host > 1 - else: - for dims in [(0, ), (1, ), (0, 1), (1, 0)]: - num_devices = 1 - for dim in dims: - num_devices = self.mesh_shape[dim] * num_devices - if num_devices != 1: - if len(dims) == 1: - dim = dims[0] - ar_alpha[dims] = self.hw_alpha[dim] if self.hw_sharp[ - dim] else self.hw_alpha[dim] * num_devices / 2 / ( - num_devices - 1) - ar_beta[dims] = self.hw_beta[dim] - ag_alpha[dims] = self.hw_alpha[dim] * num_devices / ( - num_devices - 1) - ag_beta[dims] = self.hw_beta[dim] - rs_alpha[dims] = self.hw_alpha[dim] * num_devices / ( - num_devices - 1) - rs_beta[dims] = self.hw_beta[dim] - a2a_alpha[dims] = self.hw_alpha[dim] * num_devices / ( - num_devices - 1) - a2a_beta[dims] = self.hw_beta[dim] - elif len(dims) == 2: # two level communication - num_hosts, num_devices_per_host = phy_num_hosts, phy_num_devices_per_host - inter_node_col_alpha = self.hw_alpha[ - 0] * num_devices_per_host - inter_node_ar_alpha = inter_node_col_alpha if self.hw_sharp[ - 0] else inter_node_col_alpha * num_hosts / 2 / ( - num_hosts - 1) - intra_node_ar_alpha = self.hw_alpha[1] - intra_node_ar_alpha = intra_node_ar_alpha if self.hw_sharp[ - 1] else intra_node_ar_alpha * num_devices_per_host / 2 / ( - num_devices_per_host - 1) - ar_alpha[dims] = min(inter_node_ar_alpha, - intra_node_ar_alpha) - ar_beta[dims] = max(self.hw_beta) - ag_alpha[dims] = min( - inter_node_col_alpha * num_hosts / (num_hosts - 1), - self.hw_alpha[1] * num_devices_per_host / - (num_devices_per_host - 1)) - ag_beta[dims] = max(self.hw_beta) - rs_alpha[dims] = ag_alpha[dims] - rs_beta[dims] = ag_beta[dims] - a2a_alpha[dims] = min( - num_hosts * self.hw_alpha[0] / (num_hosts - 1), - self.hw_alpha[1] * num_hosts) - a2a_beta[dims] = max(self.hw_beta) - else: - pass - ret['all_to_all'] = [a2a_alpha, a2a_beta] - ret['all_reduce'] = [ar_alpha, ar_beta] - ret['all_gather'] = [ag_alpha, ag_beta] - ret['reduce_scatter'] = [rs_alpha, rs_beta] - ret['p2p_cross_device'] = [ - self.cluster_info.intra_node_bw_per_device, - self.cluster_info.intra_node_latency - ] - ret['p2p_cross_host'] = [ - self.cluster_info.inter_node_bw_per_device, - self.cluster_info.inter_node_latency - ] - return ret - - #[ToDo] stub functions here - def _profile_split(self, min_max_comm_size): - comm_size, elapsed_time = [], [] - size = min_max_comm_size[0] - while size <= min_max_comm_size[1]: - time = size * 2 / self.cluster_info.memory_bw - comm_size.append(size) - elapsed_time.append(time) - size = size * 2 - return np.array([comm_size, elapsed_time]) - - def _prase_nccl_test_results(self, f_nccl_test_out_log): - '''[ToDo] There is some dtye that may not been supported by nccl test, using default dtype (float)''' - start_parse = False - comm_size, elapsed_time = [], [] - try: - with open(f_nccl_test_out_log, 'r') as lines: - for line in lines: - if start_parse: - prof_data = re.split(r"[ ]+", line.strip()) - if len(prof_data) != 13: - continue - comm_size.append(float(prof_data[0])) - elapsed_time.append(float(prof_data[5])) - if 'GB/s' in line and 'us' in line: - start_parse = True - except Exception: - print(f'failed to parse {f_nccl_test_out_log}') - return comm_size, elapsed_time - - def _profile_with_nccl_test(self, min_max_comm_size, dtype, device_group, - func_name, step, workload_key): - - if func_name == 'split': - if 2 == step: - return self._profile_split(min_max_comm_size) - else: - return None - workspace_dir = self.config['profiling_workspace'] + f'/{workload_key}' - os.makedirs(workspace_dir, exist_ok=True) - outfile, errfile = workspace_dir + '/profile.out', workspace_dir + '/profile.err' - if 1 == step: - num_nodes = len(self.host_ips) - num_gpus = self.mesh_shape[0] * self.mesh_shape[1] - ntasks_per_node = num_gpus // num_nodes - nccl_test_command = '"export NCCL_TESTS_SPLIT_MASK={} && export NCCL_COLLNET_ENABLE=1 && {} -b {} -e {} -g 1 -d {} -f {}"'.format( - device_group[1], func_name, min_max_comm_size[0], - min_max_comm_size[1], dtype, 2) - sbatch_command = '#!/bin/bash\n' - sbatch_command += '#SBATCH -p {}\n'.format(self.config['partition']) - sbatch_command += '#SBATCH -A {}\n'.format(self.config['account']) - sbatch_command += '#SBATCH -J {}\n'.format(self.config['jobname']) - sbatch_command += '#SBATCH -N {}\n'.format(num_nodes) - sbatch_command += '#SBATCH -t {}\n'.format(self.config['time']) - sbatch_command += '#SBATCH --ntasks-per-node={}\n'.format( - ntasks_per_node) - sbatch_command += '#SBATCH --exclusive\n' - sbatch_command += '#SBATCH --mem=0\n' - sbatch_command += '#SBATCH --network=sharp\n' - sbatch_command += '#SBATCH --mail-type=FAIL\n' - srun_command = 'srun --nodes={} --mpi=pmix --ntasks-per-node={} --network=sharp -o {} -e {} --container-image={} bash -c '.format( - num_nodes, ntasks_per_node, outfile, errfile, - self.config['container']) - command = sbatch_command + srun_command + nccl_test_command - with open(workspace_dir + '/workload.sub', 'w') as f: - f.write(command) - with open('./preprofiling_step1.sh', 'a') as f: - f.write(f'sbatch {workspace_dir}/workload.sub\n') - return None - - else: - comm_size, elapsed_time = self._prase_nccl_test_results(outfile) - if len(comm_size) < 2: - assert 0, 'the profiling for {} was failed at step1, please try again'.format( - workload_key) - else: - print(workload_key, comm_size, elapsed_time) - return np.array([comm_size, elapsed_time]) - - def _profile_single_comm_perf(self, device_group, comm_op, step, data_key): - results = {} - func_name = self.comm_op_to_nccl_test_func_name[comm_op] - for dtype in self.prof_comm_dtypes: - size_time = self._profile_with_nccl_test( - self.prof_min_max_size, dtype, device_group, func_name, step, - data_key + f'_dtype{dtype}') - results[dtype] = size_time - return results - - def profile_all_comms_perf(self, step): - if self.mesh_shape == (1, 1): - return None - mesh_results = self.prof_database.query(self.cluster_key, - self.mesh_shape) - if mesh_results: - return mesh_results - - mesh_results = {} - data_key = self.cluster_key + f'_mesh_shape{self.mesh_shape[0]}x{self.mesh_shape[1]}' - for comm_op in [ - 'all_reduce', 'all_to_all', 'all_gather', 'reduce_scatter', - 'split' - ]: - comm_perf = {} - for dim, device_group in self.devices_group.items(): - # don't need to profile for mesh dim == 1 - if len(dim) == 1 and self.mesh_shape[dim[0]] == 1: - continue - - comm_perf[dim] = self._profile_single_comm_perf( - device_group, comm_op, step, data_key + - '_comm_op{}_dim{}'.format(comm_op, ''.join(map(str, dim)))) - mesh_results[comm_op] = comm_perf - if 2 == step: - self.prof_database.update(self.cluster_key, self.mesh_shape, - mesh_results) - - return mesh_results - - def _model_comm_cost_from_s_curve(self, size_time_array, realsize): - assert size_time_array[0][0] <= realsize <= size_time_array[0][-1],\ - 'the comm_size: {} is not in the profile range: [{}{}]'\ - .format(realsize, size_time_array[0][0], size_time_array[0][-1]) - return np.interp(realsize, size_time_array[0], size_time_array[1]) - - def _model_comm_cost_from_alpha_beta(self, comm_op, dim_key, size_in_bytes): - elapsed_time = 0.0 - if 'split' == comm_op: - elapsed_time = size_in_bytes * 2 / ( - self.cluster_info.memory_bw * - self.cluster_info.memory_efficiency) * 1e-3 - else: - dict_alpha, dict_beta = self.algo_alpha_beta[comm_op] - alpha, beta = dict_alpha[dim_key], dict_beta[dim_key] - elapsed_time = (size_in_bytes / - (alpha * self.cluster_info.communication_efficiency) - * 1e-3) + beta - return elapsed_time - - def _input_size_to_comm_size(self, comm_op, dims, input_size): - ret = input_size - if 'all_gather' == comm_op: - for dim in dims: - ret = ret * self.mesh_shape[dim] - return ret - - def estimate_comm_cost(self, comm_op, dim, input_size, dtype): - - size = self._input_size_to_comm_size(comm_op, dim, input_size) - if self.config.comm_cost_model == CostModel.S_CURVE: - mesh_perf = self.prof_database.query(self.cluster_key, - self.mesh_shape) - assert mesh_perf is not None, 'the mesh is not profiled, mesh_shape = {}'.format( - self.mesh_shape) - comm_op_perf = mesh_perf.get(comm_op, None) - assert comm_op_perf is not None, '{} is not profiled'.format( - comm_op) - elapsed_time = self._model_comm_cost_from_s_curve( - comm_op_perf[tuple(dim)][dtype], size) - return elapsed_time - elif self.config.comm_cost_model == CostModel.ALPHA_BETA: - elapsed_time = self._model_comm_cost_from_alpha_beta( - comm_op, tuple(dim), size) - elif self.config.comm_cost_model == CostModel.PROFILE: - assert False, 'Unsupported profile based communication cost model now' - elif self.config.comm_cost_model == CostModel.ZERO: - elapsed_time = 0.0 - - return elapsed_time # us - - -class PhysicalDeviceMesh(object): - - def __init__(self, - phy_devices_id, - config: AutoParallelConfig, - prof_database=None, - shape_consistency_manager=None, - host_ips=None): - self.phy_devices_id = np.array(phy_devices_id) - self.num_hosts, self.num_devices_per_host = self.phy_devices_id.shape - self.host_ips = host_ips - if host_ips is None: - self.host_ips = [''] * self.num_hosts - self.config = config - self.cluster_info = config.get_cluster_info() - self.prof_database: ProfileDB = prof_database - self.shape_consistency_manager = shape_consistency_manager - if self.config.comm_cost_model not in CostModel: - raise ValueError( - f'unsupported communication cost model: {self.config.comm_cost_model}' - ) - if self.config.sharding_cost_model not in CostModel: - raise ValueError( - f'unsupported sharding cost model: {self.config.sharding_cost_model}' - ) - if self.config.comm_cost_model == CostModel.S_CURVE or self.config.sharding_cost_model == CostModel.PROFILE: - if self.prof_database is None: - profile_cache = config.profile_cache - if profile_cache is None: - self.prof_database = MemDB() - else: - self.prof_database = Hdf5DB(profile_cache) - elif self.config.comm_cost_model == CostModel.ALPHA_BETA: - assert self.cluster_info.intra_node_bw_per_device > 0, 'intra_node_bw_per_device is needed for alpha_beta method' - assert self.cluster_info.inter_node_bw_per_device > 0, 'inter_node_bw_per_device is needed for alpha_beta method' - if self.config.sharding_cost_model == CostModel.ALPHA_BETA: - assert self.cluster_info.memory_bw > 0, 'memory_bw is needed for alpha_beta method' - - if not shape_consistency_manager: - self.shape_consistency_manager = ShapeConsistencyManager() - - @property - def size(self) -> int: - return self.phy_devices_id.size - - def close(self): - if self.prof_database is not None: - self.prof_database.close() - - def split_pipeline_meshes( - self, num_stages, - num_devices_per_stage) -> List["PhysicalDeviceMesh"]: - sub_meshes = [] - if num_devices_per_stage <= self.num_devices_per_host: - assert self.num_devices_per_host % num_devices_per_stage == 0, \ - "num_devices_per_host ({}) % num_devices_per_stage ({}) != 0"\ - .format(self.num_devices_per_host, num_devices_per_stage) - num_clusters_per_host = self.num_devices_per_host // num_devices_per_stage - num_clusters = self.num_hosts * num_clusters_per_host - assert num_stages % num_clusters == 0, \ - "num_stages({}) % num_clusters({}) !=0".format(num_stages, num_clusters) - for mesh_id in range(num_stages): - cluster_id = mesh_id % num_clusters - cluster_col = cluster_id % num_clusters_per_host - cluster_row = cluster_id // num_clusters_per_host - sub_devices_id = [ - self.phy_devices_id[cluster_row][cluster_col * - num_devices_per_stage:( - (cluster_col + 1) * - num_devices_per_stage)] - ] - sub_meshes.append( - PhysicalDeviceMesh(sub_devices_id, self.config, - self.prof_database, - self.shape_consistency_manager, - [self.host_ips[cluster_row]])) - else: - assert num_devices_per_stage % self.num_devices_per_host == 0, \ - "num_devices_per_stage ({}) % num_devices_per_host ({}) != 0"\ - .format(num_devices_per_stage, self.num_devices_per_host) - num_host_per_cluster = num_devices_per_stage // self.num_devices_per_host - assert self.num_hosts % num_host_per_cluster == 0, \ - "num_hosts ({}) % num_host_per_cluster({}) != 0".format(self.num_hosts, num_host_per_cluster) - num_clusters = self.num_hosts // num_host_per_cluster - for mesh_id in range(num_stages): - cluster_id = mesh_id % num_clusters - cluster_row = cluster_id * num_host_per_cluster - sub_devices_id = self.phy_devices_id[cluster_row:( - cluster_row + num_host_per_cluster)] - host_ips = self.host_ips[cluster_row:(cluster_row + - num_host_per_cluster)] - sub_meshes.append( - PhysicalDeviceMesh(sub_devices_id, self.config, - self.prof_database, - self.shape_consistency_manager, - host_ips)) - return sub_meshes - - def _profile_logical_meshes(self, logical_meshes, step): - for lmesh in logical_meshes: - lmesh.profile_all_comms_perf(step) - - def as_logical_mesh(self) -> LogicalDeviceMesh: - alpha = [ - self.cluster_info.inter_node_bw_per_device, - self.cluster_info.intra_node_bw_per_device - ] - beta = [ - self.cluster_info.inter_node_latency, - self.cluster_info.intra_node_latency - ] - sharp = [ - self.cluster_info.inter_node_sharp, - self.cluster_info.intra_node_sharp - ] - return LogicalDeviceMesh( - self.phy_devices_id.shape, - self.phy_devices_id.shape, - self.phy_devices_id, - self.config, - alpha, - beta, - sharp, - self.prof_database, - self.shape_consistency_manager, - self.host_ips, - ) - - def get_logical_meshes(self): - logical_meshes = [] - # (1, 2) -> (1, 2) - # (1, 4) -> (2, 2) - # (1, 8) -> (2, 4) - # (1, 16) -> (2, 8), (4, 4) - # (1, 32) -> (2, 16), (4, 8) - # (1, 48) -> (2, 24), (3, 16), (4, 12), (6, 8) - # (1, 64) -> (2, 32), (4, 16), (8, 8) - # we will traverse logical shape's axis in sharding spec, thus (2, 8) contains (8, 2) - # we will merge logical shapes' axis, thus (2, 8) contains (1, 16) and (16, 1) - if self.num_hosts == 1: - alpha = [self.cluster_info.intra_node_bw_per_device] - beta = [self.cluster_info.intra_node_latency] - sharp = [self.cluster_info.intra_node_sharp] - for i in range(2, self.num_devices_per_host): - if self.num_devices_per_host % i == 0 and i * i <= self.num_devices_per_host: - lmesh_shape = (i, self.num_devices_per_host // i) - lmesh_phy_ids = self.phy_devices_id.reshape(lmesh_shape) - logical_meshes.append( - LogicalDeviceMesh(self.phy_devices_id.shape, - lmesh_shape, lmesh_phy_ids, - self.config, alpha, beta, sharp, - self.prof_database, - self.shape_consistency_manager, - self.host_ips)) - # (8, 1) -> (2, 4) - # (16, 1) -> (2, 8), (4, 4) - elif self.num_devices_per_host == 1: - alpha = [self.cluster_info.inter_node_bw_per_device] - beta = [self.cluster_info.inter_node_latency] - sharp = [self.cluster_info.inter_node_sharp] - for i in range(2, self.num_hosts): - if self.num_hosts % i == 0 and i * i <= self.num_hosts: - lmesh_shape = (i, self.num_hosts // i) - lmesh_phy_ids = self.phy_devices_id.reshape(lmesh_shape) - logical_meshes.append( - LogicalDeviceMesh(self.phy_devices_id.shape, - lmesh_phy_ids, self.config, alpha, - beta, sharp, self.prof_database, - self.shape_consistency_manager, - self.host_ips)) - # (2, 1) -> (2, 1) - # (2, 8) -> (2, 8) - # (1, 2) -> (1, 2) - # (1, 3) -> (1, 3) - # (1, 5) -> (1, 5) - if 0 == len(logical_meshes): - logical_meshes.append(self.as_logical_mesh()) - return logical_meshes - - ''' - we assume we can evenly split the pipeline and deviceMesh - ''' - - def _list_all_sub_meshes(self): - sub_meshes = [] - for num_devices_per_stage in range(1, self.num_devices_per_host + 1): - if self.num_devices_per_host % num_devices_per_stage == 0: - num_stages = self.num_hosts * self.num_devices_per_host // num_devices_per_stage - sub_meshes.append( - self.split_pipeline_meshes(num_stages, - num_devices_per_stage)[0]) - for num_hosts_per_stage in range(2, self.num_hosts + 1): - if self.num_hosts % num_hosts_per_stage == 0: - num_stages = self.num_hosts // num_hosts_per_stage - sub_meshes.append( - self.split_pipeline_meshes( - num_stages, - num_hosts_per_stage * self.num_devices_per_host)[0]) - return sub_meshes - - def list_all_pipeline_configs(self): - configs = [] - for num_devices_per_stage in range(1, self.num_devices_per_host + 1): - if self.num_devices_per_host % num_devices_per_stage == 0: - num_stages = self.num_hosts * self.num_devices_per_host // num_devices_per_stage - configs.append((num_stages, num_devices_per_stage)) - for num_hosts_per_stage in range(2, self.num_hosts + 1): - if self.num_hosts % num_hosts_per_stage == 0: - num_stages = self.num_hosts // num_hosts_per_stage - configs.append( - (num_stages, - num_hosts_per_stage * self.num_devices_per_host)) - return configs - - def profile_s_curve(self, step): - sub_phy_device_meshes = self._list_all_sub_meshes() - for phy_mesh in sub_phy_device_meshes: - lmeshes = phy_mesh.get_logical_meshes() - self._profile_logical_meshes(lmeshes, step) - if 2 == step: - self.save_profile_database() - - def profile_alpha_beta(self): - alpha = [250, 25] - beta = [100, 100] - return alpha, beta diff --git a/tensorrt_llm/auto_parallel/node_graph.py b/tensorrt_llm/auto_parallel/node_graph.py deleted file mode 100644 index 503925c9d3..0000000000 --- a/tensorrt_llm/auto_parallel/node_graph.py +++ /dev/null @@ -1,347 +0,0 @@ -from typing import List - -import pandas as pd -import tensorrt as trt - -from .pipeline_graph import PipelineGraph -from .runtime_profiling import RuntimeProfiler -from .simplifier import GraphConfig, StageType -from .solver import CostGraph, Solver -from .tensor_parallel.activation_node import Activation -from .tensor_parallel.assertion_node import Assertion -from .tensor_parallel.cast_node import Cast -from .tensor_parallel.concatenation_node import Concatenation -from .tensor_parallel.constant_node import Constant -from .tensor_parallel.elementwise_node import ElementWise -from .tensor_parallel.fill_node import Fill -from .tensor_parallel.gather_node import Gather -from .tensor_parallel.identity_node import Identity -from .tensor_parallel.input_node import InputNode -from .tensor_parallel.matmul_node import MatrixMultiply -from .tensor_parallel.node import Node -from .tensor_parallel.normalization_node import Normalization -from .tensor_parallel.output_node import OuputNode -from .tensor_parallel.p2p_node import P2PNode, P2PType -from .tensor_parallel.plugin_node import PluginNode -from .tensor_parallel.plugin_nodes.gemm_node import GemmPlugin -from .tensor_parallel.plugin_nodes.gpt_attention_node import GPTAttentionPlugin -from .tensor_parallel.plugin_nodes.identity_node import IdentityPlugin -from .tensor_parallel.plugin_nodes.look_up_node import LookupPlugin -from .tensor_parallel.plugin_nodes.normalization_node import (LayernormPlugin, - RMSnormPlugin) -from .tensor_parallel.reduce_node import Reduce -from .tensor_parallel.select_node import Select -from .tensor_parallel.shape_node import Shape -from .tensor_parallel.shuffle_node import Shuffle -from .tensor_parallel.slice_node import Slice -from .tensor_parallel.softmax_node import SoftMax -from .tensor_parallel.unary_node import Unary - -LAYER_TYPE_2_NODE_TYPE = { - trt.LayerType.ACTIVATION: Activation, - trt.LayerType.ASSERTION: Assertion, - trt.LayerType.CAST: Cast, - trt.LayerType.CONCATENATION: Concatenation, - trt.LayerType.CONSTANT: Constant, - trt.LayerType.ELEMENTWISE: ElementWise, - trt.LayerType.FILL: Fill, - trt.LayerType.GATHER: Gather, - trt.LayerType.IDENTITY: Identity, - trt.LayerType.MATRIX_MULTIPLY: MatrixMultiply, - trt.LayerType.NORMALIZATION: Normalization, - trt.LayerType.PLUGIN_V2: PluginNode, - trt.LayerType.REDUCE: Reduce, - trt.LayerType.SELECT: Select, - trt.LayerType.SHAPE: Shape, - trt.LayerType.SHUFFLE: Shuffle, - trt.LayerType.SLICE: Slice, - trt.LayerType.SOFTMAX: SoftMax, - trt.LayerType.UNARY: Unary, -} -# TODO: BertAttention/All Quant plugins -PLUGIN_LAYER_TYPE_2_NODE_TYPE = { - 'GPTAttention': GPTAttentionPlugin, - 'Gemm': GemmPlugin, - 'Layernorm': LayernormPlugin, - 'Rmsnorm': RMSnormPlugin, - 'Lookup': LookupPlugin, - 'Identity': IdentityPlugin, -} - - -class NodeGraph: - - def __init__(self, graph: PipelineGraph): - self._nodes = {} - - # construct nodes - for input in graph.inputs: - self._nodes[input.name] = InputNode(input) - for layer in graph.layers: - layer.to_base_class() - if "p2p_type" in layer.attrs: - self._nodes[layer.name] = P2PNode(layer) - elif layer.type == trt.LayerType.PLUGIN_V2: - layer.to_subclass() - plugin_type = layer.as_trt().plugin.plugin_type - layer.to_base_class() - if plugin_type in PLUGIN_LAYER_TYPE_2_NODE_TYPE: - node = PLUGIN_LAYER_TYPE_2_NODE_TYPE[plugin_type](layer) - else: - node = LAYER_TYPE_2_NODE_TYPE[layer.type](layer) - self._nodes[layer.name] = node - else: - node = LAYER_TYPE_2_NODE_TYPE[layer.type](layer) - self._nodes[layer.name] = node - for output in graph.outputs: - self._nodes[output.name] = OuputNode(output) - for node in self.nodes: - node.post_init(self) - node.node_runtime_profiler = RuntimeProfiler() - - def get_node(self, name): - return self._nodes[name] - - @property - def nodes(self) -> List[Node]: - return [*self._nodes.values()] - - def assign_cost_weights(self, graph_config: GraphConfig): - layer_mapping = graph_config.graph_mapping.layer_mapping - for layer_name in layer_mapping.values(): - node = self.get_node(layer_name) - node.sharding_weight += 1 - node.resharding_weight += 1 - same_spec_layer_mapping = graph_config.graph_mapping.same_spec_layer_mapping - for same_spec_layer_name, layer_name in same_spec_layer_mapping.items(): - node = self.get_node(layer_name) - same_spec_node = self.get_node(same_spec_layer_name) - same_spec_node.sharding_weight = node.sharding_weight - same_spec_node.resharding_weight = node.resharding_weight - - def set_slowest_stage(self, stage_type: StageType, - graph_config: GraphConfig): - num_micro_batches = graph_config.num_micro_batches - block_per_stage = graph_config.num_blocks // graph_config.num_stages - block_pipeline_weight = block_per_stage * (num_micro_batches - 1) - for node in self.nodes: - node.pipeline_weight = 0 - node.cost_level = -1 - if node.stage_type == StageType.START: - if stage_type == StageType.START: - node.pipeline_weight = num_micro_batches - 1 - node.cost_level = 1 - else: - node.cost_level = 0 - if stage_type == StageType.START and node.in_start_block: - node.pipeline_weight = block_pipeline_weight - if node.stage_type == StageType.END: - if stage_type == StageType.END: - node.pipeline_weight = num_micro_batches - 1 - node.cost_level = 1 - else: - node.cost_level = 0 - if stage_type == StageType.END and node.in_end_block: - node.pipeline_weight = block_pipeline_weight - if isinstance(node, P2PNode): - if (graph_config.has_cross_host - and node.p2p_type == P2PType.CROSS_HOST) or ( - not graph_config.has_cross_host - and node.p2p_type == P2PType.CROSS_DEVICE): - if stage_type == StageType.BLOCK: - node.pipeline_weight += num_micro_batches - 1 - node.cost_level = 1 - else: - node.cost_level = 0 - elif (graph_config.has_cross_device - and node.p2p_type == P2PType.CROSS_DEVICE) or ( - not graph_config.has_cross_device - and node.p2p_type == P2PType.CROSS_HOST): - node.pipeline_weight += num_micro_batches - 1 - if stage_type == StageType.BLOCK and node.in_slowest_block: - node.pipeline_weight = block_pipeline_weight - - def get_cost_graph(self, lmesh): - leaf_strategies = [] - for node in self.nodes: - if node.is_replicated: - node.set_strategy(None, lmesh) - else: - node.collect_strategies(lmesh) - for node in self.nodes: - strategies_vector = node.update_resharding_cost() - if len(strategies_vector) != 0: - leaf_strategies.append(strategies_vector) - cost_graph = CostGraph(leaf_strategies) - return cost_graph - - def find_solution(self, cost_graph, memory_budget): - solver = Solver(cost_graph, memory_budget=memory_budget) - solution = solver.find_solution()[1] - - graph_strategy = solution.node_best_strategy - for node_name, strategy in graph_strategy.items(): - node = self._nodes[node_name] - for idx, pre_node in enumerate(node.predecessor_nodes): - if pre_node is None: - continue - if pre_node.node_name not in strategy.best_resharding_cost: - continue - strategy.best_resharding_cost[ - idx] = strategy.best_resharding_cost[pre_node.node_name] - strategy.node_names[idx] = pre_node.node_name - for key in list(strategy.best_resharding_cost.keys()): - if isinstance(key, str): - del strategy.best_resharding_cost[key] - - return solution - - def visualize(self, name='pp_graph'): - with open(name + '.dot', 'w') as f: - f.write("digraph {\n") - ''' - f.write(" // Value Nodes\n") - for name, tensor in self._tensors.items(): - f.write(" \"{}\" [fillcolor = \"green\", label = \"{}\", shape = \"box\", style = \"filled\"];\n".format(name, tensor.shape)) - ''' - f.write(" // Operation Nodes\n") - for name, node in self._nodes.items(): - fillcolor = 'white' - if 'MATRIX_MULTIPLY' in name: - fillcolor = 'green' - label = name - if len(node.outputs) > 0: - label = name + '\\n' + str(node.outputs[0].shape) - f.write( - " \"{}\" [fillcolor = \"{}\", label = \"{}\", shape = \"box\", style = \"filled\"];\n" - .format(name, fillcolor, label)) - f.write(" // Edges\n") - for name, node in self._nodes.items(): - for successor_node in node.successor_nodes: - if successor_node: - f.write(" \"{}\" ->\"{}\";\n".format( - name, successor_node.node_name)) - f.write(" }\n") - - def visualize_solution(self, - solution, - fname='pp_graph_solution', - ignore_shape_io=True): - with open(fname + '.dot', 'w') as f: - names, costs, block_ids = [], [], [] - f.write("digraph {\n") - f.write(" // Operation Nodes\n") - for name, node in self._nodes.items(): - if ignore_shape_io and node.layer is not None and node.layer.is_shape_io: - continue - cost = 0.0 - fillcolor = 'white' - if 'MATRIX_MULTIPLY' in name or 'PLUGIN_V2_Gemm' in name: - fillcolor = 'orange' - elif '_same_spec' in name: - fillcolor = 'gray' - elif 'p2p_block' in name: - fillcolor = 'blue' - elif 'PLUGIN' in name: - fillcolor = 'yellow' - - shape = 'box' - if 'output_node' == node.node_type or 'input_node' == node.node_type: - shape = 'ellipse' - fillcolor = 'green' - - label = name + f'_block{node.building_block_id}_weight{node.sharding_weight}' - if len(node.inputs) > 0: - for idx, input in enumerate(node.inputs): - if not input: - continue - label = label + f'\\ninput{idx}_' + str( - input.shape) + f'_{input.dtype_str_size[0]}_' - if node.node_name in solution.node_best_strategy: - best_strategy = solution.node_best_strategy[ - node.node_name] - shard_seq = str( - best_strategy.sharding_specs[f'input{idx}']. - sharding_sequence) - label = label + shard_seq - if idx not in best_strategy.best_resharding_cost: - continue - rcosts = best_strategy.best_resharding_cost[idx][0] - comm_action_sequence, resharding_cost = rcosts[ - 1], rcosts[2] - if len(comm_action_sequence) > 0: - label = label + '|' - for commspec in comm_action_sequence: - comm = [ - commspec.comm_pattern, commspec.gather_dim, - commspec.shard_dim, - commspec.logical_process_axis - ] - label = label + '->' + str(comm) - if resharding_cost > 0: - label = label + '_rcost{:.2}'.format( - resharding_cost) - cost = cost + resharding_cost - if len(node.outputs) > 0: - best_strategy = None - for idx, output in enumerate(node.outputs): - label = label + f'\\noutput{idx}_' + str( - output.shape) + f'_{output.dtype_str_size[0]}' - if node.node_name in solution.node_best_strategy: - best_strategy = solution.node_best_strategy[ - node.node_name] - shard_seq = str( - best_strategy.sharding_specs[f'output{idx}']. - sharding_sequence) - comm = None - if f'output{idx}' in best_strategy.communication_actions: - commspec = best_strategy.communication_actions[ - f'output{idx}'] - comm = [ - commspec.comm_pattern, commspec.gather_dim, - commspec.shard_dim, - commspec.logical_process_axis - ] - label = label + '_' + shard_seq - if comm: - label = label + f' | {comm}' - if best_strategy: - cost = cost + best_strategy.sharding_cost + best_strategy.communication_cost - label = label + '| scost{:.2}'.format( - best_strategy.sharding_cost) - if best_strategy.communication_cost > 0: - label = label + ' | ccost{:.2}'.format( - best_strategy.communication_cost) - names.append(name) - costs.append(cost) - block_ids.append([ - node.building_block_id, node.cost_level, - node.sharding_weight + node.pipeline_weight, - node.same_spec_id - ]) - f.write( - " \"{}\" [fillcolor = \"{}\", label = \"{}\", shape = \"{}\", style = \"filled\"];\n" - .format(name, fillcolor, label, shape)) - f.write(" // Edges\n") - for name, node in self._nodes.items(): - if ignore_shape_io and node.layer is not None and node.layer.is_shape_io: - continue - for successor_node in node.successor_nodes: - if successor_node: - if ignore_shape_io and successor_node.layer is not None and successor_node.layer.is_shape_io: - continue - f.write(" \"{}\" ->\"{}\";\n".format( - name, successor_node.node_name)) - f.write(" }\n") - df = pd.DataFrame.from_dict({ - 'node': - names, - 'cost': - costs, - 'block_id': [block[0] for block in block_ids], - 'cost_level': [block[1] for block in block_ids], - 'sharding_weight': [block[2] for block in block_ids], - 'same_spec_id': [block[3] for block in block_ids] - }) - df['weight_cost'] = df['sharding_weight'] * df['cost'] - df.to_csv(fname + '.csv') diff --git a/tensorrt_llm/auto_parallel/parallelization.py b/tensorrt_llm/auto_parallel/parallelization.py deleted file mode 100644 index 0e0e0d78c3..0000000000 --- a/tensorrt_llm/auto_parallel/parallelization.py +++ /dev/null @@ -1,2299 +0,0 @@ -import contextlib -import copy -import itertools -import re -from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from typing import Any, ClassVar, Dict, List, Sequence, Set, Tuple, Union - -import numpy as np -import tensorrt as trt -import torch -from filelock import FileLock - -from tensorrt_llm import serialization -from tensorrt_llm._utils import (str_dtype_to_trt, trt_dtype_to_np, - trt_dtype_to_torch) -from tensorrt_llm.functional import AllReduceParams, create_allreduce_plugin -from tensorrt_llm.logger import logger -from tensorrt_llm.mapping import Mapping -from tensorrt_llm.network import (PluginInfo, delete_plugin_info, get_np_weight, - get_plugin_info, set_plugin_info) -from tensorrt_llm.plugin import TRT_LLM_PLUGIN_NAMESPACE, init_all_reduce_helper -from tensorrt_llm.plugin.plugin import CustomAllReduceHelper -from tensorrt_llm.version import __version__ - -from .config import AutoParallelConfig -from .device_mesh import LogicalDeviceMesh -from .pipeline_graph import Layer, PipelineGraph, Tensor -from .shape_info import (ShapeInfo, get_per_layer_graph, get_shape_layers, - infer_per_layer_shapes) -from .simplifier import GraphConfig, GraphMapping, Simplifier, StageType -from .tensor_parallel.comm_spec import CommSpec -from .tensor_parallel.plugin_nodes.gpt_attention_node import ( - GPTAttentionPlugin, IdxEntry, IdxEntryParser) -from .tensor_parallel.sharding_spec import ShardingSpec, get_sharding_sequence -from .tensor_parallel.sharding_strategy import ShardingStrategy -from .utils import (get_updated_plugin, to_base_class_layer, to_subclass_layer, - to_trt_weights) - -default_int_dtype = trt.int64 - -# These dataclasses are used in ParallelConfig serialization. If there are other classes need to be serialized, please add to this list. -BASE_AUTOPP_CLASSES = { - "tensorrt_llm.auto_parallel.parallelization": ["ParallelConfig"], - "tensorrt_llm.auto_parallel.config": ["AutoParallelConfig", "CostModel"], - "tensorrt_llm.auto_parallel.simplifier": ["GraphConfig", "StageType"] -} - - -@dataclass -class ParallelConfig: - VERSION: ClassVar[str] = __version__ - - version: str = VERSION - network_hash: str = None - auto_parallel_config: AutoParallelConfig = None - graph_config: GraphConfig = None - lmesh: LogicalDeviceMesh = None - cost: float = None - graph_strategy: Dict[str, ShardingStrategy] = None - stage_type: StageType = None - - def save(self, filename): - with open(filename, 'wb') as file: - serialization.dump(self, file) - - @staticmethod - def from_file(filename) -> "ParallelConfig": - with open(filename, "rb") as file: - return serialization.load(file, - approved_imports=BASE_AUTOPP_CLASSES) - - def print_graph_strategy(self, file=None): - for index, (node_name, - strategy) in enumerate(self.graph_strategy.items()): - print(f'\n[{index}]: node_name = {node_name}', file=file) - strategy.print_strategy(best_resharding_cost_only=True, file=file) - - -def desimplify_strategy( - graph: PipelineGraph, - graph_strategy: Dict[str, ShardingStrategy], - graph_mapping: GraphMapping, -): - for strategy in graph_strategy.values(): - for name, commspec in list(strategy.communication_actions.items()): - strategy.communication_actions[name] = [commspec] - strategy.sharding_specs[ - f"{name}_after_comm"] = strategy.sharding_specs[name] - - # insert same spec layers' communication actions after - # its producer's communication actions - same_spec_layer_mapping = graph_mapping.same_spec_layer_mapping - for same_spec_layer_name in same_spec_layer_mapping.keys(): - same_spec_strategy = graph_strategy[same_spec_layer_name] - same_spec_commspecs = same_spec_strategy.best_resharding_cost[0][0][1] - if len(same_spec_commspecs) == 0: - continue - output_name = same_spec_layer_name[:-len("_same_spec")] - output = graph.get_tensor(output_name) - layer_name = output.producer.name - output_index = output.output_index - strategy = graph_strategy[layer_name] - commspecs = strategy.communication_actions.get(f"output{output_index}", - []) - commspecs.extend(same_spec_commspecs) - strategy.communication_actions[f"output{output_index}"] = commspecs - strategy.sharding_specs[ - f"output{output_index}_after_comm"] = same_spec_strategy.sharding_specs[ - "output0"] - - layer_mapping = graph_mapping.layer_mapping - for removed_layer_name, layer_name in layer_mapping.items(): - if layer_name in graph_strategy: - strategy = copy.copy(graph_strategy[layer_name]) - layer = graph.get_layer(removed_layer_name) - if layer is not None: - strategy.node_names = strategy.node_names.copy() - for index, name in list(strategy.node_names.items()): - input = layer.get_input(index) - node_name = input.name if input.producer is None else input.producer.name - strategy.node_names[index] = node_name - graph_strategy[removed_layer_name] = strategy - - -@dataclass -class SplitInfo: - input_dim: Union[int, trt.ITensor] - partition: int - - def __deepcopy__(self, memo) -> "SplitInfo": - return SplitInfo(self.input_dim, self.partition) - - -@dataclass -class TensorInfo: - name: str = None - split_infos: Dict[int, SplitInfo] = field(default_factory=dict) - - def set_split_info(self, dim, split_info): - self.split_infos[dim] = split_info - - def __deepcopy__(self, memo) -> "TensorInfo": - return TensorInfo(self.name, copy.deepcopy(self.split_infos)) - - -@dataclass -class TensorContext: - info_by_device: Dict[int, TensorInfo] = field(default_factory=dict) - device_dims_for_shape: Set[int] = field(default_factory=set) - - def update_name_mapping(self, device_id, new_name): - if device_id not in self.info_by_device: - self.info_by_device[device_id] = TensorInfo() - self.info_by_device[device_id].name = new_name - - def set_split_info(self, device_id, dim, split_info): - if device_id not in self.info_by_device: - self.info_by_device[device_id] = TensorInfo() - self.info_by_device[device_id].set_split_info(dim, split_info) - - def set_split_infos(self, device_id, split_infos: Dict[int, SplitInfo]): - if device_id not in self.info_by_device: - self.info_by_device[device_id] = TensorInfo() - self.info_by_device[device_id].split_infos = split_infos - - def __deepcopy__(self, memo) -> "TensorContext": - return TensorContext(copy.deepcopy(self.info_by_device), - set(self.device_dims_for_shape)) - - -@dataclass -class LayerUpdate: - updated_attrs: Dict[str, Any] = field(default_factory=dict) - updated_inputs: Dict[int, trt.ITensor] = field(default_factory=dict) - split_info_updated: bool = False - - @staticmethod - def none() -> "LayerUpdate": - return LayerUpdate() - - -@dataclass -class GraphContext: - tensor_contexts: Dict[str, TensorContext] = field(default_factory=dict) - - def get_name(self, tensor_name, device_id): - if tensor_name not in self.tensor_contexts: - return None - if device_id not in self.tensor_contexts[tensor_name].info_by_device: - return None - return self.tensor_contexts[tensor_name].info_by_device[device_id].name - - def update_name_mapping(self, tensor_name, device_id, new_name): - if tensor_name not in self.tensor_contexts: - self.tensor_contexts[tensor_name] = TensorContext() - self.tensor_contexts[tensor_name].update_name_mapping( - device_id, new_name) - - def get_name_mapping(self, device_id, prefix: str) -> Dict[str, str]: - name_mapping = {} - for tensor_name in self.tensor_contexts.keys(): - new_name = self.get_name(tensor_name, device_id) - if new_name is not None: - name_mapping[f"{prefix}{tensor_name}"] = new_name - return name_mapping - - def add_device_dims_for_shape(self, tensor_name: str, - device_dims: Sequence[int]): - if tensor_name not in self.tensor_contexts: - self.tensor_contexts[tensor_name] = TensorContext() - self.tensor_contexts[tensor_name].device_dims_for_shape.update( - device_dims) - - def get_device_dims_for_shape(self, tensor_name: str): - if tensor_name not in self.tensor_contexts: - return set() - return self.tensor_contexts[tensor_name].device_dims_for_shape - - def get_split_infos(self, tensor_name, device_id): - if tensor_name not in self.tensor_contexts: - return None - if device_id not in self.tensor_contexts[tensor_name].info_by_device: - return None - return self.tensor_contexts[tensor_name].info_by_device[ - device_id].split_infos - - def set_split_info(self, tensor_name, device_id, dim, split_info): - if tensor_name not in self.tensor_contexts: - self.tensor_contexts[tensor_name] = TensorContext() - self.tensor_contexts[tensor_name].set_split_info( - device_id, dim, split_info) - - def set_split_infos(self, tensor_name, device_id, - split_infos: Dict[int, SplitInfo]): - if tensor_name not in self.tensor_contexts: - self.tensor_contexts[tensor_name] = TensorContext() - self.tensor_contexts[tensor_name].set_split_infos( - device_id, split_infos) - - def update_layer_context(self, wrapped_layer: Layer, - layer_update: LayerUpdate, - local_context: "GraphContext", device_id: int, - device_ids: np.ndarray, - sharding_specs: Dict[str, ShardingSpec]): - layer = wrapped_layer.as_trt() - for i in range(layer.num_outputs): - output = layer.get_output(i) - new_name = local_context.get_name(output.name, device_id) - if new_name is not None: - self.update_name_mapping(output.name, device_id, new_name) - if layer_update.split_info_updated: - for i in range(layer.num_outputs): - output = layer.get_output(i) - split_infos = local_context.get_split_infos( - output.name, device_id) - if split_infos is not None: - self.set_split_infos(output.name, device_id, split_infos) - return - split_info_by_device_dim = {} - for i in range(layer.num_inputs): - input = layer.get_input(i) - if input is None: - continue - sharding_spec = sharding_specs[f"input{i}"] - split_infos = local_context.get_split_infos(input.name, device_id) - if split_infos is None: - continue - for dim, split_info in split_infos.items(): - device_dim = tuple(sharding_spec.dim_partition_dict[dim]) - split_info_by_device_dim[device_dim] = split_info - for i in range(layer.num_outputs): - output = layer.get_output(i) - sharding_spec = sharding_specs[f"output{i}"] - for dim, device_dim in sharding_spec.dim_partition_dict.items(): - split_info = split_info_by_device_dim.get(tuple(device_dim)) - if split_info is None: - if device_dim == [0, 1] or device_dim == [1, 0]: - if (0, ) in split_info_by_device_dim and ( - 1, ) in split_info_by_device_dim: - split_info = SplitInfo( - split_info_by_device_dim[(0, )].input_dim * - split_info_by_device_dim[(1, )].input_dim, - split_info_by_device_dim[(0, )].partition * - split_info_by_device_dim[(1, )].partition, - ) - assert split_info is not None - partition = get_partition(device_dim, device_ids) - if split_info.input_dim != output.shape[dim]: - assert output.shape[ - dim] > 0 and output.shape[dim] % partition == 0 - output_split_info = SplitInfo(output.shape[dim], partition) - self.set_split_info(output.name, device_id, dim, - output_split_info) - - def get_local_context(self, layer: trt.ILayer) -> "GraphContext": - local_context = GraphContext() - for i in range(layer.num_inputs): - input = layer.get_input(i) - if input is None: - continue - local_context.tensor_contexts[input.name] = copy.deepcopy( - self.tensor_contexts[input.name]) - return local_context - - def get_local_context_for_output(self, - output: trt.ITensor) -> "GraphContext": - local_context = GraphContext() - local_context.tensor_contexts[output.name] = copy.deepcopy( - self.tensor_contexts[output.name]) - return local_context - - def merge_context(self, context: "GraphContext"): - self.tensor_contexts.update(context.tensor_contexts) - - -@dataclass -class ShardContext: - graph_context: GraphContext - layer: Layer - nditer: np.nditer - device_ids: np.ndarray - strategy: ShardingStrategy - - -def get_partition(device_dim, device_ids): - if device_dim == [0]: - partition = device_ids.shape[0] - elif device_dim == [1]: - partition = device_ids.shape[1] - else: - assert device_dim == [0, 1] or device_dim == [1, 0] - partition = device_ids.size - return partition - - -def get_index(device_dim, iter): - if device_dim == [0]: - index = iter.multi_index[0] - elif device_dim == [1]: - index = iter.multi_index[1] - else: - assert device_dim == [0, 1] or device_dim == [1, 0] - index = iter.iterindex - return index - - -def get_full_sharding_spec(sharding_spec): - return ShardingSpec(sharding_spec.device_mesh, - sharding_spec.data_type_size, - sharding_spec.entire_shape, - sharding_spec.max_entire_shape, - sharding_spec.raw_shape, - dim_partition_dict={}) - - -def get_comm_action_sequence(from_sharding_sepc, to_sharding_sepc): - comm_action_sequence = from_sharding_sepc.device_mesh.shape_consistency_manager.shape_consistency( - from_sharding_sepc, to_sharding_sepc)[1] - # TODO: should merged by shape_consistency - if len(comm_action_sequence) == 2: - if comm_action_sequence[0].comm_pattern == comm_action_sequence[ - 1].comm_pattern == "all_gather": - if comm_action_sequence[0].gather_dim == comm_action_sequence[ - 1].gather_dim: - comm_action_sequence = [ - CommSpec( - comm_action_sequence[0].comm_pattern, - comm_action_sequence[0].sharding_spec, - comm_action_sequence[0].gather_dim, - comm_action_sequence[0].shard_dim, [[ - *comm_action_sequence[0].logical_process_axis[0], - *comm_action_sequence[1].logical_process_axis[0] - ]], comm_action_sequence[0].mix_gather, - comm_action_sequence[0].forward_only) - ] - assert len(comm_action_sequence[0].logical_process_axis[0]) <= 2 - assert len(comm_action_sequence) <= 1 - return comm_action_sequence - - -class GraphGroup(ABC): - - @staticmethod - def from_graph( - graph: PipelineGraph, - config: ParallelConfig, - auto_parallel_config: AutoParallelConfig, - ) -> "GraphGroup": - if auto_parallel_config.debug_mode: - return PrefixedGraphGroup(graph, config, auto_parallel_config) - else: - return DistributedGraphGroup(graph, config, auto_parallel_config) - - @property - @abstractmethod - def auto_parallel_config(self) -> AutoParallelConfig: - ... - - @abstractmethod - def add_input(self, tensor, device_ids, strategy: ShardingStrategy): - ... - - @abstractmethod - def add_layer(self, layer, device_ids, strategy: ShardingStrategy): - ... - - @abstractmethod - def add_output(self, tensor, device_ids, sharding_spec: ShardingSpec): - ... - - @abstractmethod - def get_network(self, device_id) -> trt.INetworkDefinition: - ... - - @abstractmethod - def get_graph(self, device_id) -> PipelineGraph: - ... - - @property - @abstractmethod - def full_graph(self) -> PipelineGraph: - ... - - @abstractmethod - def get_prefix(self, device_id) -> str: - ... - - @abstractmethod - def get_shapes(self, device_id) -> Dict[str, Tuple[int, ...]]: - ... - - @abstractmethod - def get_values(self, device_id) -> Dict[str, List[int]]: - ... - - @abstractmethod - def add_all_reduce_layer(self, context: GraphContext, input_name, - output_name, device_ids, to_reduce_tensors): - ... - - @abstractmethod - def add_all_gather_layer(self, context: GraphContext, input_name, - output_name, device_ids, to_gather_tensors): - ... - - @abstractmethod - def register_layer(self, - layer, - base_name, - input_name, - output_name=None, - device_id=None, - keep_tensor_name=False) -> Layer: - ... - - def get_tensor(self, context: GraphContext, tensor_name: str, - device_id: int) -> Tensor: - name = context.get_name(tensor_name, device_id) - return self.get_graph(device_id).get_tensor(name) - - def add_comm(self, - context: GraphContext, - input_name, - device_ids, - commspec, - output_name=None, - is_singleton=False): - remove_index = [] - for i, device_dim in enumerate(commspec.logical_process_axis): - partition = get_partition(device_dim, device_ids) - if partition == 1: - remove_index.append(i) - if len(remove_index) > 0: - if commspec.comm_pattern in ["all_gather", "all_to_all"]: - commspec.gather_dim = [ - dim for i, dim in enumerate(commspec.gather_dim) - if i not in remove_index - ] - if commspec.comm_pattern in [ - "split", "reduce_scatter", "all_to_all" - ]: - commspec.shard_dim = [ - dim for i, dim in enumerate(commspec.shard_dim) - if i not in remove_index - ] - commspec.logical_process_axis = [ - dim for i, dim in enumerate(commspec.logical_process_axis) - if i not in remove_index - ] - flatten_device_dim = list( - itertools.chain.from_iterable(commspec.logical_process_axis)) - if flatten_device_dim == []: - return - if flatten_device_dim == [0, 1] or flatten_device_dim == [1, 0]: - self._add_comm(context, input_name, device_ids, commspec, - output_name, is_singleton) - elif flatten_device_dim == [0]: - for i in range(device_ids.shape[1]): - self._add_comm(context, input_name, device_ids[:, i:i + 1], - commspec, output_name, is_singleton) - elif flatten_device_dim == [1]: - for i in range(device_ids.shape[0]): - self._add_comm(context, input_name, device_ids[i:i + 1, :], - commspec, output_name, is_singleton) - else: - raise RuntimeError( - f"Invalid flatten device_dim: {flatten_device_dim}") - - def _add_comm(self, - context: GraphContext, - input_name, - device_ids, - commspec, - output_name=None, - is_singleton=False): - input_tensors = [ - self.get_tensor(context, input_name, device_id.item()) - for device_id in np.nditer(device_ids) - ] - comm_pattern = commspec.comm_pattern - if comm_pattern == "split": - self.add_split(context, input_name, output_name, device_ids, - commspec.shard_dim, commspec.logical_process_axis) - elif comm_pattern == "all_gather": - self.add_all_gather(context, input_name, output_name, device_ids, - commspec.gather_dim, - commspec.logical_process_axis, is_singleton) - elif comm_pattern == "all_reduce": - self.add_all_reduce(context, input_name, output_name, device_ids) - elif comm_pattern == "reduce_scatter": - self.add_reduce_scatter(context, input_name, output_name, - device_ids, commspec.shard_dim, - commspec.logical_process_axis) - elif comm_pattern == "all_to_all": - self.add_all_to_all(context, input_name, output_name, device_ids, - commspec.gather_dim, commspec.shard_dim, - commspec.logical_process_axis) - else: - raise NotImplementedError - output_tensors = [ - self.get_tensor(context, input_name, device_id.item()) - for device_id in np.nditer(device_ids) - ] - for input_tensor, output_tensor in zip(input_tensors, output_tensors): - if input_tensor.dtype != output_tensor.dtype: - raise ValueError( - f"Input tensor and output tensor should have the same dtype for communication layers, " - f"input dtype is {input_tensor.dtype} for {input_tensor.name}, " - f"but output dtype is {output_tensor.dtype} for {output_tensor.name}" - ) - - def add_all_reduce(self, context: GraphContext, input_name, output_name, - device_ids): - dtype = str_dtype_to_trt(self.full_graph._plugin_config.dtype) - to_reduce_tensors = [] - for device_id in np.nditer(device_ids): - device_id = device_id.item() - layer_info = (input_name, output_name, device_id) - network = self.get_network(device_id) - input_tensor = self.get_tensor(context, input_name, - device_id).as_trt() - input_dtype = input_tensor.dtype - if input_dtype != dtype: - to_reduce_tensor = self.cast( - network, - input_tensor, - dtype, - layer_info, - ) - else: - to_reduce_tensor = input_tensor - to_reduce_tensors.append(to_reduce_tensor) - self.add_all_reduce_layer(context, input_name, output_name, device_ids, - to_reduce_tensors) - if input_dtype != dtype: - for device_id in np.nditer(device_ids): - device_id = device_id.item() - layer_info = (input_name, output_name, device_id) - network = self.get_network(device_id) - input_tensor = self.get_tensor( - context, - input_name, - device_id, - ).as_trt() - output_tensor = self.cast( - network, - input_tensor, - input_dtype, - layer_info, - ) - context.update_name_mapping( - input_name, - device_id, - output_tensor.name, - ) - - def add_reduce_scatter(self, context: GraphContext, input_name, output_name, - device_ids, shard_dims, device_dims): - self.add_all_reduce(context, input_name, output_name, device_ids) - self.add_split(context, input_name, output_name, device_ids, shard_dims, - device_dims) - - # TODO: use native all_to_all operation - def add_all_to_all(self, context: GraphContext, input_name, output_name, - device_ids, gather_dims, shard_dims, device_dims): - self.add_all_gather(context, input_name, output_name, device_ids, - gather_dims, device_dims) - self.add_split(context, input_name, output_name, device_ids, shard_dims, - device_dims) - - def get_item(self, network, tensor, index, layer_info): - get_item_layer = network.add_slice(tensor, [index], [1], [1]) - self.register_layer(get_item_layer, f"get_item{index}", *layer_info) - return get_item_layer.get_output(0) - - def get_shape(self, network, tensor, layer_info): - shape_layer = network.add_shape(tensor) - self.register_layer(shape_layer, "shape", *layer_info) - return shape_layer.get_output(0) - - def concat(self, network, tensors, layer_info): - concat_layer = network.add_concatenation(tensors) - self.register_layer(concat_layer, "concat", *layer_info) - return concat_layer.get_output(0) - - def flatten(self, network, tensor, layer_info): - shuffle_layer = network.add_shuffle(tensor) - shuffle_layer.reshape_dims = [-1] - shuffle_layer.zero_is_placeholder = False - self.register_layer(shuffle_layer, "flatten", *layer_info) - return shuffle_layer.get_output(0) - - def reshape(self, network, tensor, reshape_dims, layer_info): - reshape_layer = network.add_shuffle(tensor) - reshape_layer.set_input(1, reshape_dims) - reshape_layer.zero_is_placeholder = False - self.register_layer(reshape_layer, "reshape", *layer_info) - return reshape_layer.get_output(0) - - def cast(self, network, tensor, dtype, layer_info): - if tensor.dtype == dtype: - return tensor - cast_layer = network.add_cast(tensor, dtype) - self.register_layer(cast_layer, "cast", *layer_info) - return cast_layer.get_output(0) - - def const_int(self, network, name, value, layer_info): - const_layer = network.add_constant( - [1], np.array([value], dtype=trt_dtype_to_np(default_int_dtype))) - self.register_layer(const_layer, name, *layer_info) - return const_layer.get_output(0) - - def get_dim_size(self, network, tensor, dim, layer_info, shape_tensor=None): - raw_shape = tensor.shape - dim_size = raw_shape[dim] - if dim_size != -1: - return dim_size - else: - if shape_tensor is None: - shape_tensor = self.get_shape(network, tensor, layer_info) - return self.get_item(network, shape_tensor, dim, layer_info) - - def add_split(self, context: GraphContext, input_name, output_name, - device_ids, shard_dims, device_dims): - it = np.nditer(device_ids, flags=['multi_index']) - for device_id in it: - device_id = device_id.item() - layer_info = (input_name, output_name, device_id) - network = self.get_network(device_id) - input_tensor = self.get_tensor(context, input_name, - device_id).as_trt() - raw_input_shape = input_tensor.shape - start = [] - output_dims = [] - stride = [] - input_shape_tensor = self.get_shape(network, input_tensor, - layer_info) - for dim in range(len(raw_input_shape)): - stride.append(1) - if dim not in shard_dims: - start.append(0) - output_dims.append( - self.get_item(network, input_shape_tensor, dim, - layer_info)) - else: - start.append(None) - output_dims.append(None) - - for dim, device_dim in zip(shard_dims, device_dims): - partition = get_partition(device_dim, device_ids) - index = get_index(device_dim, it) - input_dim = raw_input_shape[dim] - assert input_dim != -1 - assert input_dim % partition == 0 - quotient = input_dim // partition - start[dim] = index * quotient - output_dims[dim] = self.const_int(network, f"output_dim{dim}", - quotient, layer_info) - context.set_split_info(input_name, device_id, dim, - SplitInfo(input_dim, partition)) - output_dims_tensor = self.concat(network, output_dims, layer_info) - split_layer = network.add_slice(input_tensor, start, [], stride) - split_layer.set_input(2, output_dims_tensor) - wrapped_layer = self.register_layer(split_layer, "split", - *layer_info) - wrapped_layer.attrs["strategy"] = get_sharding_sequence( - len(raw_input_shape), - shard_dims, - device_dims, - ) - - output_tensor = split_layer.get_output(0) - context.update_name_mapping(input_name, device_id, - output_tensor.name) - - def add_all_gather(self, - context: GraphContext, - input_name, - output_name, - device_ids, - gather_dims, - device_dims, - is_singleton=False): - to_gather_tensors = [] - for device_id in np.nditer(device_ids): - device_id = device_id.item() - layer_info = (input_name, output_name, device_id) - network = self.get_network(device_id) - input_tensor = self.get_tensor(context, input_name, - device_id).as_trt() - to_gather_tensor = self.flatten(network, input_tensor, layer_info) - to_gather_tensors.append(to_gather_tensor) - - all_gather_layers = self.add_all_gather_layer( - context, - input_name, - output_name, - device_ids, - to_gather_tensors, - ) - - if len(device_dims) == 1: - gather_indices = [0] - elif len(device_dims) == 2 and device_dims[0] == [1]: - gather_indices = [1, 0] - else: - gather_indices = [0, 1] - - for device_id, all_gather_layer in zip(np.nditer(device_ids), - all_gather_layers): - device_id = device_id.item() - layer_info = (input_name, output_name, device_id) - network = self.get_network(device_id) - input_tensor = self.get_tensor(context, input_name, - device_id).as_trt() - permutation = [] - gathered_dims = [] - output_dims = [] - partitions = [] - raw_input_shape = input_tensor.shape - - wrapped_layer = self.get_graph(device_id).get_layer( - all_gather_layer.name) - wrapped_layer.attrs["strategy"] = get_sharding_sequence( - len(raw_input_shape), - gather_dims, - device_dims, - ) - - input_shape_layer = network.add_shape(input_tensor) - self.register_layer(input_shape_layer, "input_shape", *layer_info) - input_shape_tensor = input_shape_layer.get_output(0) - split_infos = context.get_split_infos(input_name, device_id) - for index in gather_indices: - gather_dim = gather_dims[index] - device_dim = device_dims[index] - partition = get_partition(device_dim, device_ids) - assert partition == split_infos[gather_dim].partition - partitions.append( - self.const_int(network, f"partition_num{gather_dim}", - partition, layer_info)) - for dim in range(len(raw_input_shape)): - if dim in gather_dims: - gather_index = gather_dims.index(dim) - device_dim = device_dims[gather_index] - permutation.append(gather_indices.index(gather_index)) - permutation.append(dim + len(gather_dims)) - if dim not in split_infos: - output_dim_layer = network.add_slice( - input_shape_tensor, [dim], [1], [1]) - self.register_layer(output_dim_layer, f"output_dim{dim}", - *layer_info) - dim_tensor = output_dim_layer.get_output(0) - output_dims.append(dim_tensor) - gathered_dims.append(dim_tensor) - else: - input_dim = split_infos[dim].input_dim - partition = split_infos[dim].partition - assert input_dim != -1 - assert input_dim % partition == 0 - quotient = input_dim // partition - output_dims.append( - self.const_int(network, f"output_dim{dim}", quotient, - layer_info)) - if dim in gather_dims: - gathered_dims.append( - self.const_int(network, f"gathered_dim{dim}", - quotient * partition, layer_info)) - del split_infos[dim] - else: - gathered_dims.append(output_dim_layer.get_output(0)) - - reshape_dims_for_transpose_layer = network.add_concatenation( - [*partitions, *output_dims]) - self.register_layer(reshape_dims_for_transpose_layer, - "reshape_dims_for_transpose", *layer_info) - reshape_dims_tensor = reshape_dims_for_transpose_layer.get_output(0) - transpose_layer = network.add_shuffle( - all_gather_layer.get_output(0)) - transpose_layer.set_input(1, reshape_dims_tensor) - transpose_layer.second_transpose = permutation - transpose_layer.zero_is_placeholder = False - self.register_layer(transpose_layer, "transpose", *layer_info) - - reshape_dims_for_reshape_layer = network.add_concatenation( - gathered_dims) - self.register_layer(reshape_dims_for_reshape_layer, - "reshape_dims_for_reshape", *layer_info) - reshape_dims_tensor = reshape_dims_for_reshape_layer.get_output(0) - output_tensor = self.reshape( - network, - transpose_layer.get_output(0), - reshape_dims_tensor, - layer_info, - ) - context.update_name_mapping(input_name, device_id, - output_tensor.name) - if is_singleton: - break - - def register_unfilled_weights(self, graph, layer): - if (layer.name in self.full_graph._unfilled_weights - and layer.name not in graph._unfilled_weights): - weights, values = self.full_graph._unfilled_weights[layer.name] - graph._register_unfilled_weights( - layer.name, - weights, - values, - ) - - def shard_constant(self, context: ShardContext): - sharding_spec = context.strategy.sharding_specs["output0"] - shard_dims = sharding_spec.dim_partition_dict - device_id = context.nditer.value.item() - device_ids = context.device_ids - layer = context.layer.as_trt() - graph = self.get_graph(device_id) - if len(shard_dims) == 0: - self.register_unfilled_weights(graph, layer) - return LayerUpdate(split_info_updated=True) - flatten_device_dim = list( - itertools.chain.from_iterable(shard_dims.values())) - output_name = layer.get_output(0).name - output_dtype = layer.get_output(0).dtype - output_shape = layer.shape - output_dims = [] - weight_index = [] - for dim in range(len(output_shape)): - output_dim = output_shape[dim] - if dim in shard_dims: - device_dim = shard_dims[dim] - partition = get_partition(device_dim, device_ids) - index = get_index(device_dim, context.nditer) - assert output_dim % partition == 0 - quotient = output_dim // partition - output_dims.append(quotient) - weight_index.append( - slice(index * quotient, (index + 1) * quotient)) - context.graph_context.set_split_info( - output_name, device_id, dim, - SplitInfo(output_dim, partition)) - else: - output_dims.append(output_dim) - weight_index.append(slice(None)) - if layer.name in self.full_graph._unfilled_weights: - values = self.full_graph._unfilled_weights[layer.name][1] - else: - values = layer.weights - if isinstance(values, trt.Weights): - values = values.numpy() - # TODO: remove this WAR after https://nvbugs/4359151 fixed. - if isinstance(values, trt.Weights): - network = context.layer.graph.as_trt() - values = get_np_weight(network, layer.name) - if values is not None: - values = values.reshape(layer.shape) - assert values.size == np.prod(layer.shape) - sharded_values = values[tuple(weight_index)] - assert sharded_values.size * get_partition( - flatten_device_dim, device_ids) == np.prod(layer.shape) - else: - sharded_values = None - dtype = trt_dtype_to_np(output_dtype) - sharded_weights = np.empty(tuple(output_dims), dtype) - graph._register_unfilled_weights( - f"device{device_id}_{layer.name}", - sharded_weights, - sharded_values, - ) - sharded_weights = to_trt_weights(sharded_weights) - return LayerUpdate( - updated_attrs=dict( - shape=trt.Dims(output_dims), - weights=sharded_weights, - ), - split_info_updated=True, - ) - - def shard_fill(self, context: ShardContext): - sharding_spec = context.strategy.sharding_specs["output0"] - shard_dims = sharding_spec.dim_partition_dict - if len(shard_dims) == 0: - return LayerUpdate(split_info_updated=True) - device_id = context.nditer.value.item() - device_ids = context.device_ids - layer = context.layer.as_trt() - output_name = layer.get_output(0).name - output_shape = layer.shape - output_dims = [] - for dim in range(len(output_shape)): - output_dim = output_shape[dim] - if dim in shard_dims: - device_dim = shard_dims[dim] - partition = get_partition(device_dim, device_ids) - assert output_dim % partition == 0 - quotient = output_dim // partition - output_dims.append(quotient) - context.graph_context.set_split_info( - output_name, device_id, dim, - SplitInfo(output_dim, partition)) - else: - output_dims.append(output_dim) - return LayerUpdate( - updated_attrs=dict(shape=trt.Dims(output_dims), ), - split_info_updated=True, - ) - - def update_shape(self, context: ShardContext): - if not context.layer.is_shape_io: - return - layer = context.layer.as_trt() - input_name = layer.get_input(0).name - output_name = layer.get_output(0).name - device_id = context.nditer.value.item() - layer_info = (output_name, None, device_id) - split_infos = context.graph_context.get_split_infos( - input_name, device_id) - if len(split_infos) == 0: - return - network = self.get_network(device_id) - shape_tensor = self.get_tensor(context.graph_context, output_name, - device_id).as_trt() - output_dims = [] - for dim in range(len(context.layer.get_input(0).shape)): - if dim not in split_infos: - output_dim_layer = network.add_slice(shape_tensor, [dim], [1], - [1]) - else: - input_dim = split_infos[dim].input_dim - output_dim_layer = network.add_constant( - [1], np.array([input_dim], dtype=default_int_dtype)) - self.register_layer(output_dim_layer, f"output_dim{dim}", - *layer_info) - output_dims.append(output_dim_layer.get_output(0)) - new_shape_layer = network.add_concatenation(output_dims) - self.register_layer(new_shape_layer, "new_shape", *layer_info) - new_shape_tensor = new_shape_layer.get_output(0) - context.graph_context.update_name_mapping(output_name, device_id, - new_shape_tensor.name) - - def shard_slice(self, context: ShardContext): - sharding_spec = context.strategy.sharding_specs["output0"] - shard_dims = sharding_spec.dim_partition_dict - if len(shard_dims) == 0: - return LayerUpdate.none() - device_id = context.nditer.value.item() - network = self.get_network(device_id) - device_ids = context.device_ids - layer = context.layer.as_trt() - output_dims = [] - updated_attrs = {} - updated_inputs = {} - - if layer.num_inputs >= 3: - raw_output_shape = layer.get_output(0).shape - input_name = layer.get_input(2).name - layer_info = (input_name, layer.name, device_id) - shape_tensor = self.get_tensor(context.graph_context, input_name, - device_id).as_trt() - for dim in range(len(raw_output_shape)): - output_dim_layer = network.add_slice(shape_tensor, [dim], [1], - [1]) - self.register_layer(output_dim_layer, f"output_dim{dim}", - *layer_info) - if dim in shard_dims: - device_dim = shard_dims[dim] - partition = get_partition(device_dim, device_ids) - partition_num_tensor = self.const_int( - network, f"partition_num{dim}", partition, layer_info) - quotient_layer = network.add_elementwise( - output_dim_layer.get_output(0), partition_num_tensor, - trt.ElementWiseOperation.FLOOR_DIV) - self.register_layer(quotient_layer, f"quotient{dim}", - *layer_info) - output_dim = self.cast(network, - quotient_layer.get_output(0), - default_int_dtype, layer_info) - output_dims.append(output_dim) - else: - output_dims.append(output_dim_layer.get_output(0)) - output_dims_layer = network.add_concatenation(output_dims) - self.register_layer(output_dims_layer, "output_dims", *layer_info) - updated_inputs[2] = output_dims_layer.get_output(0) - else: - output_shape = layer.shape - for dim in range(len(output_shape)): - output_dim = output_shape[dim] - assert output_dim != -1 - if dim in shard_dims: - device_dim = shard_dims[dim] - partition = get_partition(device_dim, device_ids) - assert output_dim % partition == 0 - quotient = output_dim // partition - output_dims.append(quotient) - else: - output_dims.append(output_dim) - updated_attrs["shape"] = trt.Dims(output_dims) - return LayerUpdate(updated_attrs, updated_inputs) - - def shard_shuffle(self, context: ShardContext): - sharding_spec = context.strategy.sharding_specs["output0"] - shard_dims = sharding_spec.dim_partition_dict - if len(shard_dims) == 0: - return LayerUpdate.none() - device_id = context.nditer.value.item() - network = self.get_network(device_id) - device_ids = context.device_ids - layer = context.layer.as_trt() - updated_attrs = {} - updated_inputs = {} - updated_reshape_dims = {} - second_transpose = layer.second_transpose - - if layer.num_inputs >= 2: - raw_output_shape = layer.get_output(0).shape - input_name = layer.get_input(1).name - layer_info = (input_name, layer.name, device_id) - reshape_dims_tensor = self.get_tensor(context.graph_context, - input_name, device_id) - reshape_dims = context.layer.get_input(1).value - reshape_dims_tensor = reshape_dims_tensor.as_trt() - for dim in range(len(raw_output_shape)): - if second_transpose is not None: - reshape_dim = second_transpose[dim] - else: - reshape_dim = dim - output_dim_layer = network.add_slice(reshape_dims_tensor, - [reshape_dim], [1], [1]) - self.register_layer(output_dim_layer, f"output_dim{dim}", - *layer_info) - output_dim = reshape_dims[reshape_dim] - if dim in shard_dims and output_dim != -1: - device_dim = shard_dims[dim] - partition = get_partition(device_dim, device_ids) - partition_num_tensor = self.const_int( - network, f"partition_num{dim}", partition, layer_info) - quotient_layer = network.add_elementwise( - output_dim_layer.get_output(0), partition_num_tensor, - trt.ElementWiseOperation.FLOOR_DIV) - self.register_layer(quotient_layer, f"quotient{dim}", - *layer_info) - updated_reshape_dims[reshape_dim] = self.cast( - network, - quotient_layer.get_output(0), - default_int_dtype, - layer_info, - ) - else: - updated_reshape_dims[ - reshape_dim] = output_dim_layer.get_output(0) - updated_reshape_dims = list( - map(lambda x: x[1], sorted(updated_reshape_dims.items()))) - reshape_dims_layer = network.add_concatenation(updated_reshape_dims) - self.register_layer(reshape_dims_layer, "reshape_dims", *layer_info) - updated_inputs[1] = reshape_dims_layer.get_output(0) - else: - reshape_dims = layer.reshape_dims - if reshape_dims.__len__() < 0: - return LayerUpdate.none() - for dim in range(len(reshape_dims)): - if second_transpose is not None: - reshape_dim = second_transpose[dim] - else: - reshape_dim = dim - output_dim = reshape_dims[reshape_dim] - if dim in shard_dims and output_dim != -1: - device_dim = shard_dims[dim] - partition = get_partition(device_dim, device_ids) - quotient = output_dim // partition - updated_reshape_dims[reshape_dim] = quotient - else: - updated_reshape_dims[reshape_dim] = output_dim - updated_reshape_dims = list( - map(lambda x: x[1], sorted(updated_reshape_dims.items()))) - updated_attrs["reshape_dims"] = trt.Dims(updated_reshape_dims) - return LayerUpdate(updated_attrs, updated_inputs) - - def shard_gpt_attention(self, context: ShardContext): - layer = context.layer.as_trt() - plugin_info = get_plugin_info( - self.full_graph.as_trt(), - layer.name, - ) - parser = IdxEntryParser(plugin_info) - head_dim = 1 if parser.remove_input_padding else 2 - sharding_spec = context.strategy.sharding_specs[ - f"input{parser.get_index(IdxEntry.QKV_TENSOR)}"] - shard_dims = sharding_spec.dim_partition_dict - if head_dim not in shard_dims: - return LayerUpdate.none() - device_id = context.nditer.value.item() - network = self.get_network(device_id) - device_ids = context.device_ids - updated_attrs = {} - updated_inputs = {} - device_dim = shard_dims[head_dim] - partition = get_partition(device_dim, device_ids) - index = get_index(device_dim, context.nditer) - if parser.is_entry_used(IdxEntry.K_TENSOR): - kv_sharding_spec = context.strategy.sharding_specs[ - f"input{parser.get_index(IdxEntry.K_TENSOR)}"] - kv_shard_dims = kv_sharding_spec.dim_partition_dict - if head_dim in kv_shard_dims: - kv_device_dim = kv_shard_dims[head_dim] - kv_partition = get_partition(kv_device_dim, device_ids) - else: - kv_partition = 1 - else: - kv_partition = 1 - num_heads = plugin_info.pfc_as_ndarray["num_heads"].copy() - num_kv_heads = plugin_info.pfc_as_ndarray["num_kv_heads"].copy() - tp_size = plugin_info.pfc_as_ndarray["tp_size"].copy() - tp_rank = plugin_info.pfc_as_ndarray["tp_rank"].copy() - num_kv_heads = np.maximum(num_kv_heads // kv_partition, 1) - num_heads = np.maximum(num_heads // partition, 1) - tp_size[0] = partition - tp_rank[0] = index - - new_plugin, new_plugin_info = get_updated_plugin( - plugin_info, - dict( - num_heads=num_heads, - num_kv_heads=num_kv_heads, - tp_size=tp_size, - tp_rank=tp_rank, - )) - prefix = self.get_prefix(device_id) - new_layer_name = f"{prefix}{layer.name}" - set_plugin_info(network, new_layer_name, new_plugin_info) - updated_attrs["plugin"] = new_plugin - return LayerUpdate(updated_attrs, updated_inputs) - - def shard_lookup(self, context: ShardContext): - sharding_spec = context.strategy.sharding_specs["input1"] - shard_dims = sharding_spec.dim_partition_dict - if 0 not in shard_dims: - return LayerUpdate.none() - layer = context.layer.as_trt() - plugin_info = get_plugin_info( - self.full_graph.as_trt(), - layer.name, - ) - device_id = context.nditer.value.item() - network = self.get_network(device_id) - updated_attrs = {} - device_dim = shard_dims[0] - index = get_index(device_dim, context.nditer) - rank = plugin_info.pfc_as_ndarray["rank"].copy() - rank[0] = index - - new_plugin, new_plugin_info = get_updated_plugin( - plugin_info, dict(rank=rank, )) - prefix = self.get_prefix(device_id) - new_layer_name = f"{prefix}{layer.name}" - set_plugin_info(network, new_layer_name, new_plugin_info) - updated_attrs["plugin"] = new_plugin - return LayerUpdate(updated_attrs) - - -class GraphGroupBase(GraphGroup): - - def __init__( - self, - full_graph: PipelineGraph, - config: ParallelConfig, - auto_parallel_config: AutoParallelConfig, - ) -> None: - self._full_graph = full_graph - self.config = config - self._auto_parallel_config = auto_parallel_config - self.infer_shape = auto_parallel_config.infer_shape - self.global_context = GraphContext() - self.shape_cache = {} - self.suffix = 0 - self.current_block_id = -1 - - @property - def auto_parallel_config(self) -> AutoParallelConfig: - return self._auto_parallel_config - - @property - def full_graph(self) -> PipelineGraph: - return self._full_graph - - def register_layer(self, - layer, - base_name, - input_name, - output_name=None, - device_id=None, - keep_tensor_name=False) -> Layer: - layer_name = f"{base_name}_{input_name}" - if device_id is not None: - layer_name = f"{self.get_prefix(device_id)}{layer_name}" - if output_name is not None: - layer_name = f"{layer_name}_to_{output_name}" - suffix = self.suffix - self.suffix += 1 - layer_name = f"{layer_name}_{suffix}" - if layer.type == trt.LayerType.PLUGIN_V2: - network = self.get_network(device_id) - plugin_info = get_plugin_info(network, layer.name) - if plugin_info is not None: - set_plugin_info(network, layer_name, plugin_info) - delete_plugin_info(network, layer.name) - layer.name = layer_name - layer.metadata = layer.name - if not keep_tensor_name: - for i in range(layer.num_outputs): - output_tensor = layer.get_output(i) - assert output_tensor.shape.__len__() >= 0 - output_tensor.name = f"{layer.name}_output_{i}" - wrapped_layer = self.get_graph(device_id).register_layer(layer) - if self.current_block_id != -1: - wrapped_layer.attrs["block_id"] = self.current_block_id - wrapped_layer.attrs["role"] = "helper" - if self.infer_shape: - infer_per_layer_shapes( - layer, - self.get_shapes(device_id), - self.get_values(device_id), - self.shape_cache, - is_shape_io=True, - ) - wrapped_layer.assign_shapes( - self.get_shapes(device_id), - self.get_values(device_id), - ) - return wrapped_layer - - def add_layer(self, wrapped_layer: Layer, device_ids, - strategy: ShardingStrategy): - layer = wrapped_layer.as_trt() - local_context = self.global_context.get_local_context(layer) - self.current_block_id = wrapped_layer.attrs["block_id"] - - for i, input in enumerate(wrapped_layer.inputs): - if input is None: - continue - if i not in strategy.best_resharding_cost: - continue - comm_action_sequence = strategy.best_resharding_cost[i][0][1] - for commspec in comm_action_sequence: - self.add_comm(local_context, - input.name, - device_ids, - commspec, - output_name=layer.name) - - it = np.nditer(device_ids, flags=['multi_index']) - for device_id in it: - device_id = device_id.item() - - layer_type = layer.type - to_subclass_layer(layer) - shard_context = ShardContext( - local_context, - wrapped_layer, - it, - device_ids, - strategy, - ) - if layer_type == trt.LayerType.CONSTANT: - layer_update = self.shard_constant(shard_context) - elif layer_type == trt.LayerType.FILL: - layer_update = self.shard_fill(shard_context) - elif layer_type == trt.LayerType.SLICE: - layer_update = self.shard_slice(shard_context) - elif layer_type == trt.LayerType.SHUFFLE: - layer_update = self.shard_shuffle(shard_context) - elif layer_type == trt.LayerType.PLUGIN_V2: - if layer.plugin.plugin_type == "GPTAttention": - layer_update = self.shard_gpt_attention(shard_context) - elif layer.plugin.plugin_type == "Lookup": - layer_update = self.shard_lookup(shard_context) - else: - layer_update = LayerUpdate.none() - else: - layer_update = LayerUpdate.none() - to_base_class_layer(layer) - - for i, updated_input in layer_update.updated_inputs.items(): - input_name = layer.get_input(i).name - local_context.update_name_mapping(input_name, device_id, - updated_input.name) - if layer.get_input(i).dtype != updated_input.dtype: - raise ValueError( - f"Input dtype mismatch for {layer.name}, " - f"expect {layer.get_input(i).dtype} for {input_name}, " - f"get {updated_input.dtype} for {updated_input.name}") - - prefix = self.get_prefix(device_id) - new_wrapped_layer = self.get_graph(device_id).add_layer( - layer, - prefix=prefix, - input_mapping=local_context.get_name_mapping(device_id, - prefix=prefix), - updated_attrs=layer_update.updated_attrs, - ) - new_wrapped_layer.attrs["strategy"] = strategy.name - new_wrapped_layer.attrs["block_id"] = self.current_block_id - new_layer = new_wrapped_layer.as_trt() - - if self.infer_shape: - infer_per_layer_shapes( - new_layer, - self.get_shapes(device_id), - self.get_values(device_id), - self.shape_cache, - is_shape_io=wrapped_layer.is_shape_io, - ) - new_wrapped_layer.assign_shapes( - self.get_shapes(device_id), - self.get_values(device_id), - ) - - for i in range(layer.num_outputs): - output_tensor = new_layer.get_output(i) - assert output_tensor.shape.__len__() >= 0 - local_context.update_name_mapping( - layer.get_output(i).name, device_id, output_tensor.name) - - if layer.type == trt.LayerType.SHAPE: - self.update_shape(shard_context) - - self.global_context.update_layer_context( - wrapped_layer, - layer_update, - local_context, - device_id, - device_ids, - strategy.sharding_specs, - ) - - for i in range(layer.num_outputs): - commspecs = strategy.communication_actions.get(f"output{i}") - if commspecs is None: - continue - output = layer.get_output(i) - for commspec in commspecs: - self.add_comm( - self.global_context, - output.name, - device_ids, - commspec, - ) - - self.current_block_id = -1 - - -class DistributedGraphGroup(GraphGroupBase): - - def __init__( - self, - full_graph: PipelineGraph, - config: ParallelConfig, - auto_parallel_config: AutoParallelConfig, - ) -> None: - super().__init__(full_graph, config, auto_parallel_config) - self.graphs = {} - self.io_tensor_shards = {} - self.shapes_by_device = {} - self.values_by_device = {} - phy_mesh = config.graph_config.phy_mesh - device_ids = phy_mesh.phy_devices_id - for device_id in np.nditer(device_ids): - device_id = device_id.item() - graph = PipelineGraph.create_graph() - graph._auto_parallel_config = { - "io_shards": {}, - "mapping": - Mapping( - world_size=device_ids.size, - rank=device_id, - gpus_per_node=device_ids.shape[1], - tp_size=device_ids.size // config.graph_config.num_stages, - pp_size=config.graph_config.num_stages, - ), - } - self.graphs[device_id] = graph - self.shapes_by_device[device_id] = {} - self.values_by_device[device_id] = {} - - @contextlib.contextmanager - def disable_infer_shape(self): - infer_shape = self.infer_shape - self.infer_shape = False - yield - self.infer_shape = infer_shape - - def get_network(self, device_id) -> trt.INetworkDefinition: - return self.graphs[device_id].as_trt() - - def get_graph(self, device_id) -> PipelineGraph: - return self.graphs[device_id] - - def get_prefix(self, device_id) -> str: - return "" - - def get_shapes(self, device_id) -> Dict[str, Tuple[int, ...]]: - return self.shapes_by_device[device_id] - - def get_values(self, device_id) -> Dict[str, List[int]]: - return self.values_by_device[device_id] - - def add_reduce_scatter(self, context: GraphContext, input_name, output_name, - device_ids, shard_dims, device_dims): - dtype = str_dtype_to_trt(self.full_graph._plugin_config.dtype) - it = np.nditer(device_ids, flags=['multi_index']) - for device_id in it: - device_id = device_id.item() - layer_info = (input_name, output_name, device_id) - network = self.get_network(device_id) - input_tensor = self.get_tensor(context, input_name, - device_id).as_trt() - raw_input_shape = input_tensor.shape - input_shape_tensor = self.get_shape(network, input_tensor, - layer_info) - if shard_dims != [0]: - permutation = list(range(len(raw_input_shape))) - for dim in shard_dims: - permutation.remove(dim) - permutation = shard_dims + permutation - transpose_layer = network.add_shuffle(input_tensor) - transpose_layer.second_transpose = permutation - self.register_layer(transpose_layer, "input_transpose", - *layer_info) - input_tensor = transpose_layer.get_output(0) - flatten_tensor = self.flatten(network, input_tensor, layer_info) - input_dtype = flatten_tensor.dtype - if input_dtype != dtype: - to_reduce_tensor = self.cast( - network, - flatten_tensor, - dtype, - layer_info, - ) - else: - to_reduce_tensor = flatten_tensor - - reduce_scatter_plg_creator = trt.get_plugin_registry( - ).get_plugin_creator('ReduceScatter', '1', TRT_LLM_PLUGIN_NAMESPACE) - assert reduce_scatter_plg_creator is not None - - group = trt.PluginField( - "group", - np.ascontiguousarray(device_ids.reshape(-1).astype(np.int32)), - trt.PluginFieldType.INT32) - pf_type = trt.PluginField( - "type_id", np.array([int(to_reduce_tensor.dtype)], np.int32), - trt.PluginFieldType.INT32) - - pfc = trt.PluginFieldCollection([group, pf_type]) - rs_plug = reduce_scatter_plg_creator.create_plugin( - "reduce_scatter", pfc) - - reduce_scatter_layer = network.add_plugin_v2([to_reduce_tensor], - rs_plug) - plugin_info = PluginInfo(reduce_scatter_plg_creator, - "reduce_scatter", pfc) - set_plugin_info(network, reduce_scatter_layer.name, plugin_info) - with self.disable_infer_shape(): - wrapped_tensor = self.register_layer( - reduce_scatter_layer, - "reduce_scatter", - *layer_info, - ).get_output(0) - reduce_scatter_tensor = reduce_scatter_layer.get_output(0) - if self.infer_shape: - shape = self.shapes_by_device[device_id][to_reduce_tensor.name] - assert len(shape) == 1 - output_shape = (shape[0] // device_ids.size, ) - self.shapes_by_device[device_id][ - reduce_scatter_tensor.name] = output_shape - wrapped_tensor.shape = output_shape - if input_dtype != dtype: - reduce_scatter_tensor = self.cast( - network, - reduce_scatter_tensor, - input_dtype, - layer_info, - ) - - start = [] - output_dims = [] - stride = [] - for dim in range(len(raw_input_shape)): - stride.append(1) - if dim not in shard_dims: - start.append(0) - output_dims.append( - self.get_item(network, input_shape_tensor, dim, - layer_info)) - else: - start.append(None) - output_dims.append(None) - - for dim, device_dim in zip(shard_dims, device_dims): - partition = get_partition(device_dim, device_ids) - index = get_index(device_dim, it) - input_dim = raw_input_shape[dim] - assert input_dim != -1 - assert input_dim % partition == 0 - quotient = input_dim // partition - start[dim] = index * quotient - output_dims[dim] = self.const_int(network, f"output_dim{dim}", - quotient, layer_info) - context.set_split_info(input_name, device_id, dim, - SplitInfo(input_dim, partition)) - if shard_dims != [0]: - output_dims = [ - output_dims[permutation[i]] for i in range(len(output_dims)) - ] - output_dims_tensor = self.concat(network, output_dims, layer_info) - output_tensor = self.reshape( - network, - reduce_scatter_tensor, - output_dims_tensor, - layer_info, - ) - if shard_dims != [0]: - transpose_layer = network.add_shuffle(output_tensor) - transpose_layer.second_transpose = permutation - self.register_layer(transpose_layer, "output_transpose", - *layer_info) - output_tensor = transpose_layer.get_output(0) - context.update_name_mapping(input_name, device_id, - output_tensor.name) - - def add_all_reduce_layer(self, context: GraphContext, input_name, - output_name, device_ids, to_reduce_tensors): - for device_id, to_reduce_tensor in zip(np.nditer(device_ids), - to_reduce_tensors): - device_id = device_id.item() - layer_info = (input_name, output_name, device_id) - network = self.get_network(device_id) - graph = self.get_graph(device_id) - workspace = graph.get_input("all_reduce_workspace").as_trt() - - all_reduce_layer, allreduce_plg_creator, pfc = create_allreduce_plugin( - network=network, - tensor=to_reduce_tensor, - workspace=workspace, - group=np.ascontiguousarray( - device_ids.reshape(-1).astype(np.int32)), - dtype=to_reduce_tensor.dtype, - all_reduce_params=AllReduceParams(), - ) - plugin_info = PluginInfo(allreduce_plg_creator, "allreduce", pfc) - set_plugin_info(network, all_reduce_layer.name, plugin_info) - with self.disable_infer_shape(): - wrapped_tensor = self.register_layer( - all_reduce_layer, - "all_reduce", - *layer_info, - ).get_output(0) - output_tensor = all_reduce_layer.get_output(0) - if self.infer_shape: - shape = self.shapes_by_device[device_id][to_reduce_tensor.name] - self.shapes_by_device[device_id][output_tensor.name] = shape - wrapped_tensor.shape = shape - context.update_name_mapping(input_name, device_id, - output_tensor.name) - - def add_all_gather_layer(self, context: GraphContext, input_name, - output_name, device_ids, to_gather_tensors): - all_gather_layers = [] - for device_id, to_gather_tensor in zip(np.nditer(device_ids), - to_gather_tensors): - device_id = device_id.item() - layer_info = (input_name, output_name, device_id) - network = self.get_network(device_id) - - allgather_plg_creator = trt.get_plugin_registry( - ).get_plugin_creator('AllGather', '1', TRT_LLM_PLUGIN_NAMESPACE) - assert allgather_plg_creator is not None - - group = trt.PluginField( - "group", - np.ascontiguousarray(device_ids.reshape(-1).astype(np.int32)), - trt.PluginFieldType.INT32) - pf_type = trt.PluginField( - "type_id", np.array([int(to_gather_tensor.dtype)], np.int32), - trt.PluginFieldType.INT32) - pfc = trt.PluginFieldCollection([group, pf_type]) - allgather = allgather_plg_creator.create_plugin("allgather", pfc) - - all_gather_layer = network.add_plugin_v2([to_gather_tensor], - allgather) - plugin_info = PluginInfo(allgather_plg_creator, "allgather", pfc) - set_plugin_info(network, all_gather_layer.name, plugin_info) - with self.disable_infer_shape(): - wrapped_tensor = self.register_layer( - all_gather_layer, - "all_gather", - *layer_info, - ).get_output(0) - if self.infer_shape: - output_tensor = all_gather_layer.get_output(0) - shape = self.shapes_by_device[device_id][to_gather_tensor.name] - assert len(shape) == 1 - output_shape = (shape[0] * device_ids.size, ) - self.shapes_by_device[device_id][ - output_tensor.name] = output_shape - wrapped_tensor.shape = output_shape - all_gather_layers.append(all_gather_layer) - return all_gather_layers - - def set_shard_num(self, tensor_name, dim, shard_num): - for graph in self.graphs.values(): - io_shards = graph._auto_parallel_config["io_shards"] - if tensor_name not in io_shards: - io_shards[tensor_name] = {} - io_shards[tensor_name][dim] = shard_num - - def add_input(self, tensor: Tensor, device_ids, strategy: ShardingStrategy): - context = self.global_context - sharding_spec = strategy.sharding_specs["output0"] - shard_dims = sharding_spec.dim_partition_dict - for dim, device_dim in shard_dims.items(): - partition = get_partition(device_dim, device_ids) - self.set_shard_num(tensor.name, dim, partition) - for device_id in np.nditer(device_ids): - device_id = device_id.item() - graph = self.get_graph(device_id) - new_input = graph.add_input(tensor.as_trt()) - shape = [*tensor.shape] - if len(shard_dims) != 0: - output_shape = [*tensor.raw_shape] - for dim, device_dim in shard_dims.items(): - partition = get_partition(device_dim, device_ids) - output_dim = output_shape[dim] - assert output_dim != -1 - assert output_dim % partition == 0 - quotient = output_dim // partition - output_shape[dim] = quotient - shape[dim] = quotient - assert tensor.value is None - context.set_split_info(tensor.name, device_id, dim, - SplitInfo(output_dim, partition)) - new_input.raw_shape = output_shape - context.update_name_mapping(tensor.name, device_id, tensor.name) - if self.infer_shape: - self.shapes_by_device[device_id][tensor.name] = tuple(shape) - new_input.shape = tuple(shape) - if tensor.value is not None: - self.values_by_device[device_id][tensor.name] = tensor.value - new_input.value = tensor.value - - def add_output(self, tensor: Tensor, device_ids, - strategy: ShardingStrategy): - comm_action_sequence = strategy.best_resharding_cost[0][0][1] - for commspec in comm_action_sequence: - self.add_comm(self.global_context, tensor.name, device_ids, - commspec) - for device_id in np.nditer(device_ids): - device_id = device_id.item() - graph = self.get_graph(device_id) - output_name = tensor.name - new_output_name = self.global_context.get_name( - output_name, device_id) - if new_output_name != output_name: - suffix = self.suffix - self.suffix += 1 - original_name = f"original_{output_name}_{suffix}" - original_tensor = graph.get_tensor(output_name) - original_tensor.as_trt().name = original_name - output_tensor = graph.get_tensor(new_output_name) - output_tensor.as_trt().name = output_name - graph._tensors[original_name] = original_tensor - graph._tensors[output_name] = output_tensor - del graph._tensors[new_output_name] - else: - output_tensor = graph.get_tensor(output_name) - trt_output = output_tensor.as_trt() - if trt_output.is_shape_tensor: - graph.add_output_shape(trt_output) - else: - graph.add_output(trt_output) - trt_output.dtype = tensor.dtype - if tensor.dtype != output_tensor.dtype: - raise ValueError( - f"Output dtype mismatch, " - f"expect {tensor.dtype} for {tensor.name}, " - f"get {output_tensor.dtype} for {output_tensor.name}") - - shard_dims = strategy.sharding_specs["input0"].dim_partition_dict - for dim, device_dim in shard_dims.items(): - partition = get_partition(device_dim, device_ids) - self.set_shard_num(tensor.name, dim, partition) - - -class PrefixedGraphGroup(GraphGroupBase): - - def __init__( - self, - full_graph: PipelineGraph = None, - config: ParallelConfig = None, - auto_parallel_config: AutoParallelConfig = None, - ) -> None: - auto_parallel_config = auto_parallel_config or dict( - infer_shape=False, - validation_mode=False, - ) - super().__init__(full_graph, config, auto_parallel_config) - self.validation_mode = auto_parallel_config.validation_mode - if not self.infer_shape: - self.validation_mode = False - self.prefixed_graph = PipelineGraph.create_graph() - if self.validation_mode: - self.layer_mapping = config.graph_config.graph_mapping.layer_mapping - self.graph_strategy = config.graph_strategy - self.shapes = {} - self.values = {} - self.timing_cache = None - - def get_network(self, device_id) -> trt.INetworkDefinition: - return self.prefixed_graph.as_trt() - - def get_graph(self, device_id) -> PipelineGraph: - return self.prefixed_graph - - def get_prefix(self, device_id) -> str: - return f"device{device_id}_" - - def get_shapes(self, device_id) -> Dict[str, Tuple[int, ...]]: - return self.shapes - - def get_values(self, device_id) -> Dict[str, List[int]]: - return self.values - - def add_all_reduce_layer(self, context: GraphContext, input_name, - output_name, device_ids, to_reduce_tensors): - reshaped_tensors = [] - for device_id, to_reduce_tensor in zip(np.nditer(device_ids), - to_reduce_tensors): - device_id = device_id.item() - layer_info = (input_name, output_name, device_id) - network = self.get_network(device_id) - reshape_dims_tensor = self.concat( - network, - [ - self.get_shape(network, to_reduce_tensor, layer_info), - self.const_int(network, "expanded_dim", 1, layer_info) - ], - layer_info, - ) - reshaped_tensor = self.reshape( - network, - to_reduce_tensor, - reshape_dims_tensor, - layer_info, - ) - reshaped_tensors.append(reshaped_tensor) - - for device_id in np.nditer(device_ids): - device_id = device_id.item() - layer_info = (input_name, output_name, device_id) - input_tensor = self.get_tensor(context, input_name, 0).as_trt() - num_dims = len(input_tensor.shape) - network = self.get_network(device_id) - concat_layer = network.add_concatenation(reshaped_tensors) - concat_layer.axis = num_dims - self.register_layer(concat_layer, "concat", *layer_info) - reduce_layer = network.add_reduce(concat_layer.get_output(0), - trt.ReduceOperation.SUM, - axes=1 << num_dims, - keep_dims=False) - dtype = to_reduce_tensors[0].dtype - reduce_layer.precision = dtype - reduce_layer.set_output_type(0, dtype) - self.register_layer(reduce_layer, "reduce", *layer_info) - output_tensor = reduce_layer.get_output(0) - - context.update_name_mapping(input_name, device_id, - output_tensor.name) - - def add_all_gather_layer(self, context: GraphContext, input_name, - output_name, device_ids, to_gather_tensors): - all_gather_layers = [] - for device_id in np.nditer(device_ids): - device_id = device_id.item() - layer_info = (input_name, output_name, device_id) - network = self.get_network(device_id) - all_gather_layer = network.add_concatenation(to_gather_tensors) - all_gather_layer.axis = 0 - self.register_layer(all_gather_layer, "all_gather", *layer_info) - all_gather_layers.append(all_gather_layer) - return all_gather_layers - - def add_input(self, tensor: Tensor, device_ids, strategy: ShardingStrategy): - - def add_identity(): - identity_layer = network.add_identity(input.as_trt()) - return identity_layer - - input = self.prefixed_graph.add_input(tensor.as_trt()) - if self.infer_shape: - self.shapes[tensor.name] = tensor.shape - input.shape = tensor.shape - if tensor.value is not None: - self.values[tensor.name] = tensor.value - input.value = tensor.value - network = self.get_network(None) - if self.validation_mode: - identity_layer = add_identity() - identity_layer.get_output(0).name = f"ref_{tensor.name}" - layer_info = (tensor.name, None, None) - self.register_layer(identity_layer, - "identity", - *layer_info, - keep_tensor_name=True) - input.attrs["strategy"] = strategy.name - sharding_spec = strategy.sharding_specs["output0"] - pre_sharding_sepc = get_full_sharding_spec(sharding_spec) - comm_action_sequence = get_comm_action_sequence(pre_sharding_sepc, - sharding_spec) - context = self.global_context - for device_id in np.nditer(device_ids): - device_id = device_id.item() - layer_info = (tensor.name, None, device_id) - context.update_name_mapping(tensor.name, device_id, tensor.name) - if len(comm_action_sequence - ) == 0 and not tensor.as_trt().is_shape_tensor: - identity_layer = add_identity() - self.register_layer(identity_layer, "identity", *layer_info) - context.update_name_mapping( - tensor.name, - device_id, - identity_layer.get_output(0).name, - ) - for commspec in comm_action_sequence: - self.add_comm(context, tensor.name, device_ids, commspec) - - def get_graph_in_range(self, graph_group, src_layer, layer_range, - device_ids, shapes, values): - src_network = self.prefixed_graph.as_trt() - graph = graph_group.prefixed_graph - network = graph.as_trt() - input_mapping = {} - for device_id in np.nditer(device_ids): - device_id = device_id.item() - for i in range(src_layer.num_inputs): - src_input = src_layer.get_input(i) - if src_input is not None: - input = self.get_tensor( - self.global_context, - src_input.name, - device_id, - ).as_trt() - if graph.get_input(src_input.name) is not None: - new_input = graph_group.get_tensor( - graph_group.global_context, - src_input.name, - device_id, - ).as_trt() - input_mapping[input.name] = new_input.name - continue - if graph.get_tensor(input.name) is not None: - continue - shape = shapes[input.name] - assert input.name in values - value = values[input.name] - weights = np.asarray(value, - dtype=trt_dtype_to_np(input.dtype)) - weights = to_trt_weights(weights) - input_layer = network.add_constant(shape, weights) - new_input = input_layer.get_output(0) - new_input.name = input.name - graph.register_layer(input_layer) - for i in layer_range: - layer = src_network.get_layer(i) - graph.add_layer(layer, input_mapping=input_mapping) - - def add_layer_singleton(self, output, device_ids, sharding_spec): - assert self.prefixed_graph.get_tensor(output.name) is None - network = self.prefixed_graph.as_trt() - full_sharding_sepc = get_full_sharding_spec(sharding_spec) - comm_action_sequence = get_comm_action_sequence(sharding_spec, - full_sharding_sepc) - output_context = self.global_context.get_local_context_for_output( - output) - if len(comm_action_sequence) != 0: - for commspec in comm_action_sequence[:-1]: - self.add_comm(output_context, output.name, device_ids, commspec) - self.add_comm( - output_context, - output.name, - device_ids, - comm_action_sequence[-1], - is_singleton=True, - ) - device_id = next(np.nditer(device_ids)).item() - layer_info = (output.name, None, device_id) - output_tensor = self.get_tensor(output_context, output.name, - device_id).as_trt() - singleton_layer = network.add_identity(output_tensor) - singleton_layer.get_output(0).name = output.name - self.register_layer(singleton_layer, - "singleton", - *layer_info, - keep_tensor_name=True) - - def add_layer(self, wrapped_layer: Layer, device_ids, - strategy: ShardingStrategy): - graph = self.prefixed_graph - network = graph.as_trt() - start_layer_id = network.num_layers - - super().add_layer(wrapped_layer, device_ids, strategy) - - layer = wrapped_layer.as_trt() - - if self.validation_mode: - is_shape = (wrapped_layer.is_shape_io - or layer.type == trt.LayerType.SHAPE) - - if not is_shape: - self.current_block_id = wrapped_layer.attrs["block_id"] - for i, wrapped_output in enumerate(wrapped_layer.outputs): - if wrapped_output.is_graph_output: - continue - output = wrapped_output.as_trt() - output_name = f"output{i}" - if strategy.communication_actions.get( - output_name) is not None: - output_name += "_after_comm" - sharding_spec = strategy.sharding_specs[output_name] - self.add_layer_singleton(output, device_ids, sharding_spec) - self.current_block_id = -1 - end_layer_id = network.num_layers - - is_skip = (is_shape or layer.type == trt.LayerType.CONSTANT - or layer.name in self.layer_mapping) - sharded = False - for sharding_spec in strategy.sharding_specs.values(): - if len(sharding_spec.dim_partition_dict) > 0: - sharded = True - break - if not sharded: - is_skip = True - - ref_layer = graph.add_layer(layer, prefix="ref_") - ref_layer.attrs["strategy"] = strategy.name - ref_layer.attrs["block_id"] = wrapped_layer.attrs["block_id"] - if layer.type == trt.LayerType.CONSTANT: - self.register_unfilled_weights(graph, layer) - - if is_skip: - return - - logger.debug(f"validating layer {layer.name}") - - layer_type = layer.type - generated_input_values = {} - to_subclass_layer(layer) - if layer_type == trt.LayerType.PLUGIN_V2: - if layer.plugin.plugin_type == "GPTAttention": - sharding_specs = {} - for name, sharding_spec in strategy.sharding_specs.items(): - sharding_specs[name] = get_full_sharding_spec( - sharding_spec) - plugin_info = get_plugin_info( - self.full_graph.as_trt(), - layer.name, - ) - generated_input_values = GPTAttentionPlugin.parameter_generator( - sharding_specs, plugin_info) - to_base_class_layer(layer) - - validation_graph_group = PrefixedGraphGroup() - validation_graph = validation_graph_group.prefixed_graph - validation_graph._io_buffer_mapping = self.full_graph._io_buffer_mapping - extra_input_values = {} - validation_shapes = {} - for i, wrapped_input in enumerate(wrapped_layer.inputs): - if wrapped_input is None: - continue - input = wrapped_input.as_trt() - validation_shapes[input.name] = wrapped_input.shape - if wrapped_input.value is None: - if i in generated_input_values: - extra_input_value = generated_input_values[i] - else: - extra_input_value = torch.empty( - tuple(wrapped_input.shape), - dtype=trt_dtype_to_torch(input.dtype), - device=torch.cuda.current_device(), - ) - if torch.is_floating_point(extra_input_value): - extra_input_value.normal_() - # extra_input_value[:] = random.choice([2, 3, 5, 7]) - extra_input_values[input.name] = extra_input_value - self.values[input.name] = extra_input_value - if wrapped_input.producer is not None: - node_name = wrapped_input.producer.name - output_index = wrapped_input.output_index - else: - node_name = wrapped_input.name - output_index = 0 - sharding_spec = self.graph_strategy[ - node_name].sharding_specs[f"output{output_index}"] - validation_graph_group.add_input( - wrapped_input, - device_ids, - ShardingStrategy( - sharding_specs={"output0": sharding_spec}), - ) - validation_graph.get_input( - input.name).raw_shape = wrapped_input.shape - - self.get_graph_in_range( - validation_graph_group, - layer, - range(start_layer_id, end_layer_id), - device_ids, - self.shapes, - self.values, - ) - - for i, wrapped_output in enumerate(wrapped_layer.outputs): - output = wrapped_output.as_trt() - if wrapped_output.is_graph_output: - output_name = f"output{i}" - if strategy.communication_actions.get( - output_name) is not None: - output_name += "_after_comm" - sharding_spec = strategy.sharding_specs[output_name] - validation_graph_group.global_context.merge_context( - self.global_context.get_local_context_for_output( - output)) - validation_graph_group.add_layer_singleton( - output, device_ids, sharding_spec) - validation_graph.add_output(output) - validation_shapes[output.name] = wrapped_output.shape - if not self.timing_cache: - self.timing_cache = network.builder.create_builder_config( - ).create_timing_cache(b"") - logger.debug(f"run validation graph for layer {layer.name}") - validation_runner = validation_graph.get_runner( - validation_shapes, - self.values, - timing_cache=self.timing_cache, - opt_level=0, - ) - values = validation_runner.run() - refer_input_values = {} - for wrapped_input in wrapped_layer.inputs: - if wrapped_input is None: - continue - if wrapped_input.value is not None: - refer_input_values[wrapped_input.name] = wrapped_input.value - refer_graph, output_mapping = get_per_layer_graph( - layer, - validation_shapes, - refer_input_values, - is_shape_io=False, - ) - refer_graph._io_buffer_mapping = self.full_graph._io_buffer_mapping - for proxy_output, output in output_mapping.items(): - validation_shapes[proxy_output] = validation_shapes[output] - logger.debug(f"run refer graph for layer {layer.name}") - refer_runner = refer_graph.get_runner( - validation_shapes, - self.values, - timing_cache=self.timing_cache, - opt_level=0, - ) - refer_outputs = refer_runner.run() - for name, refer_output in refer_outputs.items(): - if name in output_mapping: - refer_output = refer_output.bool() - output = values[name] - # ∣output−refer_output∣ <= atol+rtol*∣refer_output∣ - atol = 1e-02 - rtol = 1e-02 - if not torch.allclose( - output, - refer_output, - rtol=rtol, - atol=atol, - equal_nan=True, - ): - size = output.nelement() - diff = (output - refer_output).abs() - diff_index = (~torch.isnan(diff)) & (diff > ( - atol + rtol * refer_output.abs())) - diff_output = diff[diff_index] - diff_size = diff_output.nelement() - logger.warning( - f"output {name} of {layer.name} is not accurate after parallelization. " - f"{diff_size} out of {size} elements ({diff_size / size * 100:.2f}%) are not close. " - f"max: {diff_output.max():.5f}, mean: {diff_output.float().mean():.5f}, std: {diff_output.float().std():.5f}. " - f"mean of reference: {refer_output.float().mean():.5f}, mean of output: {output.float().mean():.5f}." - ) - for name in extra_input_values.keys(): - del self.values[name] - - def add_output(self, tensor: Tensor, device_ids, - strategy: ShardingStrategy): - trt_output = tensor.as_trt() - comm_action_sequence = strategy.best_resharding_cost[0][0][1] - for commspec in comm_action_sequence: - self.add_comm(self.global_context, tensor.name, device_ids, - commspec) - self.add_layer_singleton(trt_output, device_ids, - strategy.sharding_specs["input0"]) - if trt_output.is_shape_tensor: - output = self.prefixed_graph.add_output_shape(trt_output) - else: - output = self.prefixed_graph.add_output(trt_output) - trt_output.dtype = tensor.dtype - output.attrs["strategy"] = strategy.name - - def assign_shapes(self, shape_info: ShapeInfo): - if self.validation_mode: - shapes = { - f"ref_{name}": shape - for name, shape in shape_info.shapes.items() - } - values = { - f"ref_{name}": value - for name, value in shape_info.values.items() - } - self.shapes.update(shapes) - self.values.update(values) - shape_layers = get_shape_layers(self.prefixed_graph.as_trt()) - shape_info = ShapeInfo(self.shapes, self.values, shape_layers) - self.prefixed_graph.assign_shapes(shape_info) - - -def parallelize( - simplifier: Simplifier, - config: ParallelConfig, -): - auto_parallel_config = simplifier.config - debug_mode = auto_parallel_config.debug_mode - dump_path = auto_parallel_config.dump_path - debug_outputs = auto_parallel_config.debug_outputs - - simplifier.infer_shapes(config.graph_config.num_micro_batches) - network = simplifier.network - graph = simplifier.graph - phy_mesh = config.graph_config.phy_mesh - # TODO: test device_ids = [[0]] - device_ids = phy_mesh.phy_devices_id - stage_phy_meshes = config.graph_config.stage_phy_meshes - block_to_stage = config.graph_config.graph_mapping.block_to_stage - graph_strategy = config.graph_strategy - desimplify_strategy( - graph, - graph_strategy, - config.graph_config.graph_mapping, - ) - graph._plugin_config = simplifier.llm_network.plugin_config - graph_group = GraphGroup.from_graph(graph, config, auto_parallel_config) - - if not debug_mode: - init_all_reduce_helper() - tp_size = phy_mesh.size // config.graph_config.num_stages - shape = (CustomAllReduceHelper.POINTERS_PER_RANK * tp_size + - CustomAllReduceHelper.POINTERS_OF_COUNTER, ) - workspace = graph.as_trt().add_input( - name="all_reduce_workspace", - dtype=trt.int64, - shape=shape, - ) - tensor = graph.register_input(workspace) - tensor.shape = shape - graph_strategy["all_reduce_workspace"] = ShardingStrategy( - sharding_specs={ - "output0": - ShardingSpec( - device_mesh=phy_mesh.as_logical_mesh(), - data_type_size=tensor.dtype_str_size, - data_shape=shape, - max_data_shape=shape, - raw_data_shape=shape, - dim_partition_dict={}, - ) - }) - - if dump_path is not None: - lock = FileLock(f"{dump_path}/path.lock", thread_local=False) - with lock: - with open(f'{dump_path}/sharded_graph.log', 'w+') as file: - config.print_graph_strategy(file) - - for input in graph.inputs: - graph_group.add_input(input, device_ids, graph_strategy[input.name]) - for block in simplifier.blocks: - stage_id = block_to_stage[block.block_id] - stage_phy_mesh = stage_phy_meshes[stage_id] - stage_device_ids = stage_phy_mesh.phy_devices_id.reshape( - config.lmesh.mesh_shape) - for i in block.sorted_layer_ids: - layer = graph.get_layer(network.get_layer(i).name) - layer.attrs["block_id"] = block.block_id - graph_group.add_layer( - layer, - stage_device_ids, - graph_strategy[layer.name], - ) - for output in graph.outputs: - graph_group.add_output(output, device_ids, graph_strategy[output.name]) - - if debug_mode: - new_graph = graph_group.prefixed_graph - debug_outputs = debug_outputs or [] - if isinstance(debug_outputs, str): - if debug_outputs == 'validation': - debug_outputs = [] - for tensor in new_graph.tensors: - if tensor.name.startswith('ref_'): - original_name = tensor.name[4:] - original_tensor = new_graph.get_tensor(original_name) - if original_tensor is not None: - if not original_tensor.is_graph_io: - debug_outputs.append(tensor.name) - debug_outputs.append(original_name) - if original_tensor.is_graph_output: - debug_outputs.append(tensor.name) - else: - pattern = debug_outputs - debug_outputs = [] - for tensor in new_graph.tensors: - if tensor.as_trt().is_shape_tensor: - continue - if tensor.producer is not None: - layer = tensor.producer - if layer.type == trt.LayerType.SHAPE: - continue - if re.match(pattern, tensor.name): - debug_outputs.append(tensor.name) - for output_name in debug_outputs: - trt_output = new_graph.get_tensor(output_name).as_trt() - if trt_output.is_shape_tensor: - output = new_graph.add_output_shape(trt_output) - else: - output = new_graph.add_output(trt_output) - graph_group.assign_shapes(simplifier.shape_info) - if dump_path is not None: - with lock: - new_graph.to_dot( - f'{dump_path}/sharded_graph.dot', - per_device=True, - per_block=True, - # ignore_shape_io=True, - extra_attrs=['strategy'], - ) - return [new_graph] - else: - graphs = [] - for device_id in np.nditer(device_ids): - device_id = device_id.item() - graph = graph_group.graphs[device_id] - graphs.append(graph) - return graphs diff --git a/tensorrt_llm/auto_parallel/pipeline_graph.py b/tensorrt_llm/auto_parallel/pipeline_graph.py deleted file mode 100644 index 80b4c41e76..0000000000 --- a/tensorrt_llm/auto_parallel/pipeline_graph.py +++ /dev/null @@ -1,1035 +0,0 @@ -from dataclasses import dataclass -from typing import Dict, List, Optional - -import numpy as np -import tensorrt as trt -import torch - -from tensorrt_llm._utils import trt_dtype_to_str, trt_dtype_to_torch -from tensorrt_llm.logger import logger -from tensorrt_llm.network import Network, get_plugin_info, set_plugin_info -from tensorrt_llm.plugin.plugin import PluginConfig -from tensorrt_llm.runtime.session import Session - -from .utils import (current_flags, get_builder_flags, get_sorted_layer_ids, - get_strongly_typed, get_trt_network, set_trt_network, - to_base_class_layer, to_subclass_layer) - - -class Tensor: - - def __init__(self, graph: "PipelineGraph"): - self._graph = graph - self._trt = None - self._shape = None - self._max_shape = None - self._value = None - self.producer: Layer = None - self.output_index = None - self.consumers = [] - self.graph_input_index = -1 - self.graph_output_index = -1 - self.attrs = {} - - @staticmethod - def from_trt(graph: "PipelineGraph", trt_tensor: trt.ITensor): - tensor = Tensor(graph) - tensor._trt = trt_tensor - return tensor - - def as_trt(self) -> trt.ITensor: - return self._trt - - def copy(self) -> "Tensor": - tensor = Tensor(self._graph) - tensor._trt = self._trt - tensor._shape = self._shape - tensor._max_shape = self._max_shape - tensor._value = self._value - tensor.producer = self.producer - tensor.output_index = self.output_index - tensor.consumers = [*self.consumers] - tensor.graph_input_index = self.graph_input_index - tensor.graph_output_index = self.graph_output_index - tensor.attrs = self.attrs.copy() - return tensor - - @property - def graph(self) -> "PipelineGraph": - return self._graph - - @property - def name(self) -> str: - return self._trt.name - - @name.setter - def name(self, name: str): - old_name = self._trt.name - if name != old_name: - self._trt.name = name - self.graph._tensors[name] = self - del self.graph._tensors[old_name] - if self.is_graph_input: - self.graph._inputs[name] = self - del self.graph._inputs[old_name] - elif self.is_graph_output: - self.graph._outputs[name] = self - del self.graph._outputs[old_name] - - @property - def shape(self): - return self._shape - - @property - def max_shape(self): - return self._max_shape - - @property - def raw_shape(self): - assert isinstance(self._trt, trt.ITensor) - return self._trt.shape - - @shape.setter - def shape(self, shape): - self._shape = shape - - @max_shape.setter - def max_shape(self, max_shape): - self._max_shape = max_shape - - @raw_shape.setter - def raw_shape(self, raw_shape): - assert isinstance(self._trt, trt.ITensor) - self._trt.shape = raw_shape - - @property - def value(self): - return self._value - - @value.setter - def value(self, value): - self._value = value - - @property - def dtype(self): - return self._trt.dtype - - @property - def broadcast_across_batch(self): - return self._trt.broadcast_across_batch - - @property - def dtype_size(self): - return self.dtype.itemsize - - @property - def dtype_str(self): - return trt_dtype_to_str(self.dtype) - - @property - def dtype_str_size(self): - return [trt_dtype_to_str(self.dtype), self.dtype.itemsize] - - @property - def is_graph_input(self) -> bool: - return self.graph_input_index != -1 - - @property - def is_graph_output(self) -> bool: - return self.graph_output_index != -1 - - @property - def is_graph_io(self) -> bool: - return self.is_graph_input or self.is_graph_output - - -class Layer: - - def __init__(self, graph): - self._graph = graph - self._trt = None - self._index = None - self._inputs = [] - self._outputs = [] - self._is_shape_io = False - self.attrs = {} - - @staticmethod - def from_trt(graph, trt_layer, index): - layer = Layer(graph) - layer._trt = trt_layer - layer._index = index - for i in range(trt_layer.num_inputs): - input = trt_layer.get_input(i) - if input is not None: - layer._inputs.append(graph.get_tensor(input.name)) - layer._inputs[i].consumers.append((layer, i)) - else: - layer._inputs.append(None) - for i in range(trt_layer.num_outputs): - output = trt_layer.get_output(i) - layer._outputs.append(graph.get_tensor(output.name)) - layer._outputs[i].producer = layer - layer._outputs[i].output_index = i - set_trt_network(trt_layer, graph.as_trt()) - return layer - - def as_trt(self) -> trt.ILayer: - return self._trt - - @property - def graph(self) -> "PipelineGraph": - return self._graph - - @property - def name(self) -> str: - return self._trt.name - - @name.setter - def name(self, name: str): - old_name = self._trt.name - if name != old_name: - self._trt.name = name - self.graph._layers[name] = self - del self.graph._layers[old_name] - - @property - def type(self) -> trt.LayerType: - return self._trt.type - - @property - def index(self) -> int: - return self._index - - @property - def inputs(self) -> List[Tensor]: - return self._inputs - - @property - def outputs(self) -> List[Tensor]: - return self._outputs - - def get_input(self, index: int) -> Tensor: - return self._inputs[index] - - def get_output(self, index: int) -> Tensor: - return self._outputs[index] - - @property - def num_inputs(self) -> int: - return self._trt.num_inputs - - @property - def num_outputs(self) -> int: - return self._trt.num_outputs - - @property - def is_shape_io(self) -> bool: - return self._is_shape_io - - def to_subclass(self): - to_subclass_layer(self._trt) - - def to_base_class(self): - to_base_class_layer(self._trt) - - def assign_shapes(self, shapes, values): - for output in self.outputs: - output.shape = shapes[output.name] - output.value = values.get(output.name) - - -@dataclass -class GraphRunner: - session: Session - inputs: Dict[str, torch.Tensor] - outputs: Dict[str, torch.Tensor] - stream: torch.Stream - - def run(self): - cuda_stream = self.stream.cuda_stream - assert self.session.run(self.inputs, self.outputs, cuda_stream) - self.stream.synchronize() - return self.outputs - - -class PipelineGraph: - - def __init__(self): - self._trt = None - self._inputs: Dict[str, Tensor] = {} - self._outputs: Dict[str, Tensor] = {} - self._layers: Dict[str, Layer] = {} - self._tensors: Dict[str, Tensor] = {} - self._io_buffer_mapping = {} - self._unfilled_weights = {} - self._auto_parallel_config = None - self._plugin_config: PluginConfig = None - - @staticmethod - def create_graph(): - graph = PipelineGraph() - trt_builder = trt.Builder(logger.trt_logger) - explicit_batch_flag = 0 - # Explicit batch flag will be deprecated in TRT 10 - if "EXPLICIT_BATCH" in trt.NetworkDefinitionCreationFlag.__members__.keys( - ): - explicit_batch_flag = 1 << int( - trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) - if get_strongly_typed(): - network = trt_builder.create_network( - explicit_batch_flag - | (1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))) - else: - network = trt_builder.create_network(explicit_batch_flag) - graph._trt = network - return graph - - def _register_unfilled_weights(self, layer_name, weights, values): - self._unfilled_weights[layer_name] = (weights, values) - - def _add_tensor(self, tensor, old_tensor, prefix): - if prefix is not None: - tensor.name = prefix + old_tensor.name - else: - tensor.name = old_tensor.name - tensor.location = old_tensor.location - if old_tensor.dynamic_range is not None: - tensor.dynamic_range = old_tensor.dynamic_range - if tensor.is_network_input: - tensor.shape = old_tensor.shape - for i in range(len(old_tensor.shape)): - name = old_tensor.get_dimension_name(i) - if name is not None: - tensor.set_dimension_name(i, name) - return self._register_tensor(tensor) - - def _register_tensor(self, tensor): - wrapped_tensor = Tensor.from_trt(self, tensor) - assert tensor.name not in self._tensors - self._tensors[tensor.name] = wrapped_tensor - return wrapped_tensor - - def add_input(self, tensor, prefix=None): - tensor_name = tensor.name - if prefix is not None: - tensor_name = prefix + tensor_name - input = self._trt.add_input(tensor_name, tensor.dtype, tensor.shape) - new_tensor = self._add_tensor(input, tensor, prefix) - new_tensor.graph_input_index = len(self._inputs) - self._inputs[tensor_name] = new_tensor - return new_tensor - - def register_input(self, tensor, index=None): - if index is None: - index = self.num_inputs - 1 - assert self._trt.get_input(index).name == tensor.name - wrapped_input = self._register_tensor(tensor) - wrapped_input.graph_input_index = index - self._inputs[tensor.name] = wrapped_input - return wrapped_input - - def add_output(self, tensor, prefix=None): - tensor_name = tensor.name - if prefix is not None: - tensor_name = prefix + tensor_name - output = self.get_tensor(tensor_name) - output.graph_output_index = len(self._outputs) - trt_output = output.as_trt() - self._trt.mark_output(trt_output) - trt_output.dtype = tensor.dtype - self._outputs[tensor_name] = output - return output - - def add_output_shape(self, tensor, prefix=None): - tensor_name = tensor.name - if prefix is not None: - tensor_name = prefix + tensor_name - output = self.get_tensor(tensor_name) - trt_output = output.as_trt() - self._trt.mark_output_for_shapes(trt_output) - trt_output.dtype = tensor.dtype - self._outputs[tensor_name] = output - return output - - def add_layer( - self, - layer, - input_mapping=None, - prefix=None, - updated_attrs=None, - ) -> Layer: - - def get_input(i): - name = layer.get_input(i).name - if prefix is not None: - name = prefix + name - if input_mapping is not None and name in input_mapping: - name = input_mapping[name] - return self.get_tensor(name).as_trt() - - network = self._trt - layer_type = layer.type - to_subclass_layer(layer) - if layer_type == trt.LayerType.ACTIVATION: - trt_input = get_input(0) - new_layer = network.add_activation(trt_input, layer.type) - new_layer.alpha = layer.alpha - new_layer.beta = layer.beta - elif layer_type == trt.LayerType.CONCATENATION: - trt_inputs = [get_input(i) for i in range(layer.num_inputs)] - new_layer = network.add_concatenation(trt_inputs) - new_layer.axis = layer.axis - elif layer_type == trt.LayerType.CONSTANT: - new_layer = network.add_constant(layer.shape, layer.weights) - elif layer_type == trt.LayerType.ELEMENTWISE: - new_layer = network.add_elementwise(get_input(0), get_input(1), - layer.op) - elif layer_type == trt.LayerType.FILL: - if layer.num_inputs >= 1 and layer.get_input(0) is not None: - shape_input = get_input(0) - shape = [1] - else: - shape_input = None - shape = layer.shape - new_layer = network.add_fill(shape, layer.operation, layer.to_type) - if shape_input is not None: - new_layer.set_input(0, shape_input) - if layer.num_inputs >= 1 and layer.get_input(0) is not None: - new_layer.set_input(0, get_input(0)) - if layer.num_inputs >= 2 and layer.get_input(1) is not None: - new_layer.set_input(1, get_input(1)) - else: - new_layer.alpha = layer.alpha - if layer.num_inputs >= 3 and layer.get_input(2) is not None: - new_layer.set_input(2, get_input(2)) - else: - new_layer.beta = layer.beta - elif layer_type == trt.LayerType.GATHER: - trt_input = get_input(0) - trt_indices = get_input(1) - new_layer = network.add_gather_v2(trt_input, trt_indices, - layer.mode) - new_layer.axis = layer.axis - new_layer.num_elementwise_dims = layer.num_elementwise_dims - new_layer.mode = layer.mode - elif layer_type == trt.LayerType.MATRIX_MULTIPLY: - new_layer = network.add_matrix_multiply(get_input(0), layer.op0, - get_input(1), layer.op1) - elif layer_type == trt.LayerType.REDUCE: - new_layer = network.add_reduce(get_input(0), layer.op, layer.axes, - layer.keep_dims) - elif layer_type == trt.LayerType.SELECT: - trt_condition = get_input(0) - trt_then = get_input(1) - trt_else = get_input(2) - new_layer = network.add_select(trt_condition, trt_then, trt_else) - elif layer_type == trt.LayerType.SHUFFLE: - new_layer = network.add_shuffle(get_input(0)) - new_layer.first_transpose = layer.first_transpose - new_layer.second_transpose = layer.second_transpose - new_layer.zero_is_placeholder = layer.zero_is_placeholder - if layer.num_inputs >= 2: - trt_reshape_dims_tensor = get_input(1) - new_layer.set_input(1, trt_reshape_dims_tensor) - else: - new_layer.reshape_dims = layer.reshape_dims - elif layer_type == trt.LayerType.SLICE: - if layer.num_inputs >= 2 and layer.get_input(1) is not None: - trt_start = get_input(1) - start = [] - else: - trt_start = None - start = layer.start - if layer.num_inputs >= 3 and layer.get_input(2) is not None: - trt_shape = get_input(2) - shape = [] - else: - trt_shape = None - shape = layer.shape - if layer.num_inputs >= 4 and layer.get_input(3) is not None: - trt_stride = get_input(3) - stride = [] - else: - trt_stride = None - stride = layer.stride - new_layer = network.add_slice(get_input(0), start, shape, stride) - new_layer.mode = layer.mode - if trt_start is not None: - new_layer.set_input(1, trt_start) - if trt_shape is not None: - new_layer.set_input(2, trt_shape) - if trt_stride is not None: - new_layer.set_input(3, trt_stride) - elif layer_type == trt.LayerType.SOFTMAX: - new_layer = network.add_softmax(get_input(0)) - new_layer.axes = layer.axes - elif layer_type == trt.LayerType.UNARY: - new_layer = network.add_unary(get_input(0), layer.op) - elif layer_type == trt.LayerType.SHAPE: - new_layer = network.add_shape(get_input(0)) - elif layer_type == trt.LayerType.ASSERTION: - new_layer = network.add_assertion(get_input(0), layer.message) - elif layer_type == trt.LayerType.CAST: - new_layer = network.add_cast(get_input(0), layer.to_type) - elif layer_type == trt.LayerType.NORMALIZATION: - trt_input = get_input(0) - trt_scale = get_input(1) - trt_bias = get_input(2) - new_layer = network.add_normalization(trt_input, trt_scale, - trt_bias, layer.axes) - new_layer.epsilon = layer.epsilon - new_layer.num_groups = layer.num_groups - new_layer.compute_precision = layer.compute_precision - elif layer_type == trt.LayerType.IDENTITY: - new_layer = network.add_identity(get_input(0)) - elif layer_type == trt.LayerType.PLUGIN_V2: - plugin = layer.plugin - updated = False - if (updated_attrs is not None - and updated_attrs.get("plugin") is not None): - plugin = updated_attrs["plugin"] - updated = True - updated_attrs = None - new_layer = network.add_plugin_v2( - [get_input(i) for i in range(layer.num_inputs)], - plugin, - ) - else: - raise NotImplementedError( - "Unsupported layer type: {}".format(layer_type)) - - if updated_attrs is not None: - for attr_name, attr_value in updated_attrs.items(): - setattr(new_layer, attr_name, attr_value) - - to_base_class_layer(layer) - to_base_class_layer(new_layer) - layer_index = network.num_layers - 1 - layer_name = layer.name - if prefix is not None: - layer_name = prefix + layer_name - new_layer.name = layer_name - new_layer.metadata = new_layer.name - if layer.precision_is_set: - new_layer.precision = layer.precision - for i in range(layer.num_outputs): - if layer.output_type_is_set(i): - new_layer.set_output_type(i, layer.get_output_type(i)) - output = new_layer.get_output(i) - self._add_tensor(output, layer.get_output(i), prefix) - wrapped_layer = Layer.from_trt(self, new_layer, layer_index) - assert layer_name not in self._layers - self._layers[layer_name] = wrapped_layer - if layer_type == trt.LayerType.PLUGIN_V2: - if not updated: - plugin_info = get_plugin_info(get_trt_network(layer), - layer.name) - set_plugin_info(self.as_trt(), new_layer.name, plugin_info) - return wrapped_layer - - def register_layer(self, layer, index=None): - if index is None: - index = self.num_layers - 1 - assert self._trt.get_layer(index).name == layer.name - to_base_class_layer(layer) - for i in range(layer.num_outputs): - output = layer.get_output(i) - self._register_tensor(output) - wrapped_layer = Layer.from_trt(self, layer, index) - assert layer.name not in self._layers - self._layers[layer.name] = wrapped_layer - to_subclass_layer(layer) - return wrapped_layer - - def get_runner( - self, - shapes=None, - values=None, - profile=None, - timing_cache=None, - opt_level=None, - ) -> GraphRunner: - shapes = shapes or {} - values = values or {} - inputs = {} - outputs = {} - for input in self.inputs: - if input is not None: - value = values.get(input.name) - if value is None: - value = input.value - if value is not None: - if not isinstance(value, torch.Tensor): - value = torch.tensor( - value, - dtype=trt_dtype_to_torch(input.dtype), - device='cpu', - ) - inputs[input.name] = value - else: - shape = shapes.get(input.name) - if shape is None: - shape = input.shape - assert shape is not None - inputs[input.name] = torch.empty( - tuple(shape), - dtype=trt_dtype_to_torch(input.dtype), - device=torch.cuda.current_device(), - ) - if torch.is_floating_point(inputs[input.name]): - inputs[input.name].normal_() - # inputs[input.name][:] = random.choice([2, 3, 5, 7]) - for output in self.outputs: - if output.as_trt().is_shape_tensor: - continue - if output.name in self._io_buffer_mapping: - input_name = self._io_buffer_mapping[output.name] - if input_name in inputs: - outputs[output.name] = inputs[input_name] - continue - value = values.get(output.name) - if value is not None and isinstance(value, torch.Tensor): - outputs[output.name] = value - else: - shape = shapes.get(output.name) - if shape is None: - shape = output.shape - assert shape is not None - outputs[output.name] = torch.empty( - tuple(shape), - dtype=trt_dtype_to_torch(output.dtype), - device=torch.cuda.current_device(), - ) - network = self.as_trt() - config = network.builder.create_builder_config() - if opt_level is not None: - config.builder_optimization_level = opt_level - config.flags = get_builder_flags() - profile = profile or network.builder.create_optimization_profile() - profile_index = config.add_optimization_profile(profile) - if timing_cache is not None: - config.set_timing_cache(timing_cache, ignore_mismatch=False) - plan = network.builder.build_serialized_network(network, config) - if plan is None: - logger.error('Engine building failed, please check the error log.') - session = Session.from_serialized_engine(plan) - stream = torch.cuda.current_stream() - cuda_stream = stream.cuda_stream - context = session.context - context.set_optimization_profile_async(profile_index, cuda_stream) - runner = GraphRunner(session, inputs, outputs, stream) - return runner - - def run( - self, - shapes=None, - values=None, - profile=None, - timing_cache=None, - opt_level=None, - ): - return self.get_runner( - shapes, - values, - profile, - timing_cache, - opt_level, - ).run() - - def duplicate_graph(self): - graph = PipelineGraph.create_graph() - network = self.as_trt() - for i in range(network.num_inputs): - input = network.get_input(i) - graph.add_input(input) - sorted_layer_ids = get_sorted_layer_ids(network) - for i in sorted_layer_ids: - layer = network.get_layer(i) - graph.add_layer(layer) - for i in range(network.num_outputs): - output = network.get_output(i) - if output.is_shape_tensor: - graph.add_output_shape(output) - else: - graph.add_output(output) - return graph - - @staticmethod - def from_trt(trt_network): - graph = PipelineGraph() - graph._trt = trt_network - - # construct inputs and tensors - for i in range(trt_network.num_inputs): - trt_input = trt_network.get_input(i) - tensor = Tensor.from_trt(graph, trt_input) - tensor.graph_input_index = i - graph._tensors[tensor.name] = tensor - graph._inputs[tensor.name] = tensor - for i in range(trt_network.num_layers): - trt_layer = trt_network.get_layer(i) - for i in range(trt_layer.num_outputs): - trt_output = trt_layer.get_output(i) - tensor = Tensor.from_trt(graph, trt_output) - graph._tensors[tensor.name] = tensor - - # construct layers and outputs - for i in range(trt_network.num_layers): - layer = Layer.from_trt(graph, trt_network.get_layer(i), i) - graph._layers[layer.name] = layer - for i in range(trt_network.num_outputs): - tensor_name = trt_network.get_output(i).name - output_tensor = graph._tensors[tensor_name] - output_tensor.graph_output_index = i - graph._outputs[tensor_name] = output_tensor - - return graph - - @staticmethod - def from_network(network: Network, builder_config): - builder_flags = builder_config.trt_builder_config.flags - with current_flags(builder_flags, network.strongly_typed): - graph = PipelineGraph.from_trt(network.trt_network) - graph.infer_shapes(network._generate_optimization_profiles()[-1]) - return graph - - def assign_shapes(self, shape_info=None, is_partial=False): - if shape_info is None: - for tensor in self.tensors: - tensor.shape = tensor.raw_shape - return - for tensor in self.tensors: - if tensor.name in shape_info.shapes: - tensor.shape = shape_info.shapes[tensor.name] - elif not is_partial: - raise ValueError(f"Cannot find shape for tensor: {tensor.name}") - if shape_info.max_shapes is not None: - if tensor.name in shape_info.max_shapes: - tensor.max_shape = shape_info.max_shapes[tensor.name] - elif not is_partial: - raise ValueError( - f"Cannot find max shape for tensor: {tensor.name}") - if tensor.name in shape_info.values: - tensor.value = shape_info.values[tensor.name] - for layer in self.layers: - if layer.name in shape_info.shape_layers: - layer._is_shape_io = True - - def infer_shapes(self, profile=None): - from .shape_info import get_shape_info - - shape_info = get_shape_info(self._trt, profile) - self.assign_shapes(shape_info) - - def as_trt(self) -> trt.INetworkDefinition: - return self._trt - - def get_input(self, name: str) -> Tensor: - return self._inputs.get(name) - - def is_input(self, name: str) -> bool: - return name in self._inputs - - @property - def inputs(self) -> List[Tensor]: - return [*self._inputs.values()] - - @property - def num_inputs(self) -> int: - return self._trt.num_inputs - - def get_output(self, name: str) -> Tensor: - return self._outputs.get(name) - - def is_output(self, name: str) -> bool: - return name in self._outputs - - @property - def outputs(self) -> List[Tensor]: - return [*self._outputs.values()] - - @property - def num_outputs(self) -> int: - return self._trt.num_outputs - - def get_tensor(self, name: str) -> Tensor: - return self._tensors.get(name) - - @property - def tensors(self) -> List[Tensor]: - return [*self._tensors.values()] - - def get_layer(self, name: str) -> Layer: - return self._layers.get(name) - - @property - def layers(self) -> List[Layer]: - return [*self._layers.values()] - - @property - def sorted_layers(self) -> List[Layer]: - sorted_layer_ids = get_sorted_layer_ids(self.as_trt()) - return [ - self.get_layer(self.as_trt().get_layer(layer_id).name) - for layer_id in sorted_layer_ids - ] - - @property - def num_layers(self) -> int: - return self._trt.num_layers - - def to_dot(self, - path=None, - per_device=False, - per_block=False, - ignore_shape_io=False, - no_style=False, - extra_attrs=None) -> Optional[str]: - ''' - Get a graphviz representation of the graph. - - Parameters: - path: the path to save the graphviz file, if not provided, will return the graphviz source code - ''' - try: - import graphviz - except ImportError: - logger.error( - "Failed to import graphviz, please install graphviz to enable PipelineGraph.to_dot()" - ) - return - - extra_attrs = extra_attrs or [] - - graph = graphviz.Digraph() - input_block_graph = graphviz.Digraph(name='cluster_inputs') - output_block_graph = graphviz.Digraph(name='cluster_outputs') - device_graphs = {} - block_graphs = {} - block_graph_mapping = [] - tensor_names = set() - layer_names = set() - - common_style = dict(fontname='Arial', ) - node_style = dict( - **common_style, - style='rounded,filled,bold', - ) - tensor_style = dict( - **node_style, - shape='ellipse', - fillcolor='white', - ) - input_tensor_style = {**tensor_style, 'fillcolor': 'green'} - output_tensor_style = {**tensor_style, 'fillcolor': 'lightgreen'} - layer_style = dict( - **node_style, - shape='box', - fillcolor='white', - ) - shape_layer_style = {**layer_style, 'fillcolor': 'grey'} - helper_layer_style = {**layer_style, 'fillcolor': 'lightgrey'} - graph_style = dict( - **common_style, - style='rounded', - penwidth='5', - fontsize='28', - ) - device_graph_style = dict( - **graph_style, - color='cornflowerblue', - ) - block_graph_style = dict( - **graph_style, - color='darkcyan', - ) - input_block_style = dict( - **graph_style, - color='green', - ) - output_block_style = dict( - **graph_style, - color='lightgreen', - ) - if no_style: - device_graph_style = {} - block_graph_style = {} - input_block_style = {} - output_block_style = {} - input_block_graph.attr(label='inputs', **input_block_style) - output_block_graph.attr(label='outputs', **output_block_style) - - def get_tensor_labels(tensor): - labels = [] - if tensor.value is not None: - labels.append(f"value={tensor.value}") - else: - labels.append(f"dtype={tensor.dtype.name}{tensor.shape}") - for attr_name in extra_attrs: - if attr_name in tensor.attrs: - labels.append(f"{attr_name}={tensor.attrs[attr_name]}") - return labels - - def get_device_graph(name): - if per_device and name.startswith('device'): - device_name = name.split('_')[0] - if device_name not in device_graphs: - device_graph = graphviz.Digraph(name='cluster_' + - device_name) - device_graph.attr(label=device_name, **device_graph_style) - device_graphs[device_name] = device_graph - return device_graphs[device_name] - return None - - def get_block_graph(layer, current_graph): - if per_block and 'block_id' in layer.attrs: - block_label = f"block{layer.attrs['block_id']}" - if current_graph.name is not None: - graph_label = current_graph.name[len('cluster_'):] - else: - graph_label = '' - block_name = f"{graph_label}{block_label}" - if block_name not in block_graphs: - block_graph = graphviz.Digraph(name='cluster_' + block_name) - block_graph.attr(label=block_label, **block_graph_style) - block_graphs[block_name] = block_graph - block_graph_mapping.append((current_graph, block_graph)) - return block_graphs[block_name] - return current_graph - - for name, tensor in self._tensors.items(): - style = tensor_style - if tensor.is_graph_input: - style = input_tensor_style - current_graph = input_block_graph - elif tensor.is_graph_output: - style = output_tensor_style - current_graph = output_block_graph - elif tensor.producer.num_outputs == 1: - continue - else: - current_graph = get_device_graph(name) or graph - current_graph = get_block_graph(tensor.producer, current_graph) - if no_style: - style = {} - labels = [name, *get_tensor_labels(tensor)] - content = "\n".join(labels) - current_graph.node(name, content, **style) - tensor_names.add(name) - - for layer in self.sorted_layers: - name = layer.name - - style = layer_style - if layer.is_shape_io: - if ignore_shape_io: - continue - style = shape_layer_style - elif layer.attrs.get("role", None) == "helper": - style = helper_layer_style - fillcolor = None - plugin_type = None - if layer.type == trt.LayerType.PLUGIN_V2: - fillcolor = 'yellow' - layer.to_subclass() - plugin_type = layer.as_trt().plugin.plugin_type - layer.to_base_class() - if layer.type == trt.LayerType.MATRIX_MULTIPLY or plugin_type == 'Gemm': - fillcolor = 'orange' - if fillcolor is not None: - style = {**style, 'fillcolor': fillcolor} - if no_style: - style = {} - - layer_attrs = {} - layer_type = layer.type - layer.to_subclass() - if layer_type == trt.LayerType.CONSTANT: - if not layer.is_shape_io: - if trt.volume(layer.get_output(0).shape) <= 8: - weights = layer.as_trt().weights - if isinstance(weights, trt.Weights): - weights = weights.numpy() - value = np.array2string( - weights, - formatter={'float_kind': lambda x: f"{x:.2e}"}) - layer_attrs['value'] = value - elif layer_type == trt.LayerType.SHUFFLE: - for attr_name in ['first_transpose', 'second_transpose']: - attr_value = getattr(layer.as_trt(), attr_name) - if tuple(attr_value) != (0, 1, 2, 3, 4, 5, 6, 7): - tensor = layer.get_input( - 0 - ) if attr_name == 'first_transpose' else layer.get_output( - 0) - layer_attrs[attr_name] = tuple( - attr_value)[:len(tensor.shape)] - if layer.num_inputs < 2: - attr_value = layer.as_trt().reshape_dims - layer_attrs['reshape_dims'] = attr_value - elif layer_type == trt.LayerType.SLICE: - if layer.num_inputs < 2 or layer.get_input(1) is None: - layer_attrs['start'] = layer.as_trt().start - if layer.num_inputs < 4 or layer.get_input(3) is None: - attr_value = layer.as_trt().stride - if attr_value != tuple([1] * - len(layer.get_output(0).shape)): - layer_attrs['stride'] = attr_value - layer.to_base_class() - - if layer.is_shape_io: - labels = [layer.type.name] - else: - labels = [name, layer.type.name] - for key, value in layer_attrs.items(): - labels.append(f"{key}={value}") - for attr_name in extra_attrs: - if attr_name in layer.attrs: - labels.append(f"{attr_name}={layer.attrs[attr_name]}") - if layer.num_outputs == 1: - output = layer.get_output(0) - if output.name != f'{layer.name}_output_0': - labels.append(f"output={output.name}") - labels.extend(get_tensor_labels(output)) - content = "\n".join(labels) - - current_graph = get_device_graph(name) or graph - current_graph = get_block_graph(layer, current_graph) - current_graph.node(name, content, **style) - layer_names.add(name) - - for index, input in enumerate(layer.inputs): - if input is not None: - if input.is_graph_input or input.producer.num_outputs > 1: - if input.name in tensor_names: - graph.edge(input.name, name, str(index)) - else: - if input.producer.name in layer_names: - graph.edge(input.producer.name, name, str(index)) - if layer.num_outputs > 1 or (layer.num_outputs == 1 and - layer.get_output(0).is_graph_output): - for index, output in enumerate(layer.outputs): - graph.edge(name, output.name, str(index)) - - graph.subgraph(input_block_graph) - graph.subgraph(output_block_graph) - for parent_graph, block_graph in block_graph_mapping: - parent_graph.subgraph(block_graph) - for device_graph in device_graphs.values(): - graph.subgraph(device_graph) - - if not path: - return graph.source - graph.save(path) - - @staticmethod - def trt_to_dot(trt_network, path=None): - graph = PipelineGraph.from_trt(trt_network) - graph.assign_shapes() - dot = graph.to_dot(no_style=True) - if path is not None: - with open(path, "w") as f: - f.write(dot) - else: - return dot diff --git a/tensorrt_llm/auto_parallel/runtime_profiling.py b/tensorrt_llm/auto_parallel/runtime_profiling.py deleted file mode 100644 index 8f6c8d9cc4..0000000000 --- a/tensorrt_llm/auto_parallel/runtime_profiling.py +++ /dev/null @@ -1,150 +0,0 @@ -import numpy as np -import tensorrt as trt -import torch - -from tensorrt_llm.logger import logger -from tensorrt_llm.network import get_plugin_info - -from .shape_info import get_per_layer_graph -from .utils import get_cache_key, get_trt_network, get_updated_plugin - - -class NvtxProfiler(object): - - def __init__(self, nvtx_name, enable=True): - self.nvtx_name = nvtx_name - self.enable = enable - - def __enter__(self): - if self.enable: - torch.cuda.nvtx.range_push(self.nvtx_name) - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.enable: - torch.cuda.nvtx.range_pop() - - -class LayerProfiler(trt.IProfiler): - - def __init__(self): - trt.IProfiler.__init__(self) - self.layer_count = 0 - self.time = 0 - - def report_layer_time(self, layer_name, ms): - logger.debug(f'{layer_name=}, {self.layer_count=}, time = {ms} ms') - self.time += ms - self.layer_count += 1 - - -class RuntimeProfiler(object): - - def __init__(self): - self.timing_cache = None - - def _profile(self, layer, layer_attrs, shapes, values, io_buffer_mapping): - is_plugin = layer.type == trt.LayerType.PLUGIN_V2 - if is_plugin and len(layer_attrs) > 0: - plugin_info = get_plugin_info( - get_trt_network(layer), - layer.name, - ) - new_plugin, _ = get_updated_plugin(plugin_info, layer_attrs) - layer_attrs = {"plugin": new_plugin} - graph, output_mapping = get_per_layer_graph(layer, shapes, values, - layer_attrs) - graph._io_buffer_mapping = io_buffer_mapping - network = graph.as_trt() - if network.num_outputs > 0 and np.all([ - network.get_output(i).is_shape_tensor - for i in range(network.num_outputs) - ]): - return 0.0 - for proxy_output, output in output_mapping.items(): - shapes[proxy_output] = shapes[output] - if not self.timing_cache: - self.timing_cache = network.builder.create_builder_config( - ).create_timing_cache(b"") - runner = graph.get_runner( - shapes, - values, - timing_cache=self.timing_cache, - ) - context = runner.session.context - context.profiler = LayerProfiler() - runner.run() - profiler_time_first_run = context.profiler.time - runner.run() - return (context.profiler.time - profiler_time_first_run) * 1000.0 - - def runtime_profile(self, layer, layer_attrs, input_values, strategy, - device_mesh): - logger.debug(f"start to profile layer {layer.name}") - shapes = {} - values = {} - dtypes = {} - trt_layer = layer.as_trt() - - sharding_sequences = () - for i in range(layer.num_inputs): - input = trt_layer.get_input(i) - if input is not None: - shapes[input.name] = strategy.sharding_specs[ - f'input{i}'].get_sharded_shape_per_device() - dtypes[input.name] = input.dtype - sharding_sequences += (str( - strategy.sharding_specs[f"input{i}"].sharding_sequence), ) - if i in input_values: - values[input.name] = input_values[i] - else: - value = layer.get_input(i).value - if value is not None: - values[input.name] = value - else: - sharding_sequences += (None, ) - - for i in range(layer.num_outputs): - output = trt_layer.get_output(i) - if f'output{i}' in strategy.communication_actions: - shapes[output.name] = strategy.communication_actions[ - f'output{i}'].sharding_spec.get_sharded_shape_per_device() - else: - shapes[output.name] = strategy.sharding_specs[ - f'output{i}'].get_sharded_shape_per_device() - dtypes[output.name] = output.dtype - sharding_sequences += (str( - strategy.sharding_specs[f"output{i}"].sharding_sequence), ) - data_key = get_cache_key( - trt_layer, - shapes, - values, - dtypes=dtypes, - updated_attrs=layer_attrs, - ) - data_key += (sharding_sequences, ) - elapsed_time = device_mesh.prof_database.query( - device_mesh.cluster_key, - data_key, - ) - if elapsed_time: - logger.debug( - f'runtime profiling cache hit {data_key}: {elapsed_time} us') - return elapsed_time - with NvtxProfiler(f'{layer.name}_{data_key}', enable=True): - elapsed_time = self._profile( - layer.as_trt(), - layer_attrs, - shapes, - values, - layer.graph._io_buffer_mapping, - ) - logger.debug( - f'runtime profiling cache miss {data_key}: {elapsed_time} us') - - device_mesh.prof_database.update( - device_mesh.cluster_key, - data_key, - (elapsed_time, strategy.alpha_beta_cost), - ) - - return elapsed_time diff --git a/tensorrt_llm/auto_parallel/shape_info.py b/tensorrt_llm/auto_parallel/shape_info.py deleted file mode 100644 index d8a561b292..0000000000 --- a/tensorrt_llm/auto_parallel/shape_info.py +++ /dev/null @@ -1,362 +0,0 @@ -from dataclasses import dataclass -from enum import Enum -from typing import Dict, List, Set - -import numpy as np -import tensorrt as trt -import torch - -from tensorrt_llm._common import _is_building -from tensorrt_llm._utils import (trt_dtype_to_np, trt_dtype_to_str, - trt_dtype_to_torch) -from tensorrt_llm.logger import logger - -from .pipeline_graph import PipelineGraph -from .utils import (get_builder_flags, get_cache_key, get_sorted_layer_ids, - set_trt_network, to_base_class_layer, to_subclass_layer, - to_trt_weights) - - -class ShapeType(Enum): - MIN = 0 - OPT = 1 - MAX = 2 - - -_trt_to_type_dict = { - trt.int64: int, - trt.bool: bool, -} - - -def get_shape_layers(trt_network): - shape_layers = set() - for i in range(trt_network.num_layers): - layer = trt_network.get_layer(i) - if (layer.num_inputs > 0 and np.all([ - layer.get_input(j).is_shape_tensor - for j in range(layer.num_inputs) - if layer.get_input(j) is not None - ])) or (layer.num_outputs > 0 and np.all([ - layer.get_output(j).is_shape_tensor - for j in range(layer.num_outputs) - ])): - shape_layers.add(layer.name) - return shape_layers - - -def get_layers_in_shape_network(trt_network, shape_layers, sorted_layer_ids): - layers = set() - shape_tensors = set() - for layer_id in reversed(sorted_layer_ids): - layer = trt_network.get_layer(layer_id) - in_shape_network = False - if layer.name in shape_layers: - in_shape_network = True - else: - for j in range(layer.num_outputs): - output = layer.get_output(j) - if output.name in shape_tensors: - in_shape_network = True - break - if in_shape_network: - layers.add(layer.name) - for j in range(layer.num_inputs): - input = layer.get_input(j) - if input is not None: - shape_tensors.add(input.name) - return layers - - -def get_shape_network(trt_network, - shapes, - values, - sorted_layer_ids, - profile=None, - shape_type: ShapeType = ShapeType.OPT): - shape_layers = get_shape_layers(trt_network) - layers_in_shape_network = get_layers_in_shape_network( - trt_network, shape_layers, sorted_layer_ids) - shape_graph = PipelineGraph.create_graph() - shape_network = shape_graph.as_trt() - shape_builder = shape_network.builder - shape_profile = shape_builder.create_optimization_profile() - for i in range(trt_network.num_inputs): - input = trt_network.get_input(i) - shapes[input.name] = input.shape - new_input = shape_graph.add_input(input) - if profile is not None: - if -1 in input.shape: - shape = profile.get_shape(input.name) - shape = shape[shape_type.value] - shapes[input.name] = shape - new_input.raw_shape = shape - if input.is_shape_tensor: - shape_values = profile.get_shape_input(input.name) - value = shape_values[shape_type.value] - values[input.name] = value - shape_profile.set_shape_input(input.name, value, value, value) - output_mapping = {} - for layer_id in sorted_layer_ids: - layer = trt_network.get_layer(layer_id) - if layer.name in shape_layers: - new_layer = shape_graph.add_layer(layer) - for i in range(layer.num_outputs): - output = layer.get_output(i) - if output.dtype == trt.DataType.BOOL: - proxy_layer = shape_network.add_cast( - new_layer.as_trt().get_output(i), - trt.DataType.INT32, - ) - proxy_output = proxy_layer.get_output(0) - shape_graph.register_layer(proxy_layer) - shape_graph.add_output_shape(proxy_output) - output_mapping[proxy_output.name] = (output.name, - output.dtype) - else: - shape_graph.add_output_shape(output) - elif layer.name in layers_in_shape_network: - if layer.type == trt.LayerType.CONSTANT: - shape_graph.add_input(layer.get_output(0)) - else: - shape_graph.add_layer(layer) - return shape_network, shape_profile, shape_layers, output_mapping - - -def get_per_layer_graph( - layer, - shapes, - values, - updated_attrs=None, - is_shape_io: bool = None, -): - graph = PipelineGraph.create_graph() - network = graph.as_trt() - is_shape_layer = layer.num_inputs != 0 - for i in range(layer.num_inputs): - input = layer.get_input(i) - if input is not None: - shape = shapes[input.name] - if (values.get(input.name) is not None - and not isinstance(values[input.name], torch.Tensor)): - value = values[input.name] - weights = np.asarray(value, dtype=trt_dtype_to_np(input.dtype)) - weights = to_trt_weights(weights) - input_layer = network.add_constant(shape, weights) - new_input = input_layer.get_output(0) - new_input.name = input.name - graph.register_layer(input_layer) - elif graph.get_input(input.name) is None: - new_input = graph.add_input(input) - new_input.raw_shape = shapes[input.name] - is_shape_layer = False - new_layer = graph.add_layer( - layer, - updated_attrs=updated_attrs, - ) - output_mapping = {} - if layer.type == trt.LayerType.SHAPE: - is_shape_layer = True - if layer.num_inputs == 0: - is_shape_layer = False - if is_shape_io is not None: - is_shape_layer = is_shape_io - for i in range(layer.num_outputs): - output = layer.get_output(i) - value = values.get(output.name) - if value is not None and isinstance(value, torch.Tensor): - is_output_shape = False - elif is_shape_layer: - is_output_shape = True - else: - is_output_shape = False - if is_output_shape: - if output.dtype == trt.DataType.BOOL: - proxy_layer = network.add_cast( - new_layer.as_trt().get_output(i), - trt.DataType.INT32, - ) - proxy_output = proxy_layer.get_output(0) - graph.register_layer(proxy_layer) - output_mapping[proxy_output.name] = (output.name, output.dtype) - output = proxy_output - graph.add_output_shape(output) - else: - graph.add_output(output) - return graph, output_mapping - - -@_is_building -def infer_shapes(network, shapes, values, profile=None): - if network.num_outputs == 0: - return - builder = network.builder - config = builder.create_builder_config() - config.builder_optimization_level = 0 - config.flags = get_builder_flags() - profile = profile or builder.create_optimization_profile() - config.add_optimization_profile(profile) - plan = builder.build_serialized_network(network, config) - if plan is None: - raise RuntimeError( - 'Engine building failed when inferring shapes, please check the error log.' - ) - runtime = trt.Runtime(logger.trt_logger) - engine = runtime.deserialize_cuda_engine(plan) - context = engine.create_execution_context() - for i in range(network.num_inputs): - input = network.get_input(i) - if input.is_shape_tensor: - value = values[input.name] - context.set_shape_input(engine[input.name], value) - for i in range(network.num_outputs): - output = network.get_output(i) - shape = context.get_tensor_shape(output.name) - shapes[output.name] = shape - if output.is_shape_tensor: - if shape == [0]: - values[output.name] = [] - else: - if shape == []: - shape = [1] - value = torch.empty( - list(shape), - dtype=trt_dtype_to_torch(output.dtype), - device="cpu", - ) - values[output.name] = value - context.set_tensor_address(output.name, value.data_ptr()) - context.infer_shapes() - assert context.all_binding_shapes_specified - for i in range(network.num_outputs): - output = network.get_output(i) - if isinstance(values.get(output.name), torch.Tensor): - values[output.name] = values[output.name].tolist() - - -@dataclass -class ShapeInfo: - shapes: Dict[str, trt.Dims] - values: Dict[str, List[int]] - shape_layers: Set[str] - max_shapes: Dict[str, trt.Dims] = None - - -def set_constant_value(layer, values): - to_subclass_layer(layer) - output_name = layer.get_output(0).name - weights = layer.weights - if isinstance(weights, trt.Weights): - weights = weights.numpy() - values[output_name] = list(weights) - to_base_class_layer(layer) - - -def infer_per_layer_shapes( - layer: trt.ILayer, - shapes, - values, - cache=None, - is_shape_io=False, -): - if layer.type == trt.LayerType.CONSTANT: - to_subclass_layer(layer) - output_name = layer.get_output(0).name - shape = layer.shape - shapes[output_name] = shape - if is_shape_io: - set_constant_value(layer, values) - to_base_class_layer(layer) - return - elif layer.type == trt.LayerType.SHAPE: - input_name = layer.get_input(0).name - output_name = layer.get_output(0).name - shape = [*shapes[input_name]] - shapes[output_name] = trt.Dims([len(shape)]) - values[output_name] = shape - return - if cache is not None: - cache_key = get_cache_key(layer, shapes, values) - if cache_key in cache: - output_shapes, output_values = cache[cache_key] - for i in range(layer.num_outputs): - output = layer.get_output(i) - shapes[output.name] = output_shapes[i] - if output_values[i] is not None: - values[output.name] = output_values[i] - return - graph, output_mapping = get_per_layer_graph(layer, shapes, values) - dtypes = [ - trt_dtype_to_str(layer.get_input(i).dtype) - for i in range(layer.num_inputs) - ] - layer_info = (f"type={cache_key[0]}, " - f"attrs={dict(cache_key[1])}, " - f"dtypes={dtypes}, " - f"shapes={list(cache_key[2])}, " - f"values={list(cache_key[3])}") - logger.debug(f"infer shapes for layer {layer.name} ({layer_info})") - try: - infer_shapes(graph.as_trt(), shapes, values) - except RuntimeError as e: - raise RuntimeError( - f"infer shapes failed for layer {layer.name} ({layer_info})") from e - for proxy_output, (output, dtype) in output_mapping.items(): - shapes[output] = shapes[proxy_output] - del shapes[proxy_output] - if proxy_output in values: - values[output] = [ - *map(_trt_to_type_dict[dtype], values[proxy_output]) - ] - del values[proxy_output] - if cache is not None: - logger.debug( - f"shape inference cache miss, layer: {layer.name}, cache key: {cache_key}" - ) - output_shapes = [] - output_values = [] - for i in range(layer.num_outputs): - output = layer.get_output(i) - output_shapes.append(shapes[output.name]) - output_values.append(values.get(output.name)) - cache[cache_key] = (output_shapes, output_values) - - -def get_shape_info(trt_network, profile, shape_type: ShapeType = ShapeType.OPT): - shapes = {} - values = {} - sorted_layer_ids = get_sorted_layer_ids(trt_network) - infer_shape_layers = False - - shape_network, shape_profile, shape_layers, output_mapping = get_shape_network( - trt_network, - shapes, - values, - sorted_layer_ids, - profile=profile, - shape_type=shape_type) - try: - infer_shapes(shape_network, shapes, values, shape_profile) - for proxy_output, (output, dtype) in output_mapping.items(): - shapes[output] = shapes[proxy_output] - values[output] = [ - *map(_trt_to_type_dict[dtype], values[proxy_output]) - ] - del shapes[proxy_output] - del values[proxy_output] - except RuntimeError: - infer_shape_layers = True - - cache = {} - for layer_id in sorted_layer_ids: - layer = trt_network.get_layer(layer_id) - is_shape_io = layer.name in shape_layers - if is_shape_io and not infer_shape_layers: - continue - set_trt_network(layer, trt_network) - infer_per_layer_shapes(layer, - shapes, - values, - cache, - is_shape_io=is_shape_io) - return ShapeInfo(shapes, values, shape_layers) diff --git a/tensorrt_llm/auto_parallel/simplifier.py b/tensorrt_llm/auto_parallel/simplifier.py deleted file mode 100644 index 9381930c23..0000000000 --- a/tensorrt_llm/auto_parallel/simplifier.py +++ /dev/null @@ -1,837 +0,0 @@ -import math -import re -from dataclasses import dataclass -from enum import Enum -from typing import Dict, List, Tuple - -import numpy as np - -from tensorrt_llm.network import Network - -from .config import AutoParallelConfig -from .device_mesh import PhysicalDeviceMesh -from .pipeline_graph import PipelineGraph -from .shape_info import ShapeInfo, ShapeType, get_shape_info -from .tensor_parallel.p2p_node import P2PType -from .utils import get_cache_key, get_sorted_layer_ids, silent_trt_logger - - -class StageType(Enum): - START = 0 - BLOCK = 1 - END = 2 - - -class BuildingBlock: - - def __init__(self, graph, layer_range) -> None: - self.graph = graph - self.layer_range = layer_range - self.network = graph.as_trt() - self.owned_inputs = {} - self.is_edges_collected = False - self.intra_edges = [] - self.src_inter_edges = [] - self.dst_inter_edges = [] - self.relative_src_inter_edges = [] - self.relative_dst_inter_edges = [] - self.relative_inter_edges = set() - self.edge_hash = None - self.outputs = None - self.type_id = -1 - self.block_id = -1 - self.p2p_type = None - self.is_superset = False - self.is_subset = False - self.sorted_layer_ids = [] - - def collect_edges(self): - if self.is_edges_collected: - return - for layer_index in self.layer_range: - trt_layer = self.network.get_layer(layer_index) - layer = self.graph.get_layer(trt_layer.name) - layer_offset = layer.index - self.layer_range.start - for input_index, input in enumerate(layer.inputs): - if input is not None: - if input.is_graph_input: - is_owned = input.graph_input_index in self.owned_inputs - if not is_owned and np.all([ - layer.index in self.layer_range or np.all([ - output.as_trt().is_shape_tensor - for output in layer.outputs - ]) for layer, _ in input.consumers - ]): - self.owned_inputs[input.graph_input_index] = len( - self.owned_inputs) - is_owned = True - if is_owned: - self.intra_edges.append( - (-1, self.owned_inputs[input.graph_input_index], - layer_offset, input_index)) - else: - self.dst_inter_edges.append( - (-1, input.graph_input_index, layer_offset, - input_index)) - else: - src_layer_index = input.producer.index - if src_layer_index < self.layer_range.start or src_layer_index >= self.layer_range.stop: - self.dst_inter_edges.append( - (src_layer_index, input.output_index, - layer_offset, input_index)) - else: - src_layer_offset = src_layer_index - self.layer_range.start - self.intra_edges.append( - (src_layer_offset, input.output_index, - layer_offset, input_index)) - for output_index, output in enumerate(layer.outputs): - for dst_layer, dst_input_index in output.consumers: - dst_layer_index = dst_layer.index - if dst_layer_index < self.layer_range.start or dst_layer_index >= self.layer_range.stop: - self.src_inter_edges.append( - (layer_offset, output_index, dst_layer_index, - dst_input_index)) - self.edge_hash = tuple(self.intra_edges) - self.outputs = sorted( - set((edge[0], edge[1]) for edge in self.src_inter_edges)) - self.is_edges_collected = True - - def collect_relative_inter_edges(self, layer_to_block): - self.collect_edges() - for src_layer_index, src_output_index, dst_layer_index, dst_input_index in self.dst_inter_edges: - if src_layer_index in layer_to_block: - src_block = layer_to_block[src_layer_index] - src_layer_offset = src_layer_index - src_block.layer_range.start - dst = (self.type_id, dst_layer_index, dst_input_index) - self.relative_dst_inter_edges.append( - (src_block.type_id, src_layer_offset, src_output_index, - *dst)) - else: - self.relative_dst_inter_edges.append( - (-1, src_layer_index, src_output_index, self.type_id, - dst_layer_index, dst_input_index)) - self.relative_inter_edges = set(self.relative_dst_inter_edges + - self.outputs) - - def get_input_names(self): - self.collect_edges() - input_tensor_names = [] - for edge in self.dst_inter_edges: - layer_index = edge[0] - output_index = edge[1] - if layer_index == -1: - tensor_name = self.network.get_input(output_index).name - else: - tensor_name = self.network.get_layer(layer_index).get_output( - output_index).name - input_tensor_names.append(tensor_name) - return input_tensor_names - - def get_input_mapping(self, last_blocks): - input_mapping = {} - for tensor_name, relative_edge in zip(self.get_input_names(), - self.relative_dst_inter_edges): - type_id = relative_edge[0] - output_index = relative_edge[2] - if type_id >= 0: - last_block = last_blocks[type_id] - layer_offset = relative_edge[1] - mapped_layer_index = last_block.layer_range.start + layer_offset - mapped_tensor_name = self.network.get_layer( - mapped_layer_index).get_output(output_index).name - input_mapping[tensor_name] = mapped_tensor_name - else: - input_mapping[tensor_name] = tensor_name - return input_mapping - - -@dataclass -class GraphMapping: - layer_mapping: Dict[int, int] = None - block_mapping: Dict[int, int] = None - p2p_types: Dict[int, P2PType] = None - p2p_tensors: Dict[int, List[str]] = None - block_to_stage: Dict[int, int] = None - same_spec_layer_mapping: Dict[str, str] = None - - -@dataclass -class GraphConfig: - num_micro_batches: int = 1 - num_blocks: int = 1 - num_stages: int = 1 - has_cross_device: bool = False - has_cross_host: bool = False - graph_mapping: GraphMapping = None - phy_mesh: PhysicalDeviceMesh = None - stage_phy_meshes: List[PhysicalDeviceMesh] = None - - -class Simplifier: - - def __init__(self, network: Network, config: AutoParallelConfig): - self.config = config - self.sharded_io_allowlist = config.sharded_io_allowlist - self.same_buffer_io = config.same_buffer_io - self.same_spec_io = config.same_spec_io.copy() - for key, value in self.same_buffer_io.items(): - if key not in self.same_spec_io: - self.same_spec_io[key] = value - - self.llm_network = network - self.network = network.trt_network - self.module_to_layer_range_map = network._module_call_stack.module_to_layer_range_map - self.graph = self.get_graph() - self.init_layer_hash() - - module_tree = self.get_module_tree() - building_blocks = self.collect_building_blocks(module_tree) - blocks_by_module_hash = self.get_blocks_by_module_hash(building_blocks) - self.blocks_by_edge_hash = self.get_blocks_by_edge_hash( - blocks_by_module_hash) - self.layer_to_block = self.get_layer_to_block() - self.blocks = self.get_all_blocks() - self.backbone_blocks = self.get_backbone_blocks() - self.graph_mapping_for_shape = self.get_graph_mapping_for_shape() - self.graph_for_shape = self.create_simplified_graph_for_shape() - self.shape_info = None - self.num_micro_batches = None - - def infer_shapes(self, num_micro_batches): - if self.num_micro_batches == num_micro_batches: - return - with silent_trt_logger(): - self.shape_info = self.get_full_shape_info(num_micro_batches) - self.graph.assign_shapes(self.shape_info) - self.num_micro_batches = num_micro_batches - - def list_all_num_micro_batches(self): - opt_batch_size = self.get_opt_batch_size() - candidates = [] - for num_micro_batches in range(1, self.get_opt_batch_size() + 1): - if opt_batch_size % num_micro_batches == 0: - candidates.append(num_micro_batches) - return candidates - - def get_graph(self): - graph = PipelineGraph.from_trt(self.network) - graph._unfilled_weights = self.llm_network._unfilled_weights.copy() - graph._io_buffer_mapping - for input in graph.inputs: - input_name = input.name - for pattern, repl in self.same_buffer_io.items(): - if re.match(pattern, input_name): - output_name = re.sub(pattern, repl, input_name) - output = graph.get_output(output_name) - if output is not None: - graph._io_buffer_mapping[output_name] = input_name - return graph - - def get_opt_batch_size(self): - input_tensors = self.llm_network._inputs - num_profiles = len(list(input_tensors.values())[0].profiles) - opt_batch_sizes = [] - for i in range(num_profiles): - for input_tensor in input_tensors.values(): - shape_profile = input_tensor.profiles[i] - opt_shape = shape_profile.opt - for j in range(len(input_tensor.shape)): - name = input_tensor.trt_tensor.get_dimension_name(j) - if name == 'batch_size': - opt_batch_sizes.append(opt_shape[j]) - return min(opt_batch_sizes) - - def get_module_hash(self, layer_range): - module_hash = () - for i in layer_range: - assert i < self.network.num_layers, f"layer index {i} in {layer_range} out of range of {self.network.num_layers}" - layer_name = self.network.get_layer(i).name - layer = self.graph.get_layer(layer_name) - module_hash += (layer.attrs["hash"], ) - return module_hash - - def get_network_hash(self) -> str: - return str(self.get_module_hash(range(self.network.num_layers))) - - def collect_building_blocks(self, module_tree): - building_blocks = {} - queue = [] - for tree in module_tree["children"].values(): - queue.append(tree) - while len(queue) > 0: - while len(queue) > 0: - tree = queue.pop(0) - module_name = tree["name"] - if module_name is None: - for child in tree["children"].values(): - queue.append(child) - continue - layer_range = self.module_to_layer_range_map[module_name] - module_hash = self.get_module_hash(layer_range) - if module_hash in building_blocks: - building_blocks[module_hash].append(tree) - else: - building_blocks[module_hash] = [tree] - for module_hash in [*building_blocks.keys()]: - if len(building_blocks[module_hash]) == 1: - tree = building_blocks[module_hash][0] - for child in tree["children"].values(): - queue.append(child) - del building_blocks[module_hash] - blocks_by_module_hash = { - module_hash: [ - BuildingBlock(self.graph, - self.module_to_layer_range_map[tree["name"]]) - for tree in trees - ] - for module_hash, trees in building_blocks.items() - } - building_blocks = [] - for block_list in blocks_by_module_hash.values(): - for block in block_list: - building_blocks.append(block) - building_blocks = sorted(building_blocks, - key=lambda x: x.layer_range.start) - if len(building_blocks) >= 2: - for block, next_block in zip(building_blocks[:-1], - building_blocks[1:]): - block.layer_range = range(block.layer_range.start, - next_block.layer_range.start) - return building_blocks - - def get_all_blocks(self): - building_blocks = [] - for block_list in self.blocks_by_edge_hash.values(): - for block in block_list: - building_blocks.append(block) - building_blocks = sorted(building_blocks, - key=lambda x: x.layer_range.start) - all_blocks = [] - current_layer_index = 0 - block_id = 0 - for block in building_blocks: - assert current_layer_index <= block.layer_range.start - if current_layer_index < block.layer_range.start: - new_block = BuildingBlock( - self.graph, - range(current_layer_index, block.layer_range.start)) - new_block.block_id = block_id - block_id += 1 - all_blocks.append(new_block) - block.block_id = block_id - block_id += 1 - all_blocks.append(block) - current_layer_index = block.layer_range.stop - if current_layer_index < self.graph.num_layers: - new_block = BuildingBlock( - self.graph, range(current_layer_index, self.graph.num_layers)) - new_block.block_id = block_id - all_blocks.append(new_block) - sorted_layer_ids = get_sorted_layer_ids(self.network) - for block in all_blocks: - block.collect_relative_inter_edges(self.layer_to_block) - for layer_id in sorted_layer_ids: - if layer_id in block.layer_range: - block.sorted_layer_ids.append(layer_id) - return all_blocks - - def get_backbone_blocks(self): - sorted_blocks = sorted( - self.blocks_by_edge_hash.values(), - key=lambda blocks: (len(blocks), len(blocks[0].layer_range)), - ) - if len(sorted_blocks) == 0: - return [] - else: - return sorted_blocks[-1] - - def get_blocks_by_module_hash(self, blocks): - blocks_by_module_hash = {} - for block in blocks: - module_hash = self.get_module_hash(block.layer_range) - if module_hash not in blocks_by_module_hash: - blocks_by_module_hash[module_hash] = [] - blocks_by_module_hash[module_hash].append(block) - for module_hash in [*blocks_by_module_hash.keys()]: - if len(blocks_by_module_hash[module_hash]) == 1: - del blocks_by_module_hash[module_hash] - return blocks_by_module_hash - - def get_module_tree(self): - module_tree = {"children": {}, "name": None} - for module_name in self.module_to_layer_range_map.keys(): - full_name = module_name.split('.') - current_tree = module_tree["children"] - for depth, name in enumerate(full_name): - if name not in current_tree: - current_tree[name] = {"children": {}, "name": None} - if depth == len(full_name) - 1: - current_tree[name]["name"] = module_name - else: - current_tree = current_tree[name]["children"] - return module_tree - - def get_blocks_by_edge_hash(self, blocks_by_module_hash): - blocks_by_edge_hash = {} - for block_list in blocks_by_module_hash.values(): - for block in block_list: - block.collect_edges() - edge_hash = block.edge_hash - if edge_hash not in blocks_by_edge_hash: - blocks_by_edge_hash[edge_hash] = [] - blocks_by_edge_hash[edge_hash].append(block) - for edge_hash in [*blocks_by_edge_hash.keys()]: - if len(blocks_by_edge_hash[edge_hash]) == 1: - del blocks_by_edge_hash[edge_hash] - else: - block_list = blocks_by_edge_hash[edge_hash] - blocks_by_edge_hash[edge_hash] = sorted( - block_list, key=lambda x: x.layer_range.start) - for type_id, block_list in enumerate(blocks_by_edge_hash.values()): - for block in block_list: - block.type_id = type_id - return blocks_by_edge_hash - - def get_layer_to_block(self): - layer_to_block = {} - for block_list in self.blocks_by_edge_hash.values(): - for block in block_list: - for layer_index in block.layer_range: - layer_to_block[layer_index] = block - return layer_to_block - - def clean_blocks(self): - for block in self.blocks: - block.p2p_type = None - block.is_superset = False - block.is_subset = False - - def mark_p2p_type(self, phy_mesh, stage_phy_meshes, - graph_config: GraphConfig): - if len(self.backbone_blocks) == 0 or len(stage_phy_meshes) == 1: - return - assert len(self.backbone_blocks) % len(stage_phy_meshes) == 0 - block_per_stage = len(self.backbone_blocks) // len(stage_phy_meshes) - - for block in self.backbone_blocks: - block.p2p_type = None - for stage_index, stage_phy_mesh in enumerate(stage_phy_meshes[:-1]): - next_stage_phy_mesh = stage_phy_meshes[stage_index + 1] - last_device_id = stage_phy_mesh.phy_devices_id.flatten()[-1] - next_first_device_id = next_stage_phy_mesh.phy_devices_id.flatten( - )[0] - num_devices_per_host = phy_mesh.num_devices_per_host - next_block = self.backbone_blocks[(stage_index + 1) * - block_per_stage] - if last_device_id // num_devices_per_host != next_first_device_id // num_devices_per_host: - next_block.p2p_type = P2PType.CROSS_HOST - graph_config.has_cross_host = True - else: - next_block.p2p_type = P2PType.CROSS_DEVICE - graph_config.has_cross_device = True - - def get_graph_mapping(self): - layer_mapping = {} - block_mapping = {} - p2p_types = {} - p2p_tensors = {} - for block_list in self.blocks_by_edge_hash.values(): - superset_blocks = [] - superset_block_index = {} - for block in block_list: - block_added = False - for index, superset_block in enumerate(list(superset_blocks)): - if block.p2p_type == superset_block.p2p_type: - if block.relative_inter_edges.issubset( - superset_block.relative_inter_edges): - block.is_subset = True - block.is_superset = False - superset_block_index[id(block)] = index - block_added = True - break - elif superset_block.relative_inter_edges.issubset( - block.relative_inter_edges): - superset_block.is_subset = True - superset_block.is_superset = False - block.is_subset = False - block.is_superset = True - superset_blocks[index] = block - superset_block_index[id(block)] = index - block_added = True - break - if not block_added: - block.is_subset = False - block.is_superset = True - superset_blocks.append(block) - superset_block_index[id(block)] = len(superset_blocks) - 1 - for block in block_list: - assert not (block.is_subset and block.is_superset) - if block.is_subset: - superset_block = superset_blocks[superset_block_index[id( - block)]] - block_mapping[block.block_id] = superset_block.block_id - owned_inputs = map( - lambda x: x[0], - sorted(block.owned_inputs.items(), key=lambda x: x[1])) - superset_owned_inputs = map( - lambda x: x[0], - sorted(superset_block.owned_inputs.items(), - key=lambda x: x[1])) - for from_input_id, to_input_id in zip( - owned_inputs, superset_owned_inputs): - from_input_name = self.network.get_input( - from_input_id).name - to_input_name = self.network.get_input(to_input_id).name - layer_mapping[from_input_name] = to_input_name - for from_layer_id, to_layer_id in zip( - block.layer_range, superset_block.layer_range): - from_layer = self.network.get_layer(from_layer_id) - to_layer = self.network.get_layer(to_layer_id) - layer_mapping[from_layer.name] = to_layer.name - for i in range(from_layer.num_outputs): - from_output = from_layer.get_output(i) - if from_output.is_network_output: - to_output = to_layer.get_output(i) - layer_mapping[from_output.name] = to_output.name - if block.p2p_type is not None: - p2p_types[block.block_id] = block.p2p_type - p2p_tensors[block.block_id] = [ - *set(block.get_input_names()) - ] - for from_name, to_name in zip( - block.get_input_names(), - superset_block.get_input_names()): - layer_mapping[ - f"p2p_block{block.block_id}_{from_name}"] = f"p2p_block{superset_block.block_id}_{to_name}" - stage_id = 0 - block_to_stage = {} - for block in self.blocks: - if block.p2p_type is not None: - stage_id += 1 - block_to_stage[block.block_id] = stage_id - return GraphMapping( - layer_mapping, - block_mapping, - p2p_types, - p2p_tensors, - block_to_stage, - ) - - def create_simplified_graph(self, graph_config: GraphConfig): - new_graph = PipelineGraph.create_graph() - new_graph._io_buffer_mapping = self.graph._io_buffer_mapping - layer_mapping = graph_config.graph_mapping.layer_mapping - - for i in range(self.network.num_inputs): - trt_input = self.network.get_input(i) - if trt_input.name not in layer_mapping: - new_graph.add_input(trt_input) - - last_blocks = {} - same_spec_mapping = {} - same_spec_layer_mapping = {} - shape_mapping = {} - building_block_id = 0 - same_spec_ids = {} - same_spec_count = 0 - for block in self.blocks: - if not block.is_subset: - stage_type = None - if not block.is_superset: - if block.block_id == 0: - stage_type = StageType.START - elif block.block_id == len(self.blocks) - 1: - stage_type = StageType.END - input_mapping = block.get_input_mapping(last_blocks) - for from_name, to_name in [*input_mapping.items()]: - if to_name in same_spec_mapping: - input_mapping[from_name] = same_spec_mapping[to_name] - if to_name in layer_mapping: - input_mapping[from_name] = layer_mapping[to_name] - if block.is_superset and block.p2p_type is not None: - for from_name, to_name in [*input_mapping.items()]: - output_tensor = new_graph.get_tensor(to_name) - p2p_layer = new_graph.as_trt().add_identity( - output_tensor.as_trt()) - p2p_layer.name = f"p2p_block{block.block_id}_{from_name}" - p2p_layer.metadata = p2p_layer.name - p2p_tensor = p2p_layer.get_output(0) - p2p_tensor.name = f"{p2p_layer.name}_output" - wrapped_layer = new_graph.register_layer(p2p_layer) - wrapped_layer.attrs[ - "building_block_id"] = building_block_id - wrapped_layer.attrs["p2p_type"] = block.p2p_type - input_mapping[from_name] = p2p_tensor.name - shape_mapping[p2p_tensor.name] = from_name - building_block_id += 1 - for i in block.sorted_layer_ids: - layer = self.network.get_layer(i) - wrapped_layer = new_graph.add_layer( - layer, - input_mapping=input_mapping, - ) - wrapped_layer.attrs["building_block_id"] = building_block_id - wrapped_layer.attrs["stage_type"] = stage_type - if block.is_superset: - last_blocks[block.type_id] = block - - if block.type_id in same_spec_ids: - same_spec_id = same_spec_ids[block.type_id] - update_same_spec_count = False - else: - same_spec_id = same_spec_count - same_spec_ids[block.type_id] = same_spec_id - update_same_spec_count = True - count = same_spec_id - for i, (layer_offset, - output_index) in enumerate(block.outputs): - layer = self.network.get_layer(block.layer_range.start + - layer_offset) - tensor_name = layer.get_output(output_index).name - output_tensor = new_graph.get_tensor(tensor_name) - same_spec_layer = new_graph.as_trt().add_identity( - output_tensor.as_trt()) - same_spec_layer.name = f"{tensor_name}_same_spec" - same_spec_layer.metadata = same_spec_layer.name - same_spec_tensor = same_spec_layer.get_output(0) - same_spec_tensor.name = f"{same_spec_layer.name}_output" - wrapped_layer = new_graph.register_layer( - same_spec_layer) - wrapped_layer.attrs[ - "building_block_id"] = building_block_id - wrapped_layer.attrs["same_spec_id"] = count - count += 1 - same_spec_mapping[tensor_name] = same_spec_tensor.name - same_spec_layer_mapping[ - same_spec_layer.name] = layer.name - shape_mapping[same_spec_tensor.name] = tensor_name - for i, graph_input_index in enumerate( - block.owned_inputs.keys()): - input_name = self.network.get_input( - graph_input_index).name - input_tensor = new_graph.get_input(input_name) - input_tensor.attrs["same_spec_id"] = count - count += 1 - if update_same_spec_count: - same_spec_count = count - building_block_id += 1 - graph_config.graph_mapping.same_spec_layer_mapping = same_spec_layer_mapping - - if len(self.backbone_blocks) >= 2: - start_block = self.backbone_blocks[0] - if start_block.is_subset: - start_block = self.blocks[graph_config.graph_mapping. - block_mapping[start_block.block_id]] - for i in start_block.layer_range: - layer_name = self.network.get_layer(i).name - layer = new_graph.get_layer(layer_name) - layer.attrs["in_start_block"] = True - end_block = self.backbone_blocks[-1] - if end_block.is_subset: - end_block = self.blocks[graph_config.graph_mapping. - block_mapping[end_block.block_id]] - for i in end_block.layer_range: - layer_name = self.network.get_layer(i).name - layer = new_graph.get_layer(layer_name) - layer.attrs["in_end_block"] = True - slowest_p2p_type = None - if graph_config.has_cross_host: - slowest_p2p_type = P2PType.CROSS_HOST - elif graph_config.has_cross_device: - slowest_p2p_type = P2PType.CROSS_DEVICE - if slowest_p2p_type is not None: - for block in self.blocks: - if block.is_superset and block.p2p_type == slowest_p2p_type: - for i in block.layer_range: - layer_name = self.network.get_layer(i).name - layer = new_graph.get_layer(layer_name) - layer.attrs["in_slowest_block"] = True - - for i in range(self.network.num_outputs): - trt_output = self.network.get_output(i) - output = self.graph.get_output(trt_output.name) - if output.producer is not None and output.producer.index in self.layer_to_block and self.layer_to_block[ - output.producer.index].is_subset: - continue - if trt_output.is_shape_tensor: - new_output = new_graph.add_output_shape(trt_output) - else: - new_output = new_graph.add_output(trt_output) - sharded_io = False - for pattern in self.sharded_io_allowlist: - if re.match(pattern, new_output.name): - sharded_io = True - break - if not sharded_io: - new_output.producer.attrs["is_replicated"] = True - - for input in new_graph.inputs: - input_name = input.name - sharded_io = False - for pattern in self.sharded_io_allowlist: - if re.match(pattern, input_name): - sharded_io = True - break - if not sharded_io: - input.attrs["is_replicated"] = True - for pattern, repl in self.same_spec_io.items(): - if re.match(pattern, input_name): - output_name = re.sub(pattern, repl, input_name) - output = new_graph.get_output(output_name) - if output is not None: - if "same_spec_id" in input.attrs: - same_spec_id = input.attrs["same_spec_id"] - else: - same_spec_id = same_spec_count - same_spec_count += 1 - input.attrs["same_spec_id"] = same_spec_id - output.attrs["same_spec_id"] = same_spec_id - if math.prod(self.graph.get_input( - input_name).shape) < math.prod( - self.graph.get_output(output_name).shape): - input.attrs["no_memory_footprint"] = True - else: - output.attrs["no_memory_footprint"] = True - - return new_graph, shape_mapping - - def enrich_shape_info(self, shape_mapping): - shapes = self.shape_info.shapes.copy() - max_shapes = self.shape_info.max_shapes.copy() - values = self.shape_info.values.copy() - shape_layers = self.shape_info.shape_layers - for from_name, to_name in shape_mapping.items(): - if to_name in shapes: - shapes[from_name] = shapes[to_name] - if to_name in max_shapes: - max_shapes[from_name] = max_shapes[to_name] - if to_name in values: - values[from_name] = values[to_name] - shape_info = ShapeInfo(shapes, values, shape_layers, max_shapes) - return shape_info - - def simplify_graph( - self, phy_mesh: PhysicalDeviceMesh, num_stages: int, - num_devices_per_stage: int) -> Tuple[PipelineGraph, GraphConfig]: - num_blocks = len(self.backbone_blocks) - if num_blocks % num_stages != 0: - return None, None - graph_config = GraphConfig() - graph_config.num_micro_batches = self.num_micro_batches - graph_config.num_blocks = num_blocks - graph_config.num_stages = num_stages - graph_config.phy_mesh = phy_mesh - stage_phy_meshes = phy_mesh.split_pipeline_meshes( - num_stages, num_devices_per_stage) - graph_config.stage_phy_meshes = stage_phy_meshes - with silent_trt_logger(): - self.clean_blocks() - self.mark_p2p_type(phy_mesh, stage_phy_meshes, graph_config) - graph_config.graph_mapping = self.get_graph_mapping() - new_graph, shape_mapping = self.create_simplified_graph( - graph_config) - shape_info = self.enrich_shape_info(shape_mapping) - new_graph.assign_shapes(shape_info) - return new_graph, graph_config - - def get_graph_mapping_for_shape(self): - layer_mapping = {} - tensor_mapping = {} - for block_list in self.blocks_by_edge_hash.values(): - head_block = block_list[0] - for block in block_list[1:]: - for from_layer_id, to_layer_id in zip(block.layer_range, - head_block.layer_range): - from_layer = self.network.get_layer(from_layer_id) - to_layer = self.network.get_layer(to_layer_id) - layer_mapping[from_layer.name] = to_layer.name - for i in range(from_layer.num_outputs): - tensor_mapping[from_layer.get_output( - i).name] = to_layer.get_output(i).name - return layer_mapping, tensor_mapping - - def create_simplified_graph_for_shape(self): - new_graph = PipelineGraph.create_graph() - - for i in range(self.network.num_inputs): - trt_input = self.network.get_input(i) - new_graph.add_input(trt_input) - - head_blocks = {} - removed_blocks = set() - removed_layers = set() - for block_list in self.blocks_by_edge_hash.values(): - head_block = block_list[0] - head_blocks[head_block.type_id] = head_block - for block in block_list[1:]: - removed_blocks.add(id(block)) - for layer_index in block.layer_range: - removed_layers.add(layer_index) - - for block in self.blocks: - if not id(block) in removed_blocks: - input_mapping = block.get_input_mapping(head_blocks) - for i in block.sorted_layer_ids: - layer = self.network.get_layer(i) - new_graph.add_layer( - layer, - input_mapping=input_mapping, - ) - - for i in range(self.network.num_outputs): - trt_output = self.network.get_output(i) - output = self.graph.get_output(trt_output.name) - if output.producer is not None and output.producer.index in removed_layers: - continue - if trt_output.is_shape_tensor: - new_graph.add_output_shape(trt_output) - else: - new_graph.add_output(trt_output) - - return new_graph - - def get_full_shape_info(self, num_micro_batches): - layer_mapping, tensor_mapping = self.graph_mapping_for_shape - optimization_profiles = self.llm_network._generate_optimization_profiles( - ) - if len(optimization_profiles) > 0: - optimization_profile = optimization_profiles[-1] - else: - optimization_profile = None - shape_info = get_shape_info(self.graph_for_shape.as_trt(), - optimization_profile) - max_shape_info = get_shape_info(self.graph_for_shape.as_trt(), - optimization_profile, - shape_type=ShapeType.MAX) - shape_info.max_shapes = max_shape_info.shapes - for removed_tensor_name, tensor_name in tensor_mapping.items(): - shape_info.shapes[removed_tensor_name] = shape_info.shapes[ - tensor_name] - shape_info.max_shapes[removed_tensor_name] = shape_info.max_shapes[ - tensor_name] - if tensor_name in shape_info.values: - shape_info.values[removed_tensor_name] = shape_info.values[ - tensor_name] - for removed_layer_name, layer_name in layer_mapping.items(): - if layer_name in shape_info.shape_layers: - shape_info.shape_layers.add(removed_layer_name) - return shape_info - - def init_layer_hash(self): - with silent_trt_logger(): - optimization_profiles = self.llm_network._generate_optimization_profiles( - ) - if len(optimization_profiles) > 0: - optimization_profile = optimization_profiles[-1] - else: - optimization_profile = None - shape_info = get_shape_info(self.network, optimization_profile) - dtypes = {tensor.name: tensor.dtype for tensor in self.graph.tensors} - for layer in self.graph.layers: - layer_hash = get_cache_key( - layer.as_trt(), - shape_info.shapes, - shape_info.values, - dtypes, - ) - layer.attrs["hash"] = layer_hash diff --git a/tensorrt_llm/auto_parallel/solver.py b/tensorrt_llm/auto_parallel/solver.py deleted file mode 100644 index 7dd195117d..0000000000 --- a/tensorrt_llm/auto_parallel/solver.py +++ /dev/null @@ -1,641 +0,0 @@ -"""This code is adapted from Alpa https://github.com/alpa-projects/alpa/ with some changes. -""" -import multiprocessing -import time -import warnings -from collections import defaultdict - -import numpy as np -import pulp -from pulp import LpMinimize, LpProblem, LpVariable, lpDot, lpSum - -from ..logger import logger - - -class Solution: - - def __init__(self, leaf_strategies, s_val, e_val, edge_pairs, - node_index_dict, total_cost): - self.leaf_strategies = leaf_strategies - self.nodes = [ - strategies_vector.node for strategies_vector in self.leaf_strategies - ] - self.s_val = s_val - self.e_val = e_val - self.total_cost = total_cost - self.edge_pairs = list(np.reshape(edge_pairs, (-1, 2))) - self.node_index_dict = node_index_dict - self.index_node_dict = {} - for node, index in self.node_index_dict.items(): - self.index_node_dict[index] = node - self.node_best_strategy = {} - self._annotate_strategy() - - def _annotate_strategy(self): - self.node_best_strategy = {} - for index, node in enumerate(self.nodes): - best_strategy_id = self.s_val[index] - best_strategy = self.leaf_strategies[index][best_strategy_id] - self.node_best_strategy[node.node_name] = best_strategy - - for edge_idx, edge_pair in enumerate(self.edge_pairs): - src_node = self.index_node_dict[edge_pair[0]] - dst_node = self.index_node_dict[edge_pair[1]] - src_node_index = self.node_index_dict[src_node] - for dst_pre_node in dst_node.predecessor_nodes: - if dst_pre_node is None: - continue - if src_node.node_name == dst_pre_node.node_name: - self.node_best_strategy[ - dst_node.node_name].best_resharding_cost[ - src_node.node_name] = [ - self.node_best_strategy[dst_node.node_name]. - resharding_costs[src_node.node_name][ - self.s_val[src_node_index]] - ] - - def print_solution(self): - for index, node in enumerate(self.nodes): - best_strategy = self.node_best_strategy[node.node_name] - print(f'\n[{index}]: node_name = {node.node_name}') - best_strategy.print_strategy(best_resharding_cost_only=True) - print(f'solution total cost = {self.total_cost}') - - -class CostGraph: - ''' - A graph data structure to simplify the edge cost graph. It has two main functions: - 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in - CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list. - 2. To reduce the searching space, we merge computationally-trivial operators, such as - element-wise operators, transpose, and reduction, into their following nodes. The merging information will - be given by the StrategiesVector depending on the type of target node and following nodes. - - Argument: - leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph. - simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True) - ''' - - def __init__(self, leaf_strategies): - self.leaf_strategies = leaf_strategies - self.nodes = [ - strategies_vector.node for strategies_vector in leaf_strategies - ] - # stores number of strategies in each node - self.node_strategies_vector = {} - for node, strategies_vector in zip(self.nodes, self.leaf_strategies): - self.node_strategies_vector[node] = strategies_vector - # extra_node_costs will store the extra costs introduced by merging nodes - self.extra_node_costs = {} - self.following_dict = {} - self._build_cost_graph() - - def _remove_invalid_node(self, node, attr_name): - remove_list = [] - target_node_list = getattr(node, attr_name, []) - for target_node in target_node_list: - if target_node not in self.nodes: - remove_list.append(target_node) - for element in remove_list: - target_node_list.remove(element) - - def _build_cost_graph(self): - ''' - This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be - set to node. - ''' - self.edge_costs = {} - for dst_node, strategies_vector in zip(self.nodes, - self.leaf_strategies): - # build edge_cost - for src_node in dst_node.predecessor_nodes: - if src_node is None: - continue - if src_node not in self.nodes: - continue - node_pair = (src_node, dst_node) - edge_cost = {} - for i in range(len(strategies_vector)): - for j in range(len(self.node_strategies_vector[src_node])): - resharding_cost = strategies_vector[i].resharding_costs[ - src_node.node_name][j][-1] - edge_cost[(j, i)] = resharding_cost - self.edge_costs[node_pair] = edge_cost - - def get_edge_cost(self, src_node, dst_node): - return self.edge_costs[(src_node, dst_node)] - - -class Solver: - INFINITY_COST = 1e13 - - def __init__(self, - cost_graph: CostGraph, - memory_budget: float = -1.0, - solution_numbers: int = 1, - memory_increasing_coefficient: float = 1.3, - verbose=False): - ''' - Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. - Argument: - graph: The computing graph to be optimized. - strategies_constructor: It will provide all the possible strategies for each node in the computing graph. - cost_graph: A graph data structure to simplify the edge cost graph. - graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints. - memory_budget: Memory constraint for the solution. - solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget. - memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget. - ''' - self.cost_graph = cost_graph - self.leaf_strategies = cost_graph.leaf_strategies - self.nodes = cost_graph.nodes - self.memory_budget = memory_budget - self.solution_numbers = solution_numbers - if self.solution_numbers > 1: - self.memory_increasing_coefficient = memory_increasing_coefficient - else: - self.memory_increasing_coefficient = 1 - # temporarily we use all nodes as liveness list, we count the backward memory cost together with - # forward memory cost into the node memory cost, and no activation checkpoint is used in this phase. - # self.liveness_list = self.graph_analyser.liveness_analysis() - self.liveness_list = self.nodes - self.node_index_dict = self._generate_node_index_dict() - # The last solution vector of auto sharding. - self.last_s_val = None - # The last objective value of the best ILP solution. - self.last_objective = None - self.verbose = verbose - - def _generate_node_index_dict(self): - node_index_dict = {} - for index, node in enumerate(self.nodes): - node_index_dict[node] = index - return node_index_dict - - def _prepare_data_for_solver(self): - ''' - Extract information from components for solver. - ''' - node_nums = len(self.leaf_strategies) - memory_budget = self.memory_budget - - # prepare strategies_len - strategies_len = [] - for node in self.nodes: - strategies_len.append( - len(self.cost_graph.node_strategies_vector[node])) - strategies_len = np.array(strategies_len) - - # prepare edge_pairs and resharding costs - edge_pairs = [] - resharding_costs = [] - edge_cost_level = [] - edge_resharding_weights = [] - for pairs, edge_cost in self.cost_graph.edge_costs.items(): - src_node = pairs[0] - dst_node = pairs[1] - src_node_index = self.node_index_dict[src_node] - dst_node_index = self.node_index_dict[dst_node] - edge_pairs.append(src_node_index) - edge_pairs.append(dst_node_index) - edge_cost_level.append( - (dst_node.building_block_id, dst_node.cost_level)) - for i in range(strategies_len[src_node_index]): - for j in range(strategies_len[dst_node_index]): - resharding_costs.append(edge_cost[(i, j)]) - edge_resharding_weights.append(dst_node.resharding_weight + - dst_node.pipeline_weight) - edge_pairs = np.array(edge_pairs) - resharding_costs = np.array(resharding_costs) - edge_resharding_weights = np.array(edge_resharding_weights) - # prepare compute_costs, communication_costs and memory_costs - compute_costs = [] - communication_costs = [] - memory_costs = [] - peak_act_memory_costs, constant_memory_costs = [], [] - node_sharding_weights = [] - for node, strategies_vector in zip(self.nodes, self.leaf_strategies): - for index, strategy in enumerate(strategies_vector): - compute_cost = strategy.sharding_cost - origin_communication_cost = strategy.communication_cost - memory_cost = strategy.const_memory_footprint * node.sharding_weight - peak_act_memory = strategy.peak_memory_footprint - # extract the memory cost in float from MemoryCost item and sum them up - compute_costs.append(compute_cost) - # node in extra_node_costs means it has some extra communication - # cost from node merging, so we need to add those extra communication - # cost into - - communication_costs.append(origin_communication_cost) - peak_act_memory_costs.append(peak_act_memory) - constant_memory_costs.append(memory_cost) - node_sharding_weights.append(node.sharding_weight + - node.pipeline_weight) - - compute_costs = np.array(compute_costs) - communication_costs = np.array(communication_costs) - memory_costs = np.array([constant_memory_costs, peak_act_memory_costs]) - node_sharding_weights = np.array(node_sharding_weights) - same_spec_nodes_dict = defaultdict(list) - node_cost_level = [] - for idx, node in enumerate(self.nodes): - if node.same_spec_id >= 0: - same_spec_nodes_dict[node.same_spec_id].append(idx) - node_cost_level.append((node.building_block_id, node.cost_level)) - # omit initial value for nodes - s_init_np = None - following_nodes = [-1 for i in range(node_nums)] - liveness_set = self.nodes - alias_set = [] - alias_convert_costs = None - return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, node_sharding_weights, edge_resharding_weights, same_spec_nodes_dict, node_cost_level, edge_cost_level, alias_convert_costs, s_init_np, self.verbose - - def _call_solver_serialized_args(self, - node_nums, - memory_budget, - strategies_len, - following_nodes, - edge_pairs, - alias_set, - liveness_set, - compute_costs, - communication_costs, - memory_costs, - resharding_costs, - node_sharding_weights, - edge_resharding_weights, - same_spec_nodes_dict, - node_cost_level, - edge_cost_level, - alias_convert_costs, - s_init_np=None, - verbose=True): - """ - Call the solver with serialized arguments. - """ - - time.time() - - for x in [ - strategies_len, edge_pairs, compute_costs, communication_costs, - memory_costs, resharding_costs, node_sharding_weights, - edge_resharding_weights - ]: - assert isinstance(x, np.ndarray) - assert len(strategies_len) == node_nums, "strategies_len" - - def get_non_zero_index(binary_vector): - """ - Get the index of non-zero item in a vector. - """ - ct = 0 - ret = None - for i, elem in enumerate(binary_vector): - if pulp.value(elem): - ret = i - ct += 1 - - assert ct == 1 - return ret - - # 0. Unpack flatten numpy arrays - s_follow = following_nodes - s_alias = alias_set - - E = edge_pairs.reshape((-1, 2)) # noqa - r = [] - pt = 0 - edge_set = set() - for (i, j) in E: - prod_length = strategies_len[i] * strategies_len[j] - - if (i, j) in edge_set: - raise ValueError(f"Duplicated edges: {(i, j)}") - - edge_set.add((i, j)) - r.append(resharding_costs[pt:pt + prod_length]) - pt += prod_length - assert pt == len(resharding_costs) - - ###################### - # omit alias set now # - ###################### - - # A = alias_set.reshape((-1, 2)) # noqa - # for (i, j) in A: - # prod_length = strategies_len[i] * strategies_len[j] - # v.append(alias_convert_costs[pt:pt + prod_length]) - # pt += prod_length - # assert pt == len(alias_convert_costs) - - # L = [] # noqa - # pt = node_nums - # for i in range(node_nums): - # length = liveness_set[i] - # L.append(liveness_set[pt:pt + length]) - # pt += length - # assert pt == len(liveness_set) - pt = 0 - - c = [] - d = [] - m = [] - peak_m = [] - pt = 0 - for i in range(node_nums): - length = strategies_len[i] - c.append(compute_costs[pt:pt + length]) - d.append(communication_costs[pt:pt + length]) - m.append(memory_costs[0][pt:pt + length]) - peak_m.append(memory_costs[1][pt:pt + length]) - pt += length - assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}" - assert pt == len( - communication_costs), f"{pt} == {len(communication_costs)}" - assert pt == len(memory_costs[0]), f"{pt} == {len(memory_costs[0])}" - - # 1. Create variables - - ############################# - # create variables for node # - ############################# - s = [] - num_nodes = 0 - reverse_follow_backpatch = [] - for i in range(node_nums): - if s_follow[i] < 0: - if strategies_len[i] == 1: - s.append([1]) - else: - if i not in s_alias: - num_nodes += 1 - s.append( - LpVariable.matrix(f"s[{i}]", - (range(strategies_len[i]), ), - cat="Binary")) - else: - s.append(s[s_alias[i]]) - else: - if s_follow[i] < len(s): - s.append(s[s_follow[i]]) - else: - s.append(None) - reverse_follow_backpatch.append(i) - - for i in reverse_follow_backpatch: - s[i] = s[s_follow[i]] - - ############################# - # create variables for edge # - ############################# - e = [] - num_edges = 0 - map_edge_to_idx = {} - for (idx, (i, j)) in enumerate(E): - if len(s[i]) == 1: - e.append(s[j]) - elif len(s[j]) == 1: - e.append(s[i]) - else: - if i in s_alias and j in s_alias and ( - s_alias[i], s_alias[j]) in map_edge_to_idx: - e.append(e[map_edge_to_idx[(s_alias[i], s_alias[j])]]) - else: - num_edges += 1 - e.append( - LpVariable.matrix(f"e[{i},{j}]", - (range(len(s[i]) * len(s[j])), ), - cat="Binary")) - assert len(e[idx]) == len(r[idx]) - map_edge_to_idx[(i, j)] = idx - for element in s: - assert len(element) > 0 - # 2. Set initial value - ###################################### - # set a initial value for warm start # - ###################################### - if s_init_np is not None: - s_init = s_init_np.reshape((-1, 3)) - for (idx, value, fix) in s_init: - for i in range(len(s[idx])): - s[idx][i].setInitialValue(i == value) - if fix: - s[idx][i].fixValue() - - # 3. Objective - prob = LpProblem("myProblem", LpMinimize) - ################################################################### - # computing the node cost(computing cost and communication cost) # - ################################################################### - obj = 0 - block_cost_level_dict = {} - for i in range(node_nums): - assert len(s[i]) == len(c[i]) - assert len(s[i]) == len(d[i]) - obj += (lpDot(s[i], c[i]) + - lpDot(s[i], d[i])) * node_sharding_weights[i] - cost_level = node_cost_level[i] - if -1 != cost_level[1]: - if cost_level in block_cost_level_dict: - block_cost_level_dict[cost_level] += lpDot( - s[i], c[i]) + lpDot(s[i], d[i]) - else: - block_cost_level_dict[cost_level] = lpDot( - s[i], c[i]) + lpDot(s[i], d[i]) - - ############################################# - # computing the edge cost(resharding cost) # - ############################################# - - for i in range(len(E)): - assert len(e[i]) == len(r[i]) - obj += lpDot(e[i], r[i]) * edge_resharding_weights[i] - cost_level = edge_cost_level[i] - if -1 != cost_level[1]: - if cost_level in block_cost_level_dict: - block_cost_level_dict[cost_level] += lpDot(e[i], r[i]) - else: - block_cost_level_dict[cost_level] = lpDot(e[i], r[i]) - prob += obj - if len(block_cost_level_dict) >= 2: - block_cost_levels = [key for key in block_cost_level_dict.keys()] - for i in range(len(block_cost_levels)): - for j in range(i + 1, len(block_cost_levels)): - if block_cost_levels[i][1] > block_cost_levels[j][1]: - prob += block_cost_level_dict[ - block_cost_levels[i]] >= block_cost_level_dict[ - block_cost_levels[j]] + 1e-6 - elif block_cost_levels[i][1] < block_cost_levels[j][1]: - prob += block_cost_level_dict[ - block_cost_levels[j]] >= block_cost_level_dict[ - block_cost_levels[i]] + 1e-6 - # 4. Constraints - # (a). specified by `cat="Binary"` - - # (b) - ################################################# - # make sure each node only choose one strategy # - ################################################# - for i in range(node_nums): - if s_follow[i] < 0: - prob += lpSum(s[i]) == 1 - - # (c) - ################################################# - # force to constrain some nodes have the same sharding specs # - ################################################# - for spec_id, same_spec_nodes_id in same_spec_nodes_dict.items(): - num_same_spec_nodes = len(same_spec_nodes_id) - if num_same_spec_nodes >= 2: - src_node_s = s[same_spec_nodes_id[0]] - num_specs = len(src_node_s) - for i in range(1, num_same_spec_nodes): - dst_node_s = s[same_spec_nodes_id[i]] - assert len( - dst_node_s - ) == num_specs, f'unmatched num_specs when force node {same_spec_nodes_id[0]} and {same_spec_nodes_id[i]} the same specs' - for j in range(num_specs): - prob += (src_node_s[j] == dst_node_s[j]) - - # (c) - ################################################# - # compute memory consumption with liveness set # - ################################################# - if memory_budget > 0: - # calculate the constant memory - mem = 0 - for node in liveness_set: - if node not in self.node_index_dict: - continue - node_index = self.node_index_dict[node] - mem += lpSum(s[node_index][j] * m[node_index][j] - for j in range(len(s[node_index]))) - # calculate the peak activation memory - for node in liveness_set: - if node not in self.node_index_dict: - continue - node_index = self.node_index_dict[node] - cur_peak_mem = lpSum(s[node_index][j] * peak_m[node_index][j] - for j in range(len(s[node_index]))) - total_mem = mem + cur_peak_mem - prob += total_mem <= memory_budget - - # (d). specified by `cat="Binary"` - - for (idx, (i, j)) in enumerate(E): - if strategies_len[i] == 1 or strategies_len[j] == 1: - continue - - # (e) - prob += lpSum(e[idx]) == 1 - - # (f) - for row in range(len(s[i])): - C = len(s[j]) # noqa - prob += lpSum(e[idx][row * C + col] - for col in range(0, C)) <= s[i][row] - - # (g) - for col in range(len(s[j])): - R = len(s[i]) # noqa - C = len(s[j]) # noqa - prob += lpSum(e[idx][row * C + col] - for row in range(0, R)) <= s[j][col] - - if prob.objective.isNumericalConstant(): - objective = float(pulp.value(prob.objective)) - status = pulp.LpStatusOptimal - else: - msg = verbose - time_limit = 600 - solver = pulp.PULP_CBC_CMD( - mip=True, - msg=msg, - timeLimit=time_limit, - threads=multiprocessing.cpu_count(), - ) - prob.solve(solver) - - status = prob.status - objective = pulp.value(prob.objective) - objective = float( - objective) if objective is not None else self.INFINITY_COST - - if prob.status in [pulp.LpStatusInfeasible]: - objective = self.INFINITY_COST - - # Get and check results - s_val = np.full((node_nums, ), -1, dtype=np.int32) - for i in range(node_nums): - s_val[i] = get_non_zero_index(s[i]) - - e_val = np.full((len(E), ), -1, dtype=np.int32) - for (idx, (i, j)) in enumerate(E): - e_val[idx] = get_non_zero_index(e[idx]) - i_spec_index = e_val[idx] // len(s[j]) - j_spec_index = e_val[idx] % len(s[j]) - assert i_spec_index == s_val[i], f"e_val[{i}][{j}]" - assert j_spec_index == s_val[j], f"e_val[{i}][{j}]" - if verbose and r[idx][e_val[idx]] > 0: - print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}") - - self.last_s_val = list(s_val) - # self._recover_merged_node_strategy() - self.last_objective = objective - - if objective >= self.INFINITY_COST: - warnings.warn( - f"Cannot find an optimized solution given memory budget {self.memory_budget}, Please consider\n" + \ - f"1. increase memory budget if possible\n" + \ - f"2. enlarge mesh shape if possible\n" + \ - f"3. decrease the maximum parameters(i.e., max_batch_size, max_seq_len, etc.) in building config") - if memory_budget > 0: - # calculate the constant memory - mem = 0 - for node in liveness_set: - if node not in self.node_index_dict: - continue - node_index = self.node_index_dict[node] - j = self.last_s_val[node_index] - mem += m[node_index][j] - max_peak_mem = 0 - for node in liveness_set: - if node not in self.node_index_dict: - continue - node_index = self.node_index_dict[node] - j = self.last_s_val[node_index] - cur_peak_mem = peak_m[node_index][j] - max_peak_mem = max(max_peak_mem, cur_peak_mem) - logger.debug( - f'constant_mem = {mem}, peak_mem = {max_peak_mem}, memory_budget = {memory_budget}' - ) - - solution = Solution(self.leaf_strategies, self.last_s_val, e_val, - edge_pairs, self.node_index_dict, - self.last_objective) - return status, solution - - def find_solution(self): - """ - Call the solver with serialized arguments and handle python errors. Additionally, - we could give a serious of solutions with different memory budget. - """ - if self.solution_numbers == 1: - args = self._prepare_data_for_solver() - ret = self._call_solver_serialized_args(*args) - - return ret - - origin_memory_budget = self.memory_budget - memory_budget_list = [ - origin_memory_budget * self.memory_increasing_coefficient**i - for i in range(self.solution_numbers) - ] - ret_list = [] - for memory_budget in memory_budget_list: - self.memory_budget = memory_budget - args = self._prepare_data_for_solver() - ret = self._call_solver_serialized_args(*args) - ret_list.append(ret) - - return ret_list diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/__init__.py b/tensorrt_llm/auto_parallel/tensor_parallel/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/activation_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/activation_node.py deleted file mode 100644 index a5fd513103..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/activation_node.py +++ /dev/null @@ -1,41 +0,0 @@ -import copy - -from .node import Node -from .sharding_strategy import StrategiesVector - - -class Activation(Node): - - def _collect_strategies(self, device_mesh): - dim_partition_list = [] - dim_size = len(self.op_data['input0'].shape) - dim_partition_list.append({}) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) - strategies_vector = StrategiesVector(self) - for dim_partition_dict in dim_partition_list: - in0_partition_dict = dim_partition_dict - out_partition_dict = copy.deepcopy(dim_partition_dict) - dim_partition_dict_mapping = { - "input0": in0_partition_dict, - "output0": out_partition_dict, - } - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if 0 == len(sharding_spec_mapping): - continue - name = '{} = {}'.format( - sharding_spec_mapping['output0'].sharding_sequence, - sharding_spec_mapping['input0'].sharding_sequence) - sharding_strategy = self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - strategies_vector.append(sharding_strategy) - return strategies_vector diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/assertion_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/assertion_node.py deleted file mode 100644 index 73e100559b..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/assertion_node.py +++ /dev/null @@ -1,34 +0,0 @@ -import copy - -from .node import Node -from .sharding_strategy import StrategiesVector - - -class Assertion(Node): - - def _collect_strategies(self, device_mesh): - predecessor = self.predecessor_nodes[0] # one input for softmax node - strategies_vector = StrategiesVector(self) - for idx, strategy in enumerate(predecessor.strategies_vector): - global_input_name = self.op_data[ - 'input0'].name # current node's local name input0 -> global name xxx - prenode_local_name = predecessor.global_to_local_op_name[ - global_input_name] # global name xxx -> pre node local output name - dim_partition_dict = copy.deepcopy( - strategy.sharding_specs[prenode_local_name].dim_partition_dict) - in0_partition_dict = dim_partition_dict - dim_partition_dict_mapping = { - "input0": in0_partition_dict, - } - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if 0 == len(sharding_spec_mapping): - return strategies_vector - name = ' {}'.format( - sharding_spec_mapping['input0'].sharding_sequence) - sharding_strategy = self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - strategies_vector.append(sharding_strategy) - return strategies_vector diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/cast_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/cast_node.py deleted file mode 100644 index b58083ea6d..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/cast_node.py +++ /dev/null @@ -1,45 +0,0 @@ -import copy - -from .node import Node -from .sharding_strategy import StrategiesVector - - -class Cast(Node): - - def _collect_strategies(self, device_mesh): - dim_partition_list = [] - dim_size = len(self.op_data['input0'].shape) - dim_partition_list.append({}) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) - strategies_vector = StrategiesVector(self) - for dim_partition_dict in dim_partition_list: - in0_partition_dict = dim_partition_dict - out_partition_dict = copy.deepcopy(dim_partition_dict) - dim_partition_dict_mapping = { - "input0": in0_partition_dict, - "output0": out_partition_dict, - } - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if 0 == len(sharding_spec_mapping): - continue - name = '{} = {}'.format( - sharding_spec_mapping['output0'].sharding_sequence, - sharding_spec_mapping['input0'].sharding_sequence) - sharding_strategy = self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - strategies_vector.append(sharding_strategy) - - return strategies_vector - - def _update_memory_cost(self, strategies): - pass diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/comm_spec.py b/tensorrt_llm/auto_parallel/tensor_parallel/comm_spec.py deleted file mode 100644 index 3f164ee99d..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/comm_spec.py +++ /dev/null @@ -1,58 +0,0 @@ -__all__ = [ - 'CommSpec', -] - - -class CommSpec: - - def __init__(self, - comm_pattern, - sharding_spec, - gather_dim=None, - shard_dim=None, - logical_process_axis=None, - mix_gather=False, - forward_only=True): - self.comm_pattern = comm_pattern - self.sharding_spec = sharding_spec - self.gather_dim = gather_dim - self.shard_dim = shard_dim - self.logical_process_axis = logical_process_axis - self.device_mesh = self.sharding_spec.device_mesh - self.mix_gather = mix_gather - self.forward_only = forward_only - if self.gather_dim: - assert len(self.gather_dim) == len( - self.logical_process_axis - ), f'unmatched gather dim {self.gather_dim} and logical process axis {self.logical_process_axis}' - if self.shard_dim: - assert len(self.shard_dim) == len( - self.logical_process_axis - ), f'unmatched shard dim {self.shard_dim} and logical process axis {self.logical_process_axis}' - if self.gather_dim and self.shard_dim: - assert len(self.shard_dim) == len( - self.gather_dim - ), f'unmatched gather dim {self.gather_dim} and shard dim {self.shard_dim}' - - def get_comm_cost(self): - ''' - For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to - compute the communication cost. - For shard operation, it is an on-chip operation, so the communication cost is zero. - ''' - comm_size = self.sharding_spec.get_sharded_size_per_device() - dtype = self.sharding_spec.dtype - - # reduce list_of_list to list - comm_dims = sum(self.logical_process_axis, []) - comm_cost = self.device_mesh.estimate_comm_cost(self.comm_pattern, - comm_dims, comm_size, - dtype) - return comm_cost - - def get_mem_cost(self): - return self.device_mesh.shape_consistency_manager.mem_cost([self]) - - def get_max_mem_cost(self): - return self.device_mesh.shape_consistency_manager.mem_cost( - [self], mem_pattern='max') diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/concatenation_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/concatenation_node.py deleted file mode 100644 index a9225fd40f..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/concatenation_node.py +++ /dev/null @@ -1,56 +0,0 @@ -import copy - -from .node import Node -from .sharding_strategy import StrategiesVector - - -class Concatenation(Node): - - def __init__(self, layer): - super().__init__(layer) - layer.to_subclass() - batch_dims = [i for i in range(len(self.get_output(0).shape))] - self.axis = layer.as_trt().axis - batch_dims.remove(self.axis) - self._generate_bcast_dims(batch_dims, self.get_output(0).shape) - layer.to_base_class() - - def _collect_strategies(self, device_mesh): - dim_partition_list = [] - dim_size = len(self.op_data['output0'].shape) - dim_partition_list.append({}) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) - - dim_partition_dict_mapping = {} - strategies_vector = StrategiesVector(self) - for dim_partition_dict in dim_partition_list: - if self.axis in dim_partition_dict: - dim_partition_dict.pop(self.axis) - for idx in range(self.num_inputs): - in_partition_dict = copy.deepcopy(dim_partition_dict) - dim_partition_dict_mapping[f'input{idx}'] = in_partition_dict - out_partition_dict = dim_partition_dict - dim_partition_dict_mapping['output0'] = out_partition_dict - - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if 0 == len(sharding_spec_mapping): - continue - name = '{} = {}'.format( - sharding_spec_mapping['output0'].sharding_sequence, self.axis, [ - sharding_spec_mapping[f'input{idx}'].sharding_sequence - for idx in range(self.num_inputs) - ]) - sharding_strategy = self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - strategies_vector.append(sharding_strategy) - return strategies_vector diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/constant_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/constant_node.py deleted file mode 100644 index ed43331fed..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/constant_node.py +++ /dev/null @@ -1,45 +0,0 @@ -from .node import Node -from .sharding_strategy import StrategiesVector - - -class Constant(Node): - - def _update_memory_cost(self, strategies): - super()._update_memory_cost(strategies) - for strategy in strategies: - strategy.inout_memory_footprint = 0.0 - strategy.peak_memory_footprint = 0.0 - strategy.const_memory_footprint = strategy.sharding_specs[ - 'output0'].get_max_sharded_size_per_device() - - def _collect_strategies(self, device_mesh): - dim_partition_list = [] - dim_size = len(self.op_data['output0'].shape) - dim_partition_list.append({}) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([1], dim_size)) - - strategies_vector = StrategiesVector(self) - for dim_partition_dict in dim_partition_list: - dim_partition_dict_mapping = {'output0': dim_partition_dict} - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if 0 == len(sharding_spec_mapping): - continue - sharding_seq = sharding_spec_mapping['output0'].sharding_sequence - sharding_strategy = self._get_sharding_strategy( - name=f'constant-op {sharding_seq}', - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - strategies_vector.append(sharding_strategy) - - return strategies_vector - - def _profile_sharding_cost(self, strategy, device_mesh): - return 0.0 diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/elementwise_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/elementwise_node.py deleted file mode 100644 index 6c2af5aaa7..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/elementwise_node.py +++ /dev/null @@ -1,49 +0,0 @@ -from .node import Node -from .sharding_strategy import StrategiesVector - - -class ElementWise(Node): - - def __init__(self, layer): - super().__init__(layer) - batch_dims = [i for i in range(len(self.get_output(0).shape))] - self._generate_bcast_dims(batch_dims, self.get_output(0).shape) - - def _collect_strategies(self, device_mesh): - dim_partition_list = [] - dim_size = len(self.op_data['output0'].shape) - dim_partition_list.append({}) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) - strategies_vector = StrategiesVector(self) - for dim_partition_dict in dim_partition_list: - in0_partition_dict = self._recover_bcast_partition_dict( - dim_partition_dict, self.op_data['input0']) - in1_partition_dict = self._recover_bcast_partition_dict( - dim_partition_dict, self.op_data['input1']) - out_partition_dict = dim_partition_dict - dim_partition_dict_mapping = { - "input0": in0_partition_dict, - "input1": in1_partition_dict, - "output0": out_partition_dict, - } - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if 0 == len(sharding_spec_mapping): - continue - name = '{} = {} {}'.format( - sharding_spec_mapping['output0'].sharding_sequence, - sharding_spec_mapping['input0'].sharding_sequence, - sharding_spec_mapping['input1'].sharding_sequence) - sharding_strategy = self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - strategies_vector.append(sharding_strategy) - return strategies_vector diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/fill_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/fill_node.py deleted file mode 100644 index 3be5f79acd..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/fill_node.py +++ /dev/null @@ -1,59 +0,0 @@ -import tensorrt as trt - -from .node import Node -from .sharding_strategy import StrategiesVector - - -class Fill(Node): - - def __init__(self, layer): - super().__init__(layer) - layer.to_subclass() - self.operation = layer.as_trt().operation - layer.to_base_class() - - def _collect_strategies(self, device_mesh): - dim_partition_list = [] - dim_size = len(self.op_data['output0'].shape) - dim_partition_list.append({}) - if self.num_inputs == 0 and self.operation != trt.FillOperation.LINSPACE: - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) - - strategies_vector = StrategiesVector(self) - for dim_partition_dict in dim_partition_list: - dim_partition_dict_mapping = {'output0': dim_partition_dict} - for i in range(self.num_inputs): - dim_partition_dict_mapping[f'input{i}'] = {} - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if 0 == len(sharding_spec_mapping): - continue - sharding_seq = sharding_spec_mapping['output0'].sharding_sequence - sharding_strategy = self._get_sharding_strategy( - name=f'fill-op {sharding_seq}', - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - strategies_vector.append(sharding_strategy) - - return strategies_vector - - def _profile_sharding_cost(self, strategy, device_mesh): - updated_layer_attrs = {} - updated_input_values = {} - shape = strategy.sharding_specs['output0'].get_sharded_shape_per_device( - ) - if self.layer.num_inputs >= 1: - updated_input_values[0] = shape - else: - updated_layer_attrs['shape'] = shape - elapsed_time = self.node_runtime_profiler.runtime_profile( - self.layer, updated_layer_attrs, updated_input_values, strategy, - device_mesh) - return elapsed_time diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/gather_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/gather_node.py deleted file mode 100644 index 8b9a6f0885..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/gather_node.py +++ /dev/null @@ -1,196 +0,0 @@ -import copy - -import tensorrt as trt - -from .comm_spec import CommSpec -from .node import Node -from .sharding_spec import DimSpec -from .sharding_strategy import StrategiesVector - - -class Gather(Node): - - def __init__(self, layer): - super().__init__(layer) - layer.to_subclass() - self.mode = layer.as_trt().mode - self.axis = layer.as_trt().axis - self.num_elementwise_dims = layer.as_trt().num_elementwise_dims - self.input_id = 0 - self.indice_id = 1 - self.support_vocab_tp = False - layer.to_base_class() - - def _update_memory_cost(self, strategies): - for strategy in strategies: - # for gather node, it input0's read = output0's write - inout_memory_footprint = ( - strategy.sharding_specs['output0'].get_sharded_size_per_device( - ) * 2 + - strategy.sharding_specs['input1'].get_sharded_size_per_device()) - strategy.inout_memory_footprint = inout_memory_footprint - strategy.peak_memory_footprint = ( - strategy.sharding_specs['output0']. - get_max_sharded_size_per_device() + strategy. - sharding_specs['input0'].get_max_sharded_size_per_device() + - strategy.sharding_specs['input1']. - get_max_sharded_size_per_device()) - - def _collect_strategies(self, device_mesh): - if self.mode == trt.GatherMode.DEFAULT: - return self._default_gather_strategies(device_mesh) - elif self.mode == trt.GatherMode.ELEMENT: - return self._element_gather_strategies(device_mesh) - elif self.mode == trt.GatherMode.ND: - assert 0, 'unsupport gatherND' - else: - assert 0, f'unsupport gather mode {self.mode}' - - def _element_gather_strategies(self, device_mesh): - dim_partition_list = [] - dim_size = len(self.op_data['output0'].shape) - dim_partition_list.append({}) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) - - strategies_vector = StrategiesVector(self) - for dim_partition_dict in dim_partition_list: - if self.axis in dim_partition_dict: - dim_partition_dict.pop(self.axis) - - dim_partition_dict_mapping = { - 'input0': dim_partition_dict, - 'input1': copy.deepcopy(dim_partition_dict), - 'output0': copy.deepcopy(dim_partition_dict), - } - - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if 0 == len(sharding_spec_mapping): - continue - name = '{} = {} {}'.format( - sharding_spec_mapping['output0'].sharding_sequence, - sharding_spec_mapping['input0'].sharding_sequence, self.axis, - sharding_spec_mapping['input1'].sharding_sequence) - sharding_strategy = self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - strategies_vector.append(sharding_strategy) - - return strategies_vector - - # for plugin, indice is input0, and weight is input1, which is different from gather node - def _default_gather_strategies(self, device_mesh): - - def add_sharding_strategy(dim_partition_dict_mapping, - vocab_tp_dim=None): - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if len(sharding_spec_mapping) > 0: - name = '{} = {} {}'.format( - sharding_spec_mapping['output0'].sharding_sequence, - sharding_spec_mapping['input0'].sharding_sequence, - self.axis, self.num_elementwise_dims, - sharding_spec_mapping['input1'].sharding_sequence) - communication_action_mapping = {} - if vocab_tp_dim is not None: - name += f'_allreduce{DimSpec(vocab_tp_dim)}' - output0_comm_action = CommSpec( - comm_pattern='all_reduce', - sharding_spec=sharding_spec_mapping['output0'], - logical_process_axis=[vocab_tp_dim], - ) - communication_action_mapping[ - 'output0'] = output0_comm_action - sharding_strategy = self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) - strategies_vector.append(sharding_strategy) - - input_id, indice_id = self.input_id, self.indice_id - strategies_vector = StrategiesVector(self) - input_size = len(self.op_data[f'input{input_id}'].shape) - indice_size = len(self.op_data[f'input{indice_id}'].shape) - output_dim = input_size + indice_size - 1 - self.num_elementwise_dims - for strategy in self.predecessor_nodes[input_id].strategies_vector: - # current node's local name input0 -> global name xxx - global_input_name = self.op_data[f'input{input_id}'].name - # global name xxx -> pre node local output name - prenode_local_name = self.predecessor_nodes[ - input_id].global_to_local_op_name[global_input_name] - input_dim_partition_dict = copy.deepcopy( - strategy.sharding_specs[prenode_local_name].dim_partition_dict) - - vocab_tp_dim = input_dim_partition_dict.pop(self.axis, None) - - input_mesh_dims = [] - for dim, mesh_dims in input_dim_partition_dict.items(): - input_mesh_dims += mesh_dims - input_mesh_dims = set(input_mesh_dims) - - for idx_strategy in self.predecessor_nodes[ - indice_id].strategies_vector: - # current node's local name input0 -> global name xxx - global_indice_name = self.op_data[f'input{indice_id}'].name - # global name xxx -> pre node local output name - prenode_local_name = self.predecessor_nodes[ - indice_id].global_to_local_op_name[global_indice_name] - indice_dim_partition_dict = copy.deepcopy( - idx_strategy.sharding_specs[prenode_local_name]. - dim_partition_dict) - - for dim, indice_mesh_dims in idx_strategy.sharding_specs[ - prenode_local_name].dim_partition_dict.items(): - for indice_mesh_dim in indice_mesh_dims: - if indice_mesh_dim in input_mesh_dims: - indice_dim_partition_dict.pop(dim) - break - - out_partition_dict = {} - - for dim in range(output_dim): - if dim < self.axis: - if dim in input_dim_partition_dict: - out_partition_dict[dim] = \ - input_dim_partition_dict[dim] - elif dim >= self.axis and dim < self.axis + indice_size - self.num_elementwise_dims: - indice_dim = dim - self.axis + self.num_elementwise_dims - if indice_dim in indice_dim_partition_dict: - out_partition_dict[dim] = \ - indice_dim_partition_dict[indice_dim] - else: - input_dim = dim - (indice_size - - self.num_elementwise_dims) + 1 - if input_dim in input_dim_partition_dict: - out_partition_dict[dim] = \ - input_dim_partition_dict[input_dim] - - dim_partition_dict_mapping = { - f"input{input_id}": input_dim_partition_dict, - f"input{indice_id}": indice_dim_partition_dict, - "output0": out_partition_dict, - } - add_sharding_strategy(dim_partition_dict_mapping) - - if self.support_vocab_tp and vocab_tp_dim is not None: - vocab_tp_dim_partition_dict = { - **input_dim_partition_dict, - self.axis: vocab_tp_dim, - } - dim_partition_dict_mapping = { - f"input{input_id}": vocab_tp_dim_partition_dict, - f"input{indice_id}": indice_dim_partition_dict, - "output0": out_partition_dict, - } - add_sharding_strategy(dim_partition_dict_mapping, - vocab_tp_dim) - - return strategies_vector diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/identity_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/identity_node.py deleted file mode 100644 index 470978b8ec..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/identity_node.py +++ /dev/null @@ -1,56 +0,0 @@ -import copy - -from .node import Node -from .sharding_strategy import StrategiesVector - - -class Identity(Node): - - def _update_memory_cost(self, strategies): - if not self.is_fake: - super()._update_memory_cost(strategies) - else: - # fake nodes for building block/PP connection - pass - - def _collect_strategies(self, device_mesh): - dim_partition_list = [] - dim_size = len(self.op_data['input0'].shape) - dim_partition_list.append({}) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) - strategies_vector = StrategiesVector(self) - # dim_partition_dict can be the same as previous node if solver's time is a problem - for dim_partition_dict in dim_partition_list: - in0_partition_dict = dim_partition_dict - out_partition_dict = copy.deepcopy(dim_partition_dict) - dim_partition_dict_mapping = { - "input0": in0_partition_dict, - "output0": out_partition_dict, - } - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if 0 == len(sharding_spec_mapping): - continue - name = '{} = {}'.format( - sharding_spec_mapping['output0'].sharding_sequence, - sharding_spec_mapping['input0'].sharding_sequence) - sharding_strategy = self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - strategies_vector.append(sharding_strategy) - return strategies_vector - - def _profile_sharding_cost(self, strategy, device_mesh): - # if same spec id is not 0, identify node is used as same spec id node - if self.same_spec_id == -1: - return super()._profile_sharding_cost(strategy, device_mesh) - else: - return 0.0 diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/input_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/input_node.py deleted file mode 100644 index f8e24cd499..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/input_node.py +++ /dev/null @@ -1,79 +0,0 @@ -from .node import Node -from .sharding_strategy import StrategiesVector - - -class InputNode(Node): - - def _update_memory_cost(self, strategies): - for strategy in strategies: - if not self.no_memory_footprint: - strategy.const_memory_footprint = strategy.sharding_specs[ - 'output0'].get_max_sharded_size_per_device() - - def __init__(self, tensor): - self._layer = None - self.is_shape_io = False - self._inputs = [] - self._outputs = [] - self.predecessor_nodes = [] - self.predecessor_nodes_out_index = {} - self.successor_nodes = [] - self.op_data = {} - self.global_to_local_op_name = {} - self.is_replicated = tensor.attrs.get("is_replicated", False) - self.same_spec_id = tensor.attrs.get("same_spec_id", -1) - self.no_memory_footprint = tensor.attrs.get("no_memory_footprint", - False) - self.building_block_id = -1 - self.cost_level = -1 - self.stage_type = None - self.in_start_block = None - self.in_end_block = None - self.in_slowest_block = None - output = tensor.copy() - self._outputs.append(output) - self.op_data['output0'] = output - self.global_to_local_op_name[output.name] = 'output0' - - self.sharding_weight = 1.0 - self.resharding_weight = 1.0 - self.pipeline_weight = 0 - self.node_name = tensor.name - self.node_type = 'input_node' - self.num_inputs = 0 - self.num_outputs = 1 - self.dtype = tensor.dtype - self.strategies_vector = [] - self.node_runtime_profiler = None - - def _collect_strategies(self, device_mesh): - dim_partition_list = [] - dim_size = len(self.op_data['output0'].shape) - dim_partition_list.append({}) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) - - strategies_vector = StrategiesVector(self) - for dim_partition_dict in dim_partition_list: - dim_partition_dict_mapping = {'output0': dim_partition_dict} - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if 0 == len(sharding_spec_mapping): - continue - sharding_seq = sharding_spec_mapping['output0'].sharding_sequence - sharding_strategy = self._get_sharding_strategy( - name=f'input-op {sharding_seq}', - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - strategies_vector.append(sharding_strategy) - - return strategies_vector - - def _profile_sharding_cost(self, strategy, device_mesh): - return 0.0 diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/matmul_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/matmul_node.py deleted file mode 100644 index cf45e55340..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/matmul_node.py +++ /dev/null @@ -1,798 +0,0 @@ -import copy -import operator -from functools import reduce - -import tensorrt as trt - -from ..device_mesh import LogicalDeviceMesh -from ..utils import get_builder_flags -from .comm_spec import CommSpec -from .node import Node -from .sharding_spec import DimSpec -from .sharding_strategy import StrategiesVector - - -class MatrixMultiply(Node): - - def __init__(self, layer): - super().__init__(layer) - layer.to_subclass() - batch_dims = [i for i in range(len(self.get_output(0).shape))][:-2] - self._generate_bcast_dims(batch_dims, self.get_output(0).shape) - self.op0_transpose = layer.as_trt().op0 == trt.MatrixOperation.TRANSPOSE - self.op1_transpose = layer.as_trt().op1 == trt.MatrixOperation.TRANSPOSE - self.num_out_dims = len(self.get_output(0).shape) - dtypes_str = [ - self.get_input(0).dtype_str, - self.get_input(1).dtype_str, - self.get_output(0).dtype_str - ] - dtypes_size = [ - self.get_input(0).dtype_size, - self.get_input(1).dtype_size, - self.get_output(0).dtype_size - ] - min_idx = dtypes_size.index(min(dtypes_size)) - self.dtype = dtypes_str[min_idx] - layer.to_base_class() - - def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1, device_mesh): - in0_split_dim = -1 if self.op0_transpose else -2 - in1_split_dim = -2 if self.op1_transpose else -1 - name = (f'{DimSpec(mesh_dim_0)}{DimSpec(mesh_dim_1)} = ' - f'{DimSpec(mesh_dim_0)}R x R{DimSpec(mesh_dim_1)}') - dim_partition_dict_mapping = { - "input0": { - in0_split_dim: mesh_dim_0 - }, - "input1": { - in1_split_dim: mesh_dim_1 - }, - "output0": { - -2: mesh_dim_0, - -1: mesh_dim_1 - }, - } - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if len(sharding_spec_mapping) == 0: - return None - strategy = self._get_sharding_strategy(name = name, \ - sharding_spec_mapping = sharding_spec_mapping, \ - communication_action_mapping = {}) - return strategy - - def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1, - device_mesh): - # handle the case SR = SS x SR - name = ( - f'{DimSpec(mesh_dim_0)}R = ' - f'{DimSpec(mesh_dim_0)}{DimSpec(mesh_dim_1)} x {DimSpec(mesh_dim_1)}R' - f'_allreduce{DimSpec(mesh_dim_1)}') - in0_split_dim = [-1, -2] if self.op0_transpose else [-2, -1] - in1_split_dim = -1 if self.op1_transpose else -2 - # get sharding spec mapping - dim_partition_dict_mapping = { - "input0": { - in0_split_dim[0]: mesh_dim_0, - in0_split_dim[1]: mesh_dim_1 - }, - "input1": { - in1_split_dim: mesh_dim_1 - }, - "output0": { - -2: mesh_dim_0 - }, - } - - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if len(sharding_spec_mapping) == 0: - return None - # get communication action mapping - communication_action_mapping = {} - output0_comm_action = CommSpec( - comm_pattern='all_reduce', - sharding_spec=sharding_spec_mapping['output0'], - logical_process_axis=[mesh_dim_1], - ) - communication_action_mapping['output0'] = output0_comm_action - return self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) - - def _split_both_contract_rs(self, name, rs_dim, rs_mesh_dim, src_spec, - dim_partition_dict_mapping, device_mesh): - output0_comm_action = CommSpec( - comm_pattern='reduce_scatter', - sharding_spec=src_spec, - shard_dim=[rs_dim], - logical_process_axis=[rs_mesh_dim], - ) - rs_out_partition_dict_mapping = copy.deepcopy( - dim_partition_dict_mapping) - rs_out_partition_dict_mapping["output0"][rs_dim] = rs_mesh_dim - rs_out_sharding_spec_mapping = self._to_sharding_spec_mapping( - rs_out_partition_dict_mapping, device_mesh) - if len(rs_out_sharding_spec_mapping) == 0: - return None - - communication_action_mapping = {} - communication_action_mapping['output0'] = output0_comm_action - return self._get_sharding_strategy( - name=name, - sharding_spec_mapping=rs_out_sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) - - def _split_lhs_space_both_contract_rs(self, mesh_dim_0, mesh_dim_1, - device_mesh): - # handle the case SS = SS x SR -> reduce_scatter - in0_split_dim = [-1, -2] if self.op0_transpose else [-2, -1] - in1_split_dim = -1 if self.op1_transpose else -2 - # get sharding spec mapping - dim_partition_dict_mapping = { - "input0": { - in0_split_dim[0]: mesh_dim_0, - in0_split_dim[1]: mesh_dim_1 - }, - "input1": { - in1_split_dim: mesh_dim_1 - }, - "output0": { - -2: mesh_dim_0, - }, - } - mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if len(mm_out_sharding_spec_mapping) == 0: - return [] - strategies = [] - for rs_dim in range(self.num_out_dims): - if rs_dim != self.num_out_dims - 2: - name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [ - 'R' - ] * self.num_out_dims, ['R'] * self.num_out_dims - name_in0[-2], name_in0[-1] = str(DimSpec(mesh_dim_0)), str( - DimSpec(mesh_dim_1)) - name_in1[-2] = str(DimSpec(mesh_dim_1)) - name_out0[-2], name_out0[rs_dim] = str( - DimSpec(mesh_dim_0)), str(DimSpec(mesh_dim_1)) - name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join( - name_in1), ', '.join(name_out0) - name = (f'[{name_out0}] = [{name_in0}] x [{name_in1}]' - f'_reducescatter{(rs_dim, DimSpec(mesh_dim_1))}') - ret = self._split_both_contract_rs( - name, rs_dim, mesh_dim_1, - mm_out_sharding_spec_mapping['output0'], - dim_partition_dict_mapping, device_mesh) - if ret: - strategies.append(ret) - return strategies - - def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1, - device_mesh): - name = ( - f'R{DimSpec(mesh_dim_1)} = ' - f'R{DimSpec(mesh_dim_0)} x {DimSpec(mesh_dim_0)}{DimSpec(mesh_dim_1)}' - f'_allreduce{DimSpec(mesh_dim_0)}') - in0_split_dim = -2 if self.op0_transpose else -1 - in1_split_dim = [-1, -2] if self.op1_transpose else [-2, -1] - # get sharding specs - dim_partition_dict_mapping = { - "input0": { - in0_split_dim: mesh_dim_0 - }, - "input1": { - in1_split_dim[0]: mesh_dim_0, - in1_split_dim[1]: mesh_dim_1 - }, - "output0": { - -1: mesh_dim_1 - }, - } - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if len(sharding_spec_mapping) == 0: - return None - # get communication actions - communication_action_mapping = {} - output0_comm_action = CommSpec( - comm_pattern='all_reduce', - sharding_spec=sharding_spec_mapping['output0'], - logical_process_axis=[mesh_dim_0], - ) - communication_action_mapping['output0'] = output0_comm_action - return self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) - - def _split_rhs_space_both_contract_rs(self, mesh_dim_0, mesh_dim_1, - device_mesh): - in0_split_dim = -2 if self.op0_transpose else -1 - in1_split_dim = [-1, -2] if self.op1_transpose else [-2, -1] - # get sharding specs - dim_partition_dict_mapping = { - "input0": { - in0_split_dim: mesh_dim_0 - }, - "input1": { - in1_split_dim[0]: mesh_dim_0, - in1_split_dim[1]: mesh_dim_1 - }, - "output0": { - -1: mesh_dim_1 - }, - } - mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if len(mm_out_sharding_spec_mapping) == 0: - return [] - strategies = [] - for rs_dim in range(self.num_out_dims): - if rs_dim != self.num_out_dims - 1: - name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [ - 'R' - ] * self.num_out_dims, ['R'] * self.num_out_dims - name_in1[-2], name_in1[-1] = str(DimSpec(mesh_dim_0)), str( - DimSpec(mesh_dim_1)) - name_in0[-1] = str(DimSpec(mesh_dim_0)) - name_out0[-1], name_out0[rs_dim] = str( - DimSpec(mesh_dim_1)), str(DimSpec(mesh_dim_0)) - name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join( - name_in1), ', '.join(name_out0) - name = (f'[{name_out0}] = [{name_in0}] x [{name_in1}]' - f'_reducescatter{(rs_dim, DimSpec(mesh_dim_0))}') - ret = self._split_both_contract_rs( - name, rs_dim, mesh_dim_0, - mm_out_sharding_spec_mapping['output0'], - dim_partition_dict_mapping, device_mesh) - if ret: - strategies.append(ret) - return strategies - - def _recompute_split_both_contract(self, mesh_dim, device_mesh): - name = (f'RR = R{DimSpec(mesh_dim)} x {DimSpec(mesh_dim)}R' - f'_allreduce{DimSpec(mesh_dim)}') - in0_split_dim = -2 if self.op0_transpose else -1 - in1_split_dim = -1 if self.op1_transpose else -2 - dim_partition_dict_mapping = { - "input0": { - in0_split_dim: mesh_dim - }, - "input1": { - in1_split_dim: mesh_dim - }, - "output0": {}, - } - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if len(sharding_spec_mapping) == 0: - return None - - # get communication action - communication_action_mapping = {} - output0_comm_action = CommSpec( - comm_pattern='all_reduce', - sharding_spec=sharding_spec_mapping['output0'], - logical_process_axis=[mesh_dim], - ) - communication_action_mapping['output0'] = output0_comm_action - return self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) - - def _recompute_split_both_contract_rs(self, mesh_dim, device_mesh): - name = (f'{DimSpec(mesh_dim)}R = ' - f'R{DimSpec(mesh_dim)} x {DimSpec(mesh_dim)}R' - f'_reducescatter0_{DimSpec(mesh_dim)}') - in0_split_dim = -2 if self.op0_transpose else -1 - in1_split_dim = -1 if self.op1_transpose else -2 - dim_partition_dict_mapping = { - "input0": { - in0_split_dim: mesh_dim - }, - "input1": { - in1_split_dim: mesh_dim - }, - "output0": {}, - } - mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if len(mm_out_sharding_spec_mapping) == 0: - return [] - - strategies = [] - for rs_dim in range(self.num_out_dims): - name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [ - 'R' - ] * self.num_out_dims, ['R'] * self.num_out_dims - name_in0[-1], name_in1[-2], name_out0[rs_dim] = str( - DimSpec(mesh_dim)), str(DimSpec(mesh_dim)), str( - DimSpec(mesh_dim)) - name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join( - name_in1), ', '.join(name_out0) - name = f'[{name_out0}] = [{name_in0}] x [{name_in1}]_reducescatter{(rs_dim, DimSpec(mesh_dim))}' - ret = self._split_both_contract_rs( - name, rs_dim, mesh_dim, mm_out_sharding_spec_mapping['output0'], - dim_partition_dict_mapping, device_mesh) - if ret: - strategies.append(ret) - return strategies - - def _split_rhs_space_only(self, mesh_dim, device_mesh): - name = f'R{DimSpec(mesh_dim)} = RR x R{DimSpec(mesh_dim)}' - in1_split_dim = -2 if self.op1_transpose else -1 - # get sharding spec - dim_partition_dict_mapping = { - "input0": {}, - "input1": { - in1_split_dim: mesh_dim - }, - "output0": { - -1: mesh_dim - }, - } - # We don't have to do anything special for bias here, because - # the bias is already the same sharding spec as the output0. - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if len(sharding_spec_mapping) == 0: - return None - return self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - - def _split_lhs_space_only(self, mesh_dim, device_mesh): - name = f'{DimSpec(mesh_dim)}R = {DimSpec(mesh_dim)}R x RR' - in0_split_dim = -1 if self.op0_transpose else -2 - # get sharding spec - dim_partition_dict_mapping = { - "input0": { - in0_split_dim: mesh_dim - }, - "input1": {}, - "output0": { - -2: mesh_dim - }, - } - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if len(sharding_spec_mapping) == 0: - return None - return self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - - def _non_split(self, device_mesh): - name = 'RR = RR x RR' - # get sharding spec - dim_partition_dict_mapping = { - "input0": {}, - "input1": {}, - "output0": {}, - } - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if len(sharding_spec_mapping) == 0: - return None - return self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - - def _split_one_batch_dim(self, batch_dim, mesh_dim, device_mesh): - name = ( - f'{DimSpec(mesh_dim)}b{batch_dim}RR = ' - f'{DimSpec(mesh_dim)}b{batch_dim}RR x {DimSpec(mesh_dim)}b{batch_dim}RR' - ) - in0_data = self.op_data['input0'] - in1_data = self.op_data['input1'] - - batch_partition_dict = {batch_dim: mesh_dim} - in0_parition_dict = self._recover_bcast_partition_dict( - batch_partition_dict, in0_data) - in1_parition_dict = self._recover_bcast_partition_dict( - batch_partition_dict, in1_data) - out_partition_dict = {batch_dim: mesh_dim} - # TODO: Double check if MatrixMultiplication's output has bcast in dim - dim_partition_dict_mapping = { - "input0": in0_parition_dict, - "input1": in1_parition_dict, - "output0": out_partition_dict, - } - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if len(sharding_spec_mapping) == 0: - return None - return self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - - def _split_two_batch_dims(self, batch_dim0, batch_dim1, mesh_dim0, - mesh_dim1, device_mesh): - name = ( - f'{DimSpec(mesh_dim0)}b{batch_dim0}{DimSpec(mesh_dim1)}b{batch_dim1}RR = ' - f'{DimSpec(mesh_dim0)}b{batch_dim0}RR x {DimSpec(mesh_dim1)}b{batch_dim1}RR' - ) - in0_data = self.op_data['input0'] - in1_data = self.op_data['input1'] - - in0_parition_dict = {} - if batch_dim0 not in in0_data.attrs["broadcast_dims"]: - in0_parition_dict[batch_dim0] = mesh_dim0 - if batch_dim1 not in in0_data.attrs["broadcast_dims"]: - in0_parition_dict[batch_dim1] = mesh_dim1 - - in1_parition_dict = {} - if batch_dim0 not in in1_data.attrs["broadcast_dims"]: - in1_parition_dict[batch_dim0] = mesh_dim0 - if batch_dim1 not in in1_data.attrs["broadcast_dims"]: - in1_parition_dict[batch_dim1] = mesh_dim1 - - batch_partition_dict = {batch_dim0: mesh_dim0, batch_dim1: mesh_dim1} - in0_parition_dict = self._recover_bcast_partition_dict( - batch_partition_dict, in0_data) - in1_parition_dict = self._recover_bcast_partition_dict( - batch_partition_dict, in1_data) - out_partition_dict = {batch_dim0: mesh_dim0, batch_dim1: mesh_dim1} - dim_partition_dict_mapping = { - "input0": in0_parition_dict, - "input1": in1_parition_dict, - "output0": out_partition_dict, - } - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if len(sharding_spec_mapping) == 0: - return None - return self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - - def _split_batch_dim_lhs_space(self, batch_dim, mesh_dim0, mesh_dim1, - device_mesh): - - name = ( - f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R = ' - f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R x {DimSpec(mesh_dim0)}b{batch_dim}RR' - ) - in0_data = self.op_data['input0'] - in1_data = self.op_data['input1'] - in0_parition_dict = {batch_dim: mesh_dim0} - in1_parition_dict = {batch_dim: mesh_dim0} - in0_lhs_split_dim = -1 if self.op0_transpose else -2 - in0_parition_dict[in0_lhs_split_dim] = mesh_dim1 - - in0_parition_dict = self._recover_bcast_partition_dict( - in0_parition_dict, in0_data) - in1_parition_dict = self._recover_bcast_partition_dict( - in1_parition_dict, in1_data) - out_partition_dict = {batch_dim: mesh_dim0, -2: mesh_dim1} - - dim_partition_dict_mapping = { - "input0": in0_parition_dict, - "input1": in1_parition_dict, - "output0": out_partition_dict, - } - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if len(sharding_spec_mapping) == 0: - return None - return self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - - def _split_batch_dim_rhs_space(self, batch_dim, mesh_dim0, mesh_dim1, - device_mesh): - - name = ( - f'{DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)} = ' - f'{DimSpec(mesh_dim0)}b{batch_dim}RR x {DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)}' - ) - in0_data = self.op_data['input0'] - in1_data = self.op_data['input1'] - in0_parition_dict = {batch_dim: mesh_dim0} - in1_parition_dict = {batch_dim: mesh_dim0} - - in1_rhs_split_dim = -2 if self.op1_transpose else -1 - in1_parition_dict[in1_rhs_split_dim] = mesh_dim1 - - in0_parition_dict = self._recover_bcast_partition_dict( - in0_parition_dict, in0_data) - in1_parition_dict = self._recover_bcast_partition_dict( - in1_parition_dict, in1_data) - out_partition_dict = {batch_dim: mesh_dim0, -1: mesh_dim1} - dim_partition_dict_mapping = { - "input0": in0_parition_dict, - "input1": in1_parition_dict, - "output0": out_partition_dict, - } - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if len(sharding_spec_mapping) == 0: - return None - return self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - - def _split_batch_dim_both_contract(self, batch_dim, mesh_dim0, mesh_dim1, - device_mesh): - - name = ( - f'{DimSpec(mesh_dim0)}b{batch_dim}RR = ' - f'{DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)} x ' - f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R_AR{mesh_dim1}' - ) - in0_data = self.op_data['input0'] - in1_data = self.op_data['input1'] - in0_parition_dict = {batch_dim: mesh_dim0} - in1_parition_dict = {batch_dim: mesh_dim0} - - in0_contract_dim = -2 if self.op0_transpose else -1 - in1_contract_dim = -1 if self.op1_transpose else -2 - in0_parition_dict[in0_contract_dim] = mesh_dim1 - in1_parition_dict[in1_contract_dim] = mesh_dim1 - - in0_parition_dict = self._recover_bcast_partition_dict( - in0_parition_dict, in0_data) - in1_parition_dict = self._recover_bcast_partition_dict( - in1_parition_dict, in1_data) - out_partition_dict = {batch_dim: mesh_dim0} - dim_partition_dict_mapping = { - "input0": in0_parition_dict, - "input1": in1_parition_dict, - "output0": out_partition_dict, - } - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if len(sharding_spec_mapping) == 0: - return None - - # get communication actions - communication_action_mapping = {} - output0_comm_action = CommSpec( - comm_pattern='all_reduce', - sharding_spec=sharding_spec_mapping['output0'], - logical_process_axis=[mesh_dim1], - ) - communication_action_mapping['output0'] = output0_comm_action - return self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) - - def _split_batch_dim_both_contract_rs(self, batch_dim, mesh_dim0, mesh_dim1, - device_mesh): - - name = ( - f'{DimSpec(mesh_dim0)}b{batch_dim}RR = ' - f'{DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)} x ' - f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R_AR{mesh_dim1}' - ) - in0_data = self.op_data['input0'] - in1_data = self.op_data['input1'] - in0_parition_dict = {batch_dim: mesh_dim0} - in1_parition_dict = {batch_dim: mesh_dim0} - - in0_contract_dim = -2 if self.op0_transpose else -1 - in1_contract_dim = -1 if self.op1_transpose else -2 - in0_parition_dict[in0_contract_dim] = mesh_dim1 - in1_parition_dict[in1_contract_dim] = mesh_dim1 - - in0_parition_dict = self._recover_bcast_partition_dict( - in0_parition_dict, in0_data) - in1_parition_dict = self._recover_bcast_partition_dict( - in1_parition_dict, in1_data) - out_partition_dict = {batch_dim: mesh_dim0} - dim_partition_dict_mapping = { - "input0": in0_parition_dict, - "input1": in1_parition_dict, - "output0": out_partition_dict, - } - mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if len(mm_out_sharding_spec_mapping) == 0: - return [] - - strategies = [] - for rs_dim in range(self.num_out_dims): - if rs_dim != batch_dim: - name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [ - 'R' - ] * self.num_out_dims, ['R'] * self.num_out_dims - name_in0[batch_dim], name_in0[-1] = str( - DimSpec(mesh_dim0)), str(DimSpec(mesh_dim1)) - name_in1[batch_dim], name_in1[-2] = str( - DimSpec(mesh_dim0)), str(DimSpec(mesh_dim1)) - name_in1[batch_dim], name_out0[rs_dim] = str( - DimSpec(mesh_dim0)), str(DimSpec(mesh_dim1)) - name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join( - name_in1), ', '.join(name_out0) - name = f'[{name_out0}] = [{name_in0}] x [{name_in1}]_reducescatter{(rs_dim, DimSpec(mesh_dim1))}' - ret = self._split_both_contract_rs( - name, rs_dim, mesh_dim1, - mm_out_sharding_spec_mapping['output0'], - dim_partition_dict_mapping, device_mesh) - if ret: - strategies.append(ret) - return strategies - - def _dp_strategies(self, device_mesh): - strategies = [] - # S0R = S0R x RR - strategies.append(self._split_lhs_space_only([0], device_mesh)) - # S1R = S1R x RR - strategies.append(self._split_lhs_space_only([1], device_mesh)) - # S01R = S01R x RR - strategies.append(self._split_lhs_space_only([0, 1], device_mesh)) - return strategies - - def _tp_strategies(self, device_mesh: LogicalDeviceMesh): - strategies = [] - # RR = RS x SR _ AR - strategies.append(self._recompute_split_both_contract([0], device_mesh)) - strategies.append(self._recompute_split_both_contract([1], device_mesh)) - strategies.append( - self._recompute_split_both_contract([0, 1], device_mesh)) - - if device_mesh.config.enable_reduce_scatter: - # RS x SR _ reduce scatter - strategies.extend( - self._recompute_split_both_contract_rs([0], device_mesh)) - strategies.extend( - self._recompute_split_both_contract_rs([1], device_mesh)) - strategies.extend( - self._recompute_split_both_contract_rs([0, 1], device_mesh)) - - # RS = RR x RS - strategies.append(self._split_rhs_space_only([0], device_mesh)) - strategies.append(self._split_rhs_space_only([1], device_mesh)) - strategies.append(self._split_rhs_space_only([0, 1], device_mesh)) - - # RS = RS x SS _ AR - strategies.append( - self._split_rhs_space_both_contract([0], [1], device_mesh)) - strategies.append( - self._split_rhs_space_both_contract([1], [0], device_mesh)) - - if device_mesh.config.enable_reduce_scatter: - # RS x SS _ reduce scatter - strategies.extend( - self._split_rhs_space_both_contract_rs([0], [1], device_mesh)) - strategies.extend( - self._split_rhs_space_both_contract_rs([1], [0], device_mesh)) - - return strategies - - def _mix_strategies(self, device_mesh): - strategies = [] - - # SR = SS x SR_AR - strategies.append( - self._split_lhs_space_both_contract([0], [1], device_mesh)) - strategies.append( - self._split_lhs_space_both_contract([1], [0], device_mesh)) - if device_mesh.config.enable_reduce_scatter: - # RS x SS _ reduce scatter - strategies.extend( - self._split_lhs_space_both_contract_rs([0], [1], device_mesh)) - strategies.extend( - self._split_lhs_space_both_contract_rs([1], [0], device_mesh)) - # SS = SR x RS - strategies.append(self._split_lhs_space_rhs_space([0], [1], - device_mesh)) - strategies.append(self._split_lhs_space_rhs_space([0], [1], - device_mesh)) - - # RR = RR x RR - strategies.append(self._non_split(device_mesh)) - return strategies - - def _bmm_strategies(self, device_mesh: LogicalDeviceMesh): - strategies = [] - bmm_dim = len(self.op_data['output0'].shape) - if bmm_dim >= 3: - for batch_dim in range(0, bmm_dim - 2): - strategies.append( - self._split_one_batch_dim(batch_dim, [0], device_mesh)) - strategies.append( - self._split_one_batch_dim(batch_dim, [1], device_mesh)) - strategies.append( - self._split_one_batch_dim(batch_dim, [0, 1], device_mesh)) - - strategies.append( - self._split_batch_dim_lhs_space(batch_dim, [0], [1], - device_mesh)) - strategies.append( - self._split_batch_dim_lhs_space(batch_dim, [1], [0], - device_mesh)) - - strategies.append( - self._split_batch_dim_rhs_space(batch_dim, [0], [1], - device_mesh)) - strategies.append( - self._split_batch_dim_rhs_space(batch_dim, [1], [0], - device_mesh)) - - strategies.append( - self._split_batch_dim_both_contract(batch_dim, [0], [1], - device_mesh)) - strategies.append( - self._split_batch_dim_both_contract(batch_dim, [1], [0], - device_mesh)) - if device_mesh.config.enable_reduce_scatter: - strategies.extend( - self._split_batch_dim_both_contract_rs( - batch_dim, [0], [1], device_mesh)) - strategies.extend( - self._split_batch_dim_both_contract_rs( - batch_dim, [1], [0], device_mesh)) - if bmm_dim >= 4: - for batch_dim0 in range(0, bmm_dim - 2): - for batch_dim1 in range(0, bmm_dim - 2): - if batch_dim0 != batch_dim1: - strategies.append( - self._split_two_batch_dims( - batch_dim0, batch_dim1, [0], [1], - device_mesh)) - - return strategies - - def _collect_strategies(self, device_mesh): - strategies_vector = StrategiesVector(self) - dp_strategies = self._dp_strategies(device_mesh) - tp_strategies = self._tp_strategies(device_mesh) - mix_strategies = self._mix_strategies(device_mesh) - bmm_strategies = self._bmm_strategies(device_mesh) - strategies_vector.extend(dp_strategies) - strategies_vector.extend(tp_strategies) - strategies_vector.extend(mix_strategies) - strategies_vector.extend(bmm_strategies) - return strategies_vector - - def is_fp16(self): - builder_flags = get_builder_flags() - return builder_flags & (1 << int(trt.BuilderFlag.FP16)) != 0 - - def _get_math_time(self, strategy, device_mesh): - shape_in0 = strategy.sharding_specs[ - 'input0'].get_sharded_shape_per_device() - shape_out = strategy.sharding_specs[ - 'output0'].get_sharded_shape_per_device() - m, n = shape_out[-2], shape_out[-1] - batches = shape_out[:-2] - k = shape_in0[-2] if self.op0_transpose else shape_in0[-1] - macs_shape = batches + [m, n, k] - macs = reduce(operator.mul, macs_shape, 1) * 2 - config = device_mesh.config - cluster_info = device_mesh.cluster_info - dtype = self.dtype - # For fp16 matmul ops that use_fp32_acc=True. - # They are mistaken for fp32 ops since all of their IO tensors use fp32 dtype. - if self.is_fp16() and self.dtype == "float32": - dtype = "float16" - math_throughput_tflops = getattr(cluster_info.math_throughput, dtype) - assert math_throughput_tflops != 0, \ - "Undefined {} math throughput of cluster {}".format(dtype, config.cluster_key) - math_time = macs / math_throughput_tflops * 1e-6 * cluster_info.math_efficiency - return math_time - - def _update_memory_cost(self, strategies): - super()._update_memory_cost(strategies) - # For fp16 matmul ops that use_fp32_acc=True. - # Their memory footprints are calculated based on fp32 IO tensors. - # Actually they will use fp16 IO tensors after fused. - # So we divide all the memory footprints by 2. - if self.is_fp16() and self.dtype == "float32": - for strategy in strategies: - strategy.inout_memory_footprint /= 2 - strategy.peak_memory_footprint /= 2 - strategy.comm_buff_memory_footprint /= 2 diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/node.py b/tensorrt_llm/auto_parallel/tensor_parallel/node.py deleted file mode 100644 index 887c520589..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/node.py +++ /dev/null @@ -1,376 +0,0 @@ -from abc import ABC - -from ..config import CostModel -from ..device_mesh import LogicalDeviceMesh -from .comm_spec import CommSpec -from .sharding_spec import ShardingSpec -from .sharding_strategy import ShardingStrategy, StrategiesVector - - -class Node(ABC): - - def __init__(self, layer): - self._layer = layer - self.is_shape_io = self._layer.is_shape_io - self._inputs = [] - self._outputs = [] - self.predecessor_nodes = [] - self.predecessor_nodes_out_index = {} - self.successor_nodes = [] - self.op_data = {} - self.global_to_local_op_name = {} - self.num_inputs = 0 - self.is_replicated = layer.attrs.get("is_replicated", False) - self.same_spec_id = layer.attrs.get("same_spec_id", -1) - self.is_fake = self.same_spec_id != -1 - self.building_block_id = layer.attrs.get("building_block_id", -1) - self.cost_level = -1 - self.stage_type = layer.attrs.get("stage_type", None) - self.in_start_block = layer.attrs.get("in_start_block", False) - self.in_end_block = layer.attrs.get("in_end_block", False) - self.in_slowest_block = layer.attrs.get("in_slowest_block", False) - for i, input in enumerate(layer.inputs): - if input is None: - self._inputs.append(None) - self.op_data[f'input{i}'] = None - continue - input = input.copy() - input.attrs["broadcast_dims"] = [] - self._inputs.append(input) - self.op_data[f'input{i}'] = input - self.global_to_local_op_name[input.name] = f'input{i}' - - for i, output in enumerate(layer.outputs): - output = output.copy() - output.attrs["broadcast_dims"] = [] - self._outputs.append(output) - self.op_data[f'output{i}'] = output - self.global_to_local_op_name[output.name] = f'output{i}' - - self.sharding_weight = 1.0 - self.resharding_weight = 1.0 - self.pipeline_weight = 0 - self.node_name = layer.name - self.node_type = 'normal_node' - self.num_inputs = layer.num_inputs - self.num_outputs = layer.num_outputs - self.dtype = layer.as_trt().precision - self.strategies_vector = [] - self.node_runtime_profiler = None - - def post_init(self, graph): - for input in self.inputs: - if input is None: - self.predecessor_nodes.append(None) - continue - if input.producer is None: - predecessor_node = graph.get_node(input.name) - self.predecessor_nodes.append(predecessor_node) - self.predecessor_nodes_out_index[predecessor_node] = 0 - predecessor_node.successor_nodes.append(self) - else: - predecessor_node = graph.get_node(input.producer.name) - self.predecessor_nodes.append(predecessor_node) - self.predecessor_nodes_out_index[ - predecessor_node] = input.output_index - predecessor_node.successor_nodes.append(self) - - @property - def layer(self): - return self._layer - - def get_input(self, index): - return self._inputs[index] - - @property - def inputs(self): - return self._inputs - - def get_output(self, index): - return self._outputs[index] - - @property - def outputs(self): - return self._outputs - - def collect_strategies(self, device_mesh): - strategies_vector = self._collect_strategies(device_mesh) - strategies_vector = self._post_process(strategies_vector) - self._update_sharding_cost(strategies_vector, device_mesh) - self.strategies_vector = strategies_vector - return self.strategies_vector - - def _set_strategy(self, strategy, device_mesh): - strategies_vector = StrategiesVector(self) - if strategy is None: - dim_partition_dict_mapping = {} - for i in range(self.num_inputs): - dim_partition_dict_mapping[f'input{i}'] = {} - for i in range(self.num_outputs): - dim_partition_dict_mapping[f'output{i}'] = {} - - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - assert 0 != len( - sharding_spec_mapping - ), f'failed to set default(all Replicate) strategy for node {self.node_name}' - name = 'RRs' - sharding_strategy = self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - strategies_vector.append(sharding_strategy) - - else: - sharding_specs_map = strategy.sharding_specs - comm_specs_map = strategy.communication_actions - dim_partition_dict_mapping = {} - for op_name, sharding_spec in sharding_specs_map.items(): - dim_partition_dict_mapping[ - op_name] = sharding_spec.dim_partition_dict - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - assert 0 != len( - sharding_spec_mapping - ), f'failed to set strategy for node {self.node_name}' - comm_specs_mapping = {} - if len(comm_specs_map) > 0: - for op_name, comm_spec in comm_specs_map.items(): - comm_specs_mapping[op_name] = CommSpec( - comm_pattern=comm_spec.comm_pattern, - sharding_spec=sharding_spec_mapping[op_name], - logical_process_axis=comm_spec.logical_process_axis, - ) - strategies_vector.append( - self._get_sharding_strategy( - name=strategy.name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=comm_specs_mapping)) - return strategies_vector - - def set_strategy(self, strategy, device_mesh): - strategies_vector = self._set_strategy(strategy, device_mesh) - strategies_vector = self._post_process(strategies_vector) - self._update_sharding_cost(strategies_vector, device_mesh) - self.strategies_vector = strategies_vector - return self.strategies_vector - - def update_resharding_cost(self): - self._update_resharding_cost(self.strategies_vector) - return self.strategies_vector - - def _to_sharding_spec_mapping(self, dim_partition_dict_mapping, - device_mesh): - results = {} - for op_data_name, dim_partition_dict in dim_partition_dict_mapping.items( - ): - if op_data_name in self.op_data: - op_data = self.op_data[op_data_name] - - def _to_sharding_spec(op_data, dim_partition_dict): - sharding_spec = ShardingSpec( - device_mesh, - op_data.dtype_str_size, [*op_data.shape], - [*op_data.max_shape], [*op_data.raw_shape], - dim_partition_dict=dim_partition_dict) - if sharding_spec.sanity_check(): - return sharding_spec - else: - return None - - sharding_spec = _to_sharding_spec(op_data, dim_partition_dict) - if sharding_spec: - results[op_data_name] = sharding_spec - else: - return {} - return results - - def _get_sharding_strategy(self, name, sharding_spec_mapping, - communication_action_mapping): - return ShardingStrategy( - name=name, - sharding_specs=sharding_spec_mapping, - communication_actions=communication_action_mapping, - ) - - def _remove_duplicated_strategy(self, strategies_vector): - name_checklist = [] - remove_list = [] - for strategy in strategies_vector: - if strategy.name not in name_checklist: - name_checklist.append(strategy.name) - else: - remove_list.append(strategy) - for strategy in remove_list: - strategies_vector.remove(strategy) - - def _post_process(self, strategies_vector): - # TODO: deal with transpose and dimension 1 problem in ClossalAI, which have been processed before - for i in range(len(strategies_vector) - 1, -1, -1): - if strategies_vector[i] is None: - strategies_vector.pop(i) - - self._remove_duplicated_strategy(strategies_vector) - return strategies_vector - - def _profile_sharding_cost(self, strategy, device_mesh: LogicalDeviceMesh): - elapsed_time = self.node_runtime_profiler.runtime_profile( - self.layer, {}, {}, strategy, device_mesh) - return elapsed_time - - def _model_sharding_cost_from_s_curve(self, strategy, - device_mesh: LogicalDeviceMesh): - ''' - [ToDo] preprofile the s_curve - ''' - sharding_cost = 0.0 - return sharding_cost - - # this method might be overwritten by some Ops - def _get_math_time(self, strategy, device_mesh: LogicalDeviceMesh): - return 0.0 - - # this method might be overwritten by some Ops - def _get_memory_time(self, strategy, device_mesh: LogicalDeviceMesh): - memory_time = (strategy.inout_memory_footprint / - device_mesh.cluster_info.memory_bw * 1e-3 * - device_mesh.cluster_info.memory_efficiency) - return memory_time - - def _model_sharding_cost_from_alpha_beta(self, strategy, - device_mesh: LogicalDeviceMesh): - math_time = self._get_math_time(strategy, device_mesh) - mem_time = self._get_memory_time(strategy, device_mesh) - return max(math_time, mem_time) - - def _get_communication_cost(self, strategy): - total_comm_cost = 0.0 - for op_data_name, comm_spec in strategy.communication_actions.items(): - comm_cost = comm_spec.get_comm_cost() - total_comm_cost = total_comm_cost + comm_cost - return total_comm_cost - - def _update_sharding_cost(self, strategies, device_mesh: LogicalDeviceMesh): - self._update_memory_cost(strategies) - - if device_mesh.config.sharding_cost_model == CostModel.ALPHA_BETA: - for strategy in strategies: - strategy.sharding_cost = self._model_sharding_cost_from_alpha_beta( - strategy, device_mesh) - elif device_mesh.config.sharding_cost_model == CostModel.S_CURVE: - for strategy in strategies: - strategy.sharding_cost = self._model_sharding_cost_from_s_curve( - strategy, device_mesh) - elif device_mesh.config.sharding_cost_model == CostModel.PROFILE: - for strategy in strategies: - strategy.alpha_beta_cost = self._model_sharding_cost_from_alpha_beta( - strategy, device_mesh) - if self.is_shape_io: - strategy.sharding_cost = strategy.alpha_beta_cost - else: - strategy.sharding_cost = self._profile_sharding_cost( - strategy, device_mesh) - elif device_mesh.config.sharding_cost_model == CostModel.ZERO: - for strategy in strategies: - strategy.sharding_cost = 0.0 - else: - assert False, 'unsupport sharding cost model option: {}'.format( - device_mesh.config.sharding_cost_model) - - for strategy in strategies: - strategy.communication_cost = self._get_communication_cost(strategy) - - def _compute_resharding_cost(self, pre_sharding_sepc, cur_sharding_spec, - op_data): - transform_path, comm_action_sequence, resharding_cost = cur_sharding_spec.device_mesh.shape_consistency_manager.shape_consistency( - pre_sharding_sepc, cur_sharding_spec) - return (transform_path, comm_action_sequence, resharding_cost) - - def _update_resharding_cost(self, strategies): - for strategy in strategies: - resharding_costs = {} - for pre_node, out_index in self.predecessor_nodes_out_index.items(): - if pre_node is None: - continue - pre_node_out_data_name = pre_node.get_output(out_index).name - pre_node_out_data_lname = pre_node.global_to_local_op_name[ - pre_node_out_data_name] - if pre_node_out_data_name not in self.global_to_local_op_name: - print(f"pre_node_out_data_name = {pre_node_out_data_name}") - continue - cur_node_inp_data_lname = self.global_to_local_op_name[ - pre_node_out_data_name] - cur_sharding_spec = strategy.sharding_specs[ - cur_node_inp_data_lname] - - pre_node_out_sharding_specs = [] - for pre_strategy in pre_node.strategies_vector: - pre_node_out_sharding_specs.append( - pre_strategy.sharding_specs[pre_node_out_data_lname]) - - if pre_node not in resharding_costs: - resharding_costs[pre_node.node_name] = [] - for prev_sharding_spec in pre_node_out_sharding_specs: - resharding_cost = self._compute_resharding_cost( - prev_sharding_spec, cur_sharding_spec, - self.op_data[cur_node_inp_data_lname]) - resharding_costs[pre_node.node_name].append(resharding_cost) - strategy.resharding_costs = resharding_costs - - def _enumerate_all_possible_1d_sharding(self, mesh_dim, dim_size): - dim_partition_list = [] - for i in range(dim_size): - dim_partition_list.append({i: mesh_dim}) - return dim_partition_list - - def _enumerate_all_possible_2d_sharding(self, mesh_dim0, mesh_dim1, - dim_size): - dim_partition_list = [] - for i in range(dim_size): - for j in range(dim_size): - if i != j: - dim_partition_list.append({i: mesh_dim0, j: mesh_dim1}) - return dim_partition_list - - def _update_memory_cost(self, strategies): - for strategy in strategies: - inout_memory_footprint, max_inout_memory_footprint = 0.0, 0.0 - for spec in strategy.sharding_specs.values(): - inout_memory_footprint += spec.get_sharded_size_per_device() - max_inout_memory_footprint += spec.get_max_sharded_size_per_device( - ) - - # the communication happens - comm_buffer_footprint, max_comm_buffer_footprint = 0.0, 0.0 - for comm_spec in strategy.communication_actions.values(): - comm_buffer_footprint += comm_spec.get_mem_cost() - max_comm_buffer_footprint += comm_spec.get_max_mem_cost() - - # when doing the output0 comm action, the input buffer should be released, the buffer is used to estimate the memory time - # rather than memory usage - strategy.inout_memory_footprint = inout_memory_footprint - - strategy.comm_buff_memory_footprint = comm_buffer_footprint - strategy.peak_memory_footprint = max(max_inout_memory_footprint, - max_comm_buffer_footprint) - - # The const memory (weight) is recorded in constant layers and should be accumulated - strategy.const_memory_footprint = 0.0 - - def _generate_bcast_dims(self, batch_dims, out_data_shape): - for output in self.outputs: - if output.broadcast_across_batch: - for bs in batch_dims: - if output.shape[bs] == 1 and output.shape[ - bs] != out_data_shape[bs]: - output.attrs["broadcast_dims"].append(bs) - - def _recover_bcast_partition_dict(self, partition_dict, op_data): - ret = {} - for data_dim, mesh_dim in partition_dict.items(): - if data_dim not in op_data.attrs[ - "broadcast_dims"] and data_dim + len( - op_data.shape) not in op_data.attrs[ - "broadcast_dims"] and op_data.shape[data_dim] != 1: - ret[data_dim] = mesh_dim - return ret diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/normalization_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/normalization_node.py deleted file mode 100644 index 3cfed50232..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/normalization_node.py +++ /dev/null @@ -1,60 +0,0 @@ -from .node import Node -from .sharding_strategy import StrategiesVector - - -class Normalization(Node): - - def __init__(self, layer): - super().__init__(layer) - layer.to_subclass() - self.axes = layer.as_trt().axes - self.weight_bias_dim_base = 0 - layer.to_base_class() - - def _collect_strategies(self, device_mesh): - dim_partition_list = [] - dim_size = len(self.op_data['input0'].shape) - dim_partition_list.append({}) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) - strategies_vector = StrategiesVector(self) - for dim_partition_dict in dim_partition_list: - shard_reduction_axes = False - for dim in range(len(self.get_input(0).shape)): - if (self.axes & (1 << dim)) and dim in dim_partition_dict: - shard_reduction_axes = True - break - if shard_reduction_axes: - continue - dim_partition_dict_mapping = { - "input0": dim_partition_dict, - "output0": dim_partition_dict, - } - if self.num_inputs >= 2: - dim_partition_dict_mapping['input1'] = {} - if self.num_inputs >= 3: - dim_partition_dict_mapping['input2'] = {} - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if 0 == len(sharding_spec_mapping): - continue - name = '{} = {} scale {}, bias {}'.format( - sharding_spec_mapping['output0'].sharding_sequence, - sharding_spec_mapping['input0'].sharding_sequence, - sharding_spec_mapping['input1'].sharding_sequence - if self.num_inputs >= 2 else 'None', - sharding_spec_mapping['input2'].sharding_sequence - if self.num_inputs >= 3 else 'None', - ) - sharding_strategy = self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - strategies_vector.append(sharding_strategy) - return strategies_vector diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/output_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/output_node.py deleted file mode 100644 index 09d500797e..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/output_node.py +++ /dev/null @@ -1,79 +0,0 @@ -from .node import Node -from .sharding_strategy import StrategiesVector - - -class OuputNode(Node): - - def _update_memory_cost(self, strategies): - for strategy in strategies: - if not self.no_memory_footprint: - strategy.const_memory_footprint = strategy.sharding_specs[ - 'input0'].get_max_sharded_size_per_device() - - def __init__(self, tensor): - self._layer = None - self.is_shape_io = False - self._inputs = [] - self._outputs = [] - self.predecessor_nodes = [] - self.predecessor_nodes_out_index = {} - self.successor_nodes = [] - self.op_data = {} - self.global_to_local_op_name = {} - self.is_replicated = tensor.attrs.get("is_replicated", False) - self.same_spec_id = tensor.attrs.get("same_spec_id", -1) - self.no_memory_footprint = tensor.attrs.get("no_memory_footprint", - False) - self.building_block_id = -1 - self.cost_level = -1 - self.stage_type = None - self.in_start_block = None - self.in_end_block = None - self.in_slowest_block = None - input = tensor.copy() - self._inputs.append(input) - self.op_data['input0'] = input - self.global_to_local_op_name[input.name] = 'input0' - - self.sharding_weight = 1.0 - self.resharding_weight = 1.0 - self.pipeline_weight = 0 - self.node_name = tensor.name - self.node_type = 'output_node' - self.num_inputs = 0 - self.num_outputs = 1 - self.dtype = tensor.dtype - self.strategies_vector = [] - self.node_runtime_profiler = None - - def _collect_strategies(self, device_mesh): - dim_partition_list = [] - dim_size = len(self.op_data['input0'].shape) - dim_partition_list.append({}) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) - - strategies_vector = StrategiesVector(self) - for dim_partition_dict in dim_partition_list: - dim_partition_dict_mapping = {'input0': dim_partition_dict} - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if 0 == len(sharding_spec_mapping): - continue - sharding_seq = sharding_spec_mapping['input0'].sharding_sequence - sharding_strategy = self._get_sharding_strategy( - name=f'output-op {sharding_seq}', - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - strategies_vector.append(sharding_strategy) - - return strategies_vector - - def _profile_sharding_cost(self, strategy, device_mesh): - return 0.0 diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/p2p_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/p2p_node.py deleted file mode 100644 index 8042a89d41..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/p2p_node.py +++ /dev/null @@ -1,67 +0,0 @@ -import copy -from enum import Enum - -from .comm_spec import CommSpec -from .identity_node import Identity -from .sharding_strategy import StrategiesVector - - -class P2PType(Enum): - CROSS_DEVICE = 0 - CROSS_HOST = 1 - - -class P2PNode(Identity): - - def __init__(self, layer): - super().__init__(layer) - self.p2p_type = layer.attrs["p2p_type"] - self.is_fake = True - - def _collect_strategies(self, device_mesh): - # one input for softmax node - predecessor = self.predecessor_nodes[0] - strategies_vector = StrategiesVector(self) - for idx, strategy in enumerate(predecessor.strategies_vector): - # current node's local name input0 -> global name xxx - global_input_name = self.op_data['input0'].name - # global name xxx -> pre node local output name - prenode_local_name = predecessor.global_to_local_op_name[ - global_input_name] - dim_partition_dict = copy.deepcopy( - strategy.sharding_specs[prenode_local_name].dim_partition_dict) - in0_partition_dict = dim_partition_dict - out_partition_dict = copy.deepcopy(dim_partition_dict) - dim_partition_dict_mapping = { - "input0": in0_partition_dict, - "output0": out_partition_dict, - } - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if 0 == len(sharding_spec_mapping): - continue - - logical_process_axis = [ - ['p2p_cross_device'] - ] if self.p2p_type == P2PType.CROSS_DEVICE else [['p2p_cross_host']] - # get communication action mapping - communication_action_mapping = {} - output0_comm_action = CommSpec( - comm_pattern='peer_to_peer', - sharding_spec=sharding_spec_mapping['output0'], - logical_process_axis=logical_process_axis, - ) - communication_action_mapping['output0'] = output0_comm_action - - name = '{} = {}'.format( - sharding_spec_mapping['output0'].sharding_sequence, - sharding_spec_mapping['input0'].sharding_sequence) - sharding_strategy = self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) - strategies_vector.append(sharding_strategy) - return strategies_vector - - def _profile_sharding_cost(self, strategy, device_mesh): - return 0.0 diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_node.py deleted file mode 100644 index 419308c7c4..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_node.py +++ /dev/null @@ -1,40 +0,0 @@ -from tensorrt_llm.network import PluginInfo, get_plugin_info - -from .node import Node -from .sharding_strategy import StrategiesVector - - -class PluginNode(Node): - - def __init__(self, layer): - super().__init__(layer) - layer.to_subclass() - self.plugin = layer.as_trt().plugin - self.plugin_type: str = self.plugin.plugin_type - self.plugin_info: PluginInfo = get_plugin_info(layer.graph.as_trt(), - layer.name) - layer.to_base_class() - - def _collect_strategies(self, device_mesh): - raise NotImplementedError( - f"Auto parallel does not support {self.plugin_type} plugin right now." - ) - - def _default_strategy(self, device_mesh): - strategies_vector = StrategiesVector(self) - dim_partition_dict_mapping = {} - for idx in range(self.num_inputs): - dim_partition_dict_mapping[f'input{idx}'] = {} - for idx in range(self.num_outputs): - dim_partition_dict_mapping[f'output{idx}'] = {} - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if 0 == len(sharding_spec_mapping): - return strategies_vector - name = '{}_all_replicate'.format(self.plugin_type) - sharding_strategy = self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - strategies_vector.append(sharding_strategy) - return strategies_vector diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/__init__.py b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gemm_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gemm_node.py deleted file mode 100644 index fc52372b97..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gemm_node.py +++ /dev/null @@ -1,27 +0,0 @@ -import tensorrt as trt - -from tensorrt_llm._utils import trt_dtype_to_str - -from ..matmul_node import MatrixMultiply -from ..plugin_node import PluginNode - - -class GemmPlugin(MatrixMultiply, PluginNode): - - def __init__(self, layer): - PluginNode.__init__(self, layer) - batch_dims = [i for i in range(len(self.get_output(0).shape))][:-2] - self._generate_bcast_dims(batch_dims, self.get_output(0).shape) - pfc_as_list = self.plugin_info.pfc_as_list - self.op0_transpose = (pfc_as_list['transa'][0] == 1) - self.op1_transpose = (pfc_as_list['transb'][0] == 1) - self.num_out_dims = len(self.get_output(0).shape) - self.dtype = trt_dtype_to_str(trt.DataType(pfc_as_list['type_id'][0])) - - def _collect_strategies(self, device_mesh): - strategies_vector = MatrixMultiply._collect_strategies( - self, device_mesh) - return strategies_vector - - def _get_math_time(self, strategy, device_mesh): - return MatrixMultiply._get_math_time(self, strategy, device_mesh) diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gpt_attention_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gpt_attention_node.py deleted file mode 100644 index 1008e64c8b..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gpt_attention_node.py +++ /dev/null @@ -1,437 +0,0 @@ -from enum import Enum, auto - -import numpy as np -import torch - -from tensorrt_llm.functional import AttentionMaskType, PositionEmbeddingType -from tensorrt_llm.quantization import QuantMode - -from ..plugin_node import PluginNode -from ..sharding_strategy import StrategiesVector - - -# WARNING: Must in sync with IdxEntry in cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h -class IdxEntry(Enum): - QKV_TENSOR = auto() - K_TENSOR = auto() - V_TENSOR = auto() - CONTEXT_FMHA_CUSTOM_MASK = auto() - SEQUENCE_LENGTH = auto() - HOST_PAST_KEY_VALUE_LENGTHS = auto() - HOST_MAX_ATTENTION_WINDOW = auto() - HOST_SINK_TOKEN_LENGTH = auto() - CONTEXT_LENGTHS = auto() - CACHE_INDIR = auto() - REQUEST_TYPES = auto() - KV_CACHE_BLOCK_OFFSETS = auto() - HOST_KV_CACHE_BLOCK_OFFSETS = auto() - HOST_KV_CACHE_POOL_POINTERS = auto() - HOST_KV_CACHE_POOL_MAPPING = auto() - PAST_KEY_VALUE = auto() - KV_CACHE_QUANTIZATION_SCALE = auto() - KV_CACHE_DEQUANTIZATION_SCALE = auto() - ATTENTION_OUTPUT_QUANTIZATION_SCALE = auto() - ATTENTION_OUTPUT_SF_SCALE = auto() - ROTARY_INV_FREQ = auto() - ROTARY_COS_SIN = auto() - ALIBI_SLOPES = auto() - RELATIVE_ATTENTION_BIAS = auto() - CROSS_KV = auto() - CROSS_KV_LENGTH = auto() - ENCODER_INPUT_LENGTH = auto() - HOST_CONTEXT_LENGTH = auto() - QKV_BIAS_TENSOR = auto() - SPEC_DECODING_GENERATION_LENGTHS = auto() - SPEC_DECODING_PACKED_MASK = auto() - SPEC_DECODING_POSITION_OFFSETS = auto() - MROPE_ROTARY_COS_SIN = auto() - MROPE_POSITION_DELTAS = auto() - HOST_RUNTIME_PERF_KNOBS = auto() - HOST_CONTEXT_PROGRESS = auto() - MLA_FUSED_Q_PROJ_TENSOR = auto() - MLA_Q_B_PROJ_TENSOR = auto() - MLA_KV_B_PROJ_TENSOR = auto() - LOGN_SCALING = auto() - - -class IdxEntryParser: - - def __init__(self, plugin_info): - self.num_kv_heads = plugin_info.pfc_as_list['num_kv_heads'][0] - self.unfuse_qkv_gemm = bool( - plugin_info.pfc_as_list['unfuse_qkv_gemm'][0]) - self.use_fp8_context_fmha = bool( - plugin_info.pfc_as_list['use_fp8_context_fmha'][0]) - self.fuse_fp4_quant = bool(plugin_info.pfc_as_list['fuse_fp4_quant'][0]) - self.mask_type = AttentionMaskType( - plugin_info.pfc_as_list['mask_type'][0]) - self.use_cache = bool(plugin_info.pfc_as_list['use_cache'][0]) - self.paged_kv_cache = bool(plugin_info.pfc_as_list['paged_kv_cache'][0]) - self.do_cross_attention = bool( - plugin_info.pfc_as_list['do_cross_attention'][0]) - self.remove_input_padding = bool( - plugin_info.pfc_as_list['remove_input_padding'][0]) - self.qkv_bias_enabled = bool( - plugin_info.pfc_as_list['qkv_bias_enabled'][0]) - self.kv_cache_quant_mode = QuantMode( - plugin_info.pfc_as_list['kv_cache_quant_mode'][0]) - self.position_embedding_type = PositionEmbeddingType( - plugin_info.pfc_as_list['position_embedding_type'][0]) - self.is_spec_decoding_enabled = bool( - plugin_info.pfc_as_list['is_spec_decoding_enabled'][0]) - self.is_mla_enabled = bool(plugin_info.pfc_as_list['is_mla_enabled'][0]) - self.use_logn_scaling = bool( - plugin_info.pfc_as_list['use_logn_scaling'][0]) - self.init_entry_to_index() - - # WARNING: Must in sync with GPTAttentionPlugin::isEntryUsed in cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp - def is_entry_used(self, entry: IdxEntry) -> bool: - if entry == IdxEntry.QKV_TENSOR: - return True - elif entry == IdxEntry.K_TENSOR: - return self.unfuse_qkv_gemm - elif entry == IdxEntry.V_TENSOR: - return self.unfuse_qkv_gemm - elif entry == IdxEntry.CONTEXT_FMHA_CUSTOM_MASK: - return self.mask_type == AttentionMaskType.custom_mask - elif entry == IdxEntry.SEQUENCE_LENGTH: - return self.use_cache - elif entry == IdxEntry.HOST_PAST_KEY_VALUE_LENGTHS: - return self.use_cache - elif entry == IdxEntry.HOST_MAX_ATTENTION_WINDOW: - return True - elif entry == IdxEntry.HOST_SINK_TOKEN_LENGTH: - return True - elif entry == IdxEntry.CONTEXT_LENGTHS: - return True - elif entry == IdxEntry.CACHE_INDIR: - return self.use_cache - elif entry == IdxEntry.REQUEST_TYPES: - return True - elif entry == IdxEntry.KV_CACHE_BLOCK_OFFSETS: - return self.use_cache and self.paged_kv_cache - elif entry == IdxEntry.HOST_KV_CACHE_BLOCK_OFFSETS: - return self.use_cache and self.paged_kv_cache - elif entry == IdxEntry.HOST_KV_CACHE_POOL_POINTERS: - return self.use_cache and self.paged_kv_cache - elif entry == IdxEntry.HOST_KV_CACHE_POOL_MAPPING: - return self.use_cache and self.paged_kv_cache - elif entry == IdxEntry.PAST_KEY_VALUE: - return self.use_cache and not self.paged_kv_cache - elif entry == IdxEntry.KV_CACHE_QUANTIZATION_SCALE: - return self.use_cache and self.kv_cache_quant_mode.has_kv_cache_quant( - ) - elif entry == IdxEntry.KV_CACHE_DEQUANTIZATION_SCALE: - return self.use_cache and self.kv_cache_quant_mode.has_kv_cache_quant( - ) - elif entry == IdxEntry.ATTENTION_OUTPUT_QUANTIZATION_SCALE: - return self.use_fp8_context_fmha and self.kv_cache_quant_mode.has_fp8_qdp( - ) - elif entry == IdxEntry.ATTENTION_OUTPUT_SF_SCALE: - return self.fuse_fp4_quant - elif entry == IdxEntry.ROTARY_INV_FREQ: - return self.position_embedding_type.is_rope() - elif entry == IdxEntry.ROTARY_COS_SIN: - return self.position_embedding_type.is_rope() - elif entry == IdxEntry.ALIBI_SLOPES: - return self.position_embedding_type.is_alibi() - elif entry == IdxEntry.RELATIVE_ATTENTION_BIAS: - return self.position_embedding_type == PositionEmbeddingType.relative - elif entry == IdxEntry.CROSS_KV: - return self.do_cross_attention - elif entry == IdxEntry.CROSS_KV_LENGTH: - return self.do_cross_attention - elif entry == IdxEntry.ENCODER_INPUT_LENGTH: - return self.do_cross_attention - elif entry == IdxEntry.HOST_CONTEXT_LENGTH: - return self.remove_input_padding - elif entry == IdxEntry.QKV_BIAS_TENSOR: - return self.qkv_bias_enabled - elif entry == IdxEntry.SPEC_DECODING_PACKED_MASK: - return self.is_spec_decoding_enabled - elif entry == IdxEntry.SPEC_DECODING_POSITION_OFFSETS: - return self.is_spec_decoding_enabled - elif entry == IdxEntry.SPEC_DECODING_GENERATION_LENGTHS: - return self.is_spec_decoding_enabled - elif entry == IdxEntry.MROPE_ROTARY_COS_SIN: - return self.position_embedding_type.is_mrope() - elif entry == IdxEntry.MROPE_POSITION_DELTAS: - return self.position_embedding_type.is_mrope() - elif entry == IdxEntry.HOST_RUNTIME_PERF_KNOBS: - return True - elif entry == IdxEntry.HOST_CONTEXT_PROGRESS: - return True - elif entry == IdxEntry.MLA_FUSED_Q_PROJ_TENSOR: - return self.is_mla_enabled - elif entry == IdxEntry.MLA_Q_B_PROJ_TENSOR: - return self.is_mla_enabled - elif entry == IdxEntry.MLA_KV_B_PROJ_TENSOR: - return self.is_mla_enabled - elif entry == IdxEntry.LOGN_SCALING: - return self.use_logn_scaling - else: - return False - - def init_entry_to_index(self): - self.entry_to_index = {} - index = 0 - for entry in IdxEntry: - if self.is_entry_used(entry): - self.entry_to_index[entry] = index - index += 1 - - def get_index(self, entry: IdxEntry) -> int: - if entry not in self.entry_to_index: - raise Exception( - f"Entry {entry} is not existed in gpt attention plugin layer {self.layer.name}" - ) - return self.entry_to_index[entry] - - -def get_partition(device_dim, device_ids): - if device_dim == [0]: - partition = device_ids.shape[0] - elif device_dim == [1]: - partition = device_ids.shape[1] - else: - assert device_dim == [0, 1] or device_dim == [1, 0] - partition = device_ids.size - return partition - - -class GPTAttentionPlugin(PluginNode): - - def __init__(self, layer): - super().__init__(layer) - self.parser = IdxEntryParser(self.plugin_info) - assert self.num_inputs == len( - self.parser.entry_to_index - ), f'the number of plugin inputs ({self.num_inputs}) is invalid' - assert self.num_outputs == ( - 2 if self.parser.is_entry_used(IdxEntry.PAST_KEY_VALUE) else 1 - ), f'the number of plugin outputs ({self.num_outputs}) has been changed' - - def _tp_strategy(self, device_mesh): - strategies_vector = StrategiesVector(self) - head_dim = 1 if self.parser.remove_input_padding else 2 - # TODO: allow mesh_dim = [0] or [1] - # for mesh_dim in ([0], [1], [0, 1]): - for mesh_dim in ([0, 1], ): - if self.parser.num_kv_heads != 1: - # MHA or GQA - # TODO: allow to duplicate kv when #kv_head < #partition - q_pdict = { - head_dim: mesh_dim - } # split in heads/hidden dimension - k_pdict = { - head_dim: mesh_dim - } # split in heads/hidden dimension - v_pdict = { - head_dim: mesh_dim - } # split in heads/hidden dimension - pastkv_pdict = {2: mesh_dim} # split in heads dimension - present_kv_pdict = {2: mesh_dim} # split in heads dimension - else: - # MQA - q_pdict = { - head_dim: mesh_dim - } # split in heads/hidden dimension - k_pdict = {} # RR - v_pdict = {} # RR - pastkv_pdict = {} # RR - present_kv_pdict = {} # RR - - out0_pdict = {head_dim: mesh_dim} - - dim_partition_dict_mapping = { - f'input{self.parser.get_index(IdxEntry.QKV_TENSOR)}': q_pdict, - f'input{self.parser.get_index(IdxEntry.K_TENSOR)}': k_pdict, - f'input{self.parser.get_index(IdxEntry.V_TENSOR)}': v_pdict, - 'output0': out0_pdict, - } - if self.parser.is_entry_used(IdxEntry.PAST_KEY_VALUE): - dim_partition_dict_mapping[ - f'input{self.parser.get_index(IdxEntry.PAST_KEY_VALUE)}'] = pastkv_pdict - dim_partition_dict_mapping['output1'] = present_kv_pdict - for i in range(self.num_inputs): - if f'input{i}' not in dim_partition_dict_mapping: - dim_partition_dict_mapping[f'input{i}'] = {} - for i in range(self.num_outputs): - if f'output{i}' not in dim_partition_dict_mapping: - dim_partition_dict_mapping[f'output{i}'] = {} - - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if 0 == len(sharding_spec_mapping): - continue - name = 'gptAttentionPlugin_tp_strategy' - sharding_strategy = self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - strategies_vector.append(sharding_strategy) - return strategies_vector - - def _dp_strategy(self, device_mesh): - strategies_vector = StrategiesVector(self) - for mesh_dim in ([0], [1], [0, 1]): - dim_partition_dict_mapping = {} - for i in range(self.num_inputs): - dim_partition_dict_mapping[f'input{i}'] = {0: mesh_dim} - for i in range(self.num_outputs): - dim_partition_dict_mapping[f'output{i}'] = {0: mesh_dim} - - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if 0 == len(sharding_spec_mapping): - continue - name = 'gptAttentionPlugin_dp_strategy' - sharding_strategy = self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - strategies_vector.append(sharding_strategy) - return strategies_vector - - def _collect_strategies(self, device_mesh): - if device_mesh.size == 1: - default_strategies = self._default_strategy(device_mesh) - else: - # Avoid to use all-replicate strategy for mesh size > 1 - # since the CPP runtime does not support it for gpt attention plugin - default_strategies = StrategiesVector(self) - for idx, strategy in enumerate(default_strategies): - strategy.name = 'gptAttentionPlugin_' + strategy.name + f'{idx}' - if self.parser.unfuse_qkv_gemm: - tp_strategies = self._tp_strategy(device_mesh) - default_strategies.extend(tp_strategies) - # if we don't split the batch dim, it should be default strategis - # elif we split the batch dim, it should be dp_strategies - # we can use above information to distinguish the two kinds of strategy - if not self.parser.remove_input_padding: - dp_strategies = self._dp_strategy(device_mesh) - default_strategies.extend(dp_strategies) - return default_strategies - - @staticmethod - def parameter_generator(sharding_specs, plugin_info): - - def get_shape(entry): - return sharding_specs[ - f'input{parser.get_index(entry)}'].get_sharded_shape_per_device( - ) - - parser = IdxEntryParser(plugin_info) - updated_input_values = {} - batch_size = get_shape(IdxEntry.CONTEXT_LENGTHS)[0] - if parser.use_cache: - beams_width = get_shape(IdxEntry.CACHE_INDIR)[1] - max_seq_length = get_shape(IdxEntry.CACHE_INDIR)[2] - elif not parser.remove_input_padding: - max_seq_length = get_shape(IdxEntry.QKV_BIAS_TENSOR)[1] - else: - max_seq_length = 1 - host_request_types = torch.full( - (batch_size, ), - 1, - dtype=torch.int32, - device='cpu', - ) - updated_input_values[parser.get_index( - IdxEntry.REQUEST_TYPES)] = host_request_types - context_lengths = torch.full( - (batch_size, ), - max_seq_length - 1, - dtype=torch.int32, - device=torch.cuda.current_device(), - ) - updated_input_values[parser.get_index( - IdxEntry.CONTEXT_LENGTHS)] = context_lengths - host_max_attention_window_sizes = torch.tensor( - [max_seq_length], - dtype=torch.int32, - device='cpu', - ) - updated_input_values[parser.get_index( - IdxEntry.HOST_MAX_ATTENTION_WINDOW - )] = host_max_attention_window_sizes - host_sink_token_length = torch.tensor( - [0], - dtype=torch.int32, - device='cpu', - ) - updated_input_values[parser.get_index( - IdxEntry.HOST_SINK_TOKEN_LENGTH)] = host_sink_token_length - if parser.use_cache: - sequence_length = torch.full((batch_size, ), - max_seq_length, - dtype=torch.int32, - device=torch.cuda.current_device()) - updated_input_values[parser.get_index( - IdxEntry.SEQUENCE_LENGTH)] = sequence_length - host_past_key_value_length = torch.full((batch_size, ), - max_seq_length - 1, - dtype=torch.int32, - device='cpu') - updated_input_values[parser.get_index( - IdxEntry.HOST_PAST_KEY_VALUE_LENGTHS - )] = host_past_key_value_length - cache_indirections = torch.full( - (batch_size, beams_width, max_seq_length), - 0, - dtype=torch.int32, - device=torch.cuda.current_device()) - updated_input_values[parser.get_index( - IdxEntry.CACHE_INDIR)] = cache_indirections - if parser.remove_input_padding: - host_context_lengths = torch.full(get_shape( - IdxEntry.HOST_CONTEXT_LENGTH), - max_seq_length - 1, - dtype=torch.int32, - device='cpu') - updated_input_values[parser.get_index( - IdxEntry.HOST_CONTEXT_LENGTH)] = host_context_lengths - return updated_input_values - - def _profile_sharding_cost(self, strategy, device_mesh): - sharding_spec = strategy.sharding_specs[ - f"input{self.parser.get_index(IdxEntry.QKV_TENSOR)}"] - shard_dims = sharding_spec.dim_partition_dict - device_ids = device_mesh.phy_ids - if 2 in shard_dims: - device_dim = shard_dims[2] - partition = get_partition(device_dim, device_ids) - else: - partition = 1 - if self.parser.is_entry_used(IdxEntry.K_TENSOR): - kv_sharding_spec = strategy.sharding_specs[ - f"input{self.parser.get_index(IdxEntry.K_TENSOR)}"] - kv_shard_dims = kv_sharding_spec.dim_partition_dict - if 2 in kv_shard_dims: - kv_device_dim = kv_shard_dims[2] - kv_partition = get_partition(kv_device_dim, device_ids) - else: - kv_partition = 1 - else: - kv_partition = 1 - num_heads = self.plugin_info.pfc_as_ndarray["num_heads"].copy() - num_kv_heads = self.plugin_info.pfc_as_ndarray["num_kv_heads"].copy() - tp_size = self.plugin_info.pfc_as_ndarray["tp_size"].copy() - tp_rank = self.plugin_info.pfc_as_ndarray["tp_rank"].copy() - num_kv_heads = np.maximum(num_kv_heads // kv_partition, 1) - num_heads = np.maximum(num_heads // partition, 1) - tp_size[0] = partition - tp_rank[0] = 0 - - updated_layer_attrs = { - 'tp_size': tp_size, - 'tp_rank': tp_rank, - 'num_heads': num_heads, - 'num_kv_heads': num_kv_heads - } - updated_input_values = self.parameter_generator(strategy.sharding_specs, - self.plugin_info) - elapsed_time = self.node_runtime_profiler.runtime_profile( - self.layer, updated_layer_attrs, updated_input_values, strategy, - device_mesh) - return elapsed_time diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/identity_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/identity_node.py deleted file mode 100644 index c94b49347d..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/identity_node.py +++ /dev/null @@ -1,11 +0,0 @@ -from ..identity_node import Identity -from ..plugin_node import PluginNode - - -class IdentityPlugin(Identity, PluginNode): - - def __init__(self, layer): - PluginNode.__init__(self, layer) - - def _collect_strategies(self, device_mesh): - return Identity._collect_strategies(self, device_mesh) diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/look_up_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/look_up_node.py deleted file mode 100644 index 38fc9dd199..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/look_up_node.py +++ /dev/null @@ -1,19 +0,0 @@ -import tensorrt as trt - -from ..gather_node import Gather -from ..plugin_node import PluginNode - - -class LookupPlugin(Gather, PluginNode): - - def __init__(self, layer): - PluginNode.__init__(self, layer) - self.mode = trt.GatherMode.DEFAULT - self.axis = 0 - self.num_elementwise_dims = 0 - self.input_id = 1 - self.indice_id = 0 - self.support_vocab_tp = True - - def _collect_strategies(self, device_mesh): - return Gather._collect_strategies(self, device_mesh) diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/normalization_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/normalization_node.py deleted file mode 100644 index 88e0b08f93..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/normalization_node.py +++ /dev/null @@ -1,28 +0,0 @@ -from ..normalization_node import Normalization -from ..plugin_node import PluginNode - - -class LayernormPlugin(Normalization, PluginNode): - - def __init__(self, layer): - PluginNode.__init__(self, layer) - # the is only true for llm model, because layer norm is only effect on hidden dim - hidden_dim = len(self.op_data['input0'].shape) - 1 - self.axes = 1 << hidden_dim - self.weight_bias_dim_base = hidden_dim - - def _collect_strategies(self, device_mesh): - return Normalization._collect_strategies(self, device_mesh) - - -class RMSnormPlugin(Normalization, PluginNode): - - def __init__(self, layer): - PluginNode.__init__(self, layer) - # the is only true for llm model, because rms norm is only effect on hidden dim - hidden_dim = len(self.op_data['input0'].shape) - 1 - self.axes = 1 << hidden_dim - self.weight_bias_dim_base = hidden_dim - - def _collect_strategies(self, device_mesh): - return Normalization._collect_strategies(self, device_mesh) diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/reduce_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/reduce_node.py deleted file mode 100644 index 22f708faa2..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/reduce_node.py +++ /dev/null @@ -1,73 +0,0 @@ -from tensorrt_llm._utils import trt_axes_to_dim - -from .node import Node -from .sharding_strategy import StrategiesVector - - -class Reduce(Node): - - def __init__(self, layer): - super().__init__(layer) - layer.to_subclass() - self.reduce_dims = trt_axes_to_dim(layer.as_trt().axes) - self.sum_mapping_dict = {} - num_input_dims = len(self.get_input(0).shape) - if layer.as_trt().keep_dims: - for i in range(num_input_dims): - self.sum_mapping_dict[i] = i - else: - output_index = 0 - for i in range(num_input_dims): - if i not in self.reduce_dims: - self.sum_mapping_dict[i] = output_index - output_index += 1 - assert output_index == len(self.get_output(0).shape) - layer.to_base_class() - - def _collect_strategies(self, device_mesh): - dim_partition_list = [] - dim_size = len(self.op_data['input0'].shape) - dim_partition_list.append({}) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) - strategies_vector = StrategiesVector(self) - for dim_partition_dict in dim_partition_list: - recover_dims = [] - out_partition_dict = {} - for dim in dim_partition_dict.keys(): - if dim in self.reduce_dims: - recover_dims.append(dim) - elif dim in self.sum_mapping_dict: - out_partition_dict[ - self.sum_mapping_dict[dim]] = dim_partition_dict[dim] - else: - assert 0, f'dim {dim} is not in sum_dims or sum_mapping_dict' - - for dim in recover_dims: - dim_partition_dict.pop(dim) - - in0_parition_dict = dim_partition_dict - dim_partition_dict_mapping = { - "input0": in0_parition_dict, - "output0": out_partition_dict, - } - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if 0 == len(sharding_spec_mapping): - continue - name = '{} = {}'.format( - sharding_spec_mapping['output0'].sharding_sequence, - self.reduce_dims, - sharding_spec_mapping['input0'].sharding_sequence) - sharding_strategy = self._get_sharding_strategy( - name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) - strategies_vector.append(sharding_strategy) - return strategies_vector diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/select_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/select_node.py deleted file mode 100644 index ce74a1295f..0000000000 --- a/tensorrt_llm/auto_parallel/tensor_parallel/select_node.py +++ /dev/null @@ -1,56 +0,0 @@ -from .node import Node -from .sharding_strategy import StrategiesVector - - -class Select(Node): - - def __init__(self, layer): - super().__init__(layer) - batch_dims = [i for i in range(len(self.get_output(0).shape))] - self._generate_bcast_dims(batch_dims, self.get_output(0).shape) - - def _collect_strategies(self, device_mesh): - dim_partition_list = [] - dim_size = len(self.op_data['output0'].shape) - dim_partition_list.append({}) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) - dim_partition_list.extend( - self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) - - strategies_vector = StrategiesVector(self) - for dim_partition_dict in dim_partition_list: - # the three inputs are condition, true tensor and false tensor - in0_partition_dict = self._recover_bcast_partition_dict( - dim_partition_dict, self.op_data['input0']) - in1_partition_dict = self._recover_bcast_partition_dict( - dim_partition_dict, self.op_data['input1']) - in2_partition_dict = self._recover_bcast_partition_dict( - dim_partition_dict, self.op_data['input2']) - out_partition_dict = dim_partition_dict - dim_partition_dict_mapping = { - "input0": in0_partition_dict, - "input1": in1_partition_dict, - "input2": in2_partition_dict, - "output0": out_partition_dict, - } - sharding_spec_mapping = self._to_sharding_spec_mapping( - dim_partition_dict_mapping, device_mesh) - if 0 == len(sharding_spec_mapping): - continue - name = '{} =