mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
613 lines
27 KiB
Python
613 lines
27 KiB
Python
import os
|
|
import re
|
|
from abc import ABC, abstractmethod
|
|
from typing import List
|
|
|
|
import h5py
|
|
import numpy as np
|
|
from filelock import FileLock
|
|
|
|
from .config import AutoParallelConfig, CostModel
|
|
from .tensor_parallel.shape_consistency import ShapeConsistencyManager
|
|
|
|
|
|
class ProfileDB(ABC):
|
|
"""A database that stores profiling results for multiple device mesh
|
|
shapes."""
|
|
|
|
@abstractmethod
|
|
def query(self, cluster_key, data_key):
|
|
...
|
|
|
|
@abstractmethod
|
|
def update(self, cluster_key, data_key, mesh_result):
|
|
...
|
|
|
|
def close(self):
|
|
pass
|
|
|
|
|
|
class MemDB(ProfileDB):
|
|
|
|
def __init__(self):
|
|
self.data = {}
|
|
|
|
def query(self, cluster_key, data_key):
|
|
key = (cluster_key, data_key)
|
|
mesh_result = self.data.get(key, None)
|
|
if mesh_result is None:
|
|
return None
|
|
else:
|
|
return mesh_result[0]
|
|
|
|
def update(self, cluster_key, data_key, mesh_result):
|
|
key = (cluster_key, data_key)
|
|
self.data[key] = mesh_result
|
|
|
|
|
|
class Hdf5DB(ProfileDB):
|
|
|
|
def __init__(self, name):
|
|
self.name = name
|
|
lock_name = self.name + ".lock"
|
|
self.lock = FileLock(lock_name, thread_local=False)
|
|
|
|
def query(self, cluster_key, data_key):
|
|
file_name = f"{self.name}.hdf5"
|
|
key = str((cluster_key, data_key))
|
|
self.lock.acquire()
|
|
mesh_result = None
|
|
with h5py.File(file_name, 'a') as f:
|
|
if key in f:
|
|
self.lock.release()
|
|
mesh_result = f[key]
|
|
return mesh_result[0]
|
|
else:
|
|
return None
|
|
|
|
def update(self, cluster_key, data_key, mesh_result):
|
|
key = str((cluster_key, data_key))
|
|
file_name = f"{self.name}.hdf5"
|
|
with h5py.File(file_name, 'a') as f:
|
|
f[key] = mesh_result
|
|
|
|
def close(self):
|
|
self.lock.release(force=True)
|
|
|
|
|
|
class LogicalDeviceMesh(object):
|
|
|
|
def __init__(self,
|
|
phy_mesh_shape,
|
|
mesh_shape,
|
|
phy_ids,
|
|
config: AutoParallelConfig,
|
|
alpha,
|
|
beta,
|
|
sharp,
|
|
prof_database=None,
|
|
shape_consistency_manager=None,
|
|
host_ips=None):
|
|
self.phy_mesh_shape = phy_mesh_shape
|
|
self.mesh_shape = mesh_shape
|
|
self.phy_ids = phy_ids
|
|
self.host_ips = host_ips
|
|
self.cluster_key = config.cluster_key + '_mesh_shape{}'.format('_'.join(
|
|
[str(i) for i in mesh_shape]))
|
|
self.prof_min_max_size = [1, 2**34]
|
|
self.prof_comm_dtypes = [
|
|
"int8", "uint8", "int32", "uint32", "int64", "uint64", "float16",
|
|
"float32", "float64", "bfloat16"
|
|
]
|
|
self.devices_group = {
|
|
(0, ): [self.phy_ids.transpose(), self.mesh_shape[1] - 1],
|
|
(1, ): [self.phy_ids, self.mesh_shape[1]],
|
|
(0, 1): [self.phy_ids.reshape([1, self.phy_ids.size]), 0]
|
|
}
|
|
self.prof_database = prof_database
|
|
self.shape_consistency_manager = shape_consistency_manager
|
|
self.config = config
|
|
self.cluster_info = config.get_cluster_info()
|
|
self.hw_alpha = alpha
|
|
self.hw_beta = beta
|
|
self.hw_sharp = sharp
|
|
self.algo_alpha_beta = self._estimate_algo_alpha_beta()
|
|
self.comm_op_to_nccl_test_func_name = {
|
|
'all_reduce': 'all_reduce_perf_mpi',
|
|
'all_gather': 'all_gather_perf_mpi',
|
|
'all_to_all': 'alltoall_perf_mpi',
|
|
'reduce_scatter': 'reduce_scatter_perf_mpi',
|
|
'split': 'split',
|
|
}
|
|
|
|
@property
|
|
def size(self) -> int:
|
|
return self.phy_ids.size
|
|
|
|
def _estimate_algo_alpha_beta(self):
|
|
ret = {}
|
|
ar_alpha, ar_beta = {}, {}
|
|
ag_alpha, ag_beta = {}, {}
|
|
rs_alpha, rs_beta = {}, {}
|
|
a2a_alpha, a2a_beta = {}, {}
|
|
phy_num_hosts, phy_num_devices_per_host = self.phy_mesh_shape
|
|
if phy_num_hosts == 1 or phy_num_devices_per_host == 1:
|
|
for dims in [(0, ), (1, ), (0, 1), (1, 0)]:
|
|
num_devices = 1
|
|
for dim in dims:
|
|
num_devices = self.mesh_shape[dim] * num_devices
|
|
if num_devices != 1:
|
|
ar_alpha[dims] = self.hw_alpha[0] if self.hw_sharp[
|
|
0] else self.hw_alpha[0] * num_devices / 2 / (
|
|
num_devices - 1)
|
|
ar_beta[dims] = self.hw_beta[0]
|
|
ag_alpha[dims] = self.hw_alpha[0] * num_devices / (
|
|
num_devices - 1)
|
|
ag_beta[dims] = self.hw_beta[0]
|
|
rs_alpha[dims] = self.hw_alpha[0] * num_devices / (
|
|
num_devices - 1)
|
|
rs_beta[dims] = self.hw_beta[0]
|
|
a2a_alpha[dims] = self.hw_alpha[0] * num_devices / (
|
|
num_devices - 1)
|
|
a2a_beta[dims] = self.hw_beta[0]
|
|
# phy and logical have the same mesh shape if num_hosts > 1 and num_devices_per_host > 1
|
|
else:
|
|
for dims in [(0, ), (1, ), (0, 1), (1, 0)]:
|
|
num_devices = 1
|
|
for dim in dims:
|
|
num_devices = self.mesh_shape[dim] * num_devices
|
|
if num_devices != 1:
|
|
if len(dims) == 1:
|
|
dim = dims[0]
|
|
ar_alpha[dims] = self.hw_alpha[dim] if self.hw_sharp[
|
|
dim] else self.hw_alpha[dim] * num_devices / 2 / (
|
|
num_devices - 1)
|
|
ar_beta[dims] = self.hw_beta[dim]
|
|
ag_alpha[dims] = self.hw_alpha[dim] * num_devices / (
|
|
num_devices - 1)
|
|
ag_beta[dims] = self.hw_beta[dim]
|
|
rs_alpha[dims] = self.hw_alpha[dim] * num_devices / (
|
|
num_devices - 1)
|
|
rs_beta[dims] = self.hw_beta[dim]
|
|
a2a_alpha[dims] = self.hw_alpha[dim] * num_devices / (
|
|
num_devices - 1)
|
|
a2a_beta[dims] = self.hw_beta[dim]
|
|
elif len(dims) == 2: # two level communication
|
|
num_hosts, num_devices_per_host = phy_num_hosts, phy_num_devices_per_host
|
|
inter_node_col_alpha = self.hw_alpha[
|
|
0] * num_devices_per_host
|
|
inter_node_ar_alpha = inter_node_col_alpha if self.hw_sharp[
|
|
0] else inter_node_col_alpha * num_hosts / 2 / (
|
|
num_hosts - 1)
|
|
intra_node_ar_alpha = self.hw_alpha[1]
|
|
intra_node_ar_alpha = intra_node_ar_alpha if self.hw_sharp[
|
|
1] else intra_node_ar_alpha * num_devices_per_host / 2 / (
|
|
num_devices_per_host - 1)
|
|
ar_alpha[dims] = min(inter_node_ar_alpha,
|
|
intra_node_ar_alpha)
|
|
ar_beta[dims] = max(self.hw_beta)
|
|
ag_alpha[dims] = min(
|
|
inter_node_col_alpha * num_hosts / (num_hosts - 1),
|
|
self.hw_alpha[1] * num_devices_per_host /
|
|
(num_devices_per_host - 1))
|
|
ag_beta[dims] = max(self.hw_beta)
|
|
rs_alpha[dims] = ag_alpha[dims]
|
|
rs_beta[dims] = ag_beta[dims]
|
|
a2a_alpha[dims] = min(
|
|
num_hosts * self.hw_alpha[0] / (num_hosts - 1),
|
|
self.hw_alpha[1] * num_hosts)
|
|
a2a_beta[dims] = max(self.hw_beta)
|
|
else:
|
|
pass
|
|
ret['all_to_all'] = [a2a_alpha, a2a_beta]
|
|
ret['all_reduce'] = [ar_alpha, ar_beta]
|
|
ret['all_gather'] = [ag_alpha, ag_beta]
|
|
ret['reduce_scatter'] = [rs_alpha, rs_beta]
|
|
ret['p2p_cross_device'] = [
|
|
self.cluster_info.intra_node_bw_per_device,
|
|
self.cluster_info.intra_node_latency
|
|
]
|
|
ret['p2p_cross_host'] = [
|
|
self.cluster_info.inter_node_bw_per_device,
|
|
self.cluster_info.inter_node_latency
|
|
]
|
|
return ret
|
|
|
|
#[ToDo] stub functions here
|
|
def _profile_split(self, min_max_comm_size):
|
|
comm_size, elapsed_time = [], []
|
|
size = min_max_comm_size[0]
|
|
while size <= min_max_comm_size[1]:
|
|
time = size * 2 / self.cluster_info.memory_bw
|
|
comm_size.append(size)
|
|
elapsed_time.append(time)
|
|
size = size * 2
|
|
return np.array([comm_size, elapsed_time])
|
|
|
|
def _prase_nccl_test_results(self, f_nccl_test_out_log):
|
|
'''[ToDo] There is some dtye that may not been supported by nccl test, using default dtype (float)'''
|
|
start_parse = False
|
|
comm_size, elapsed_time = [], []
|
|
try:
|
|
with open(f_nccl_test_out_log, 'r') as lines:
|
|
for line in lines:
|
|
if start_parse:
|
|
prof_data = re.split(r"[ ]+", line.strip())
|
|
if len(prof_data) != 13:
|
|
continue
|
|
comm_size.append(float(prof_data[0]))
|
|
elapsed_time.append(float(prof_data[5]))
|
|
if 'GB/s' in line and 'us' in line:
|
|
start_parse = True
|
|
except Exception:
|
|
print(f'failed to parse {f_nccl_test_out_log}')
|
|
return comm_size, elapsed_time
|
|
|
|
def _profile_with_nccl_test(self, min_max_comm_size, dtype, device_group,
|
|
func_name, step, workload_key):
|
|
|
|
if func_name == 'split':
|
|
if 2 == step:
|
|
return self._profile_split(min_max_comm_size)
|
|
else:
|
|
return None
|
|
workspace_dir = self.config['profiling_workspace'] + f'/{workload_key}'
|
|
os.makedirs(workspace_dir, exist_ok=True)
|
|
outfile, errfile = workspace_dir + '/profile.out', workspace_dir + '/profile.err'
|
|
if 1 == step:
|
|
num_nodes = len(self.host_ips)
|
|
num_gpus = self.mesh_shape[0] * self.mesh_shape[1]
|
|
ntasks_per_node = num_gpus // num_nodes
|
|
nccl_test_command = '"export NCCL_TESTS_SPLIT_MASK={} && export NCCL_COLLNET_ENABLE=1 && {} -b {} -e {} -g 1 -d {} -f {}"'.format(
|
|
device_group[1], func_name, min_max_comm_size[0],
|
|
min_max_comm_size[1], dtype, 2)
|
|
sbatch_command = '#!/bin/bash\n'
|
|
sbatch_command += '#SBATCH -p {}\n'.format(self.config['partition'])
|
|
sbatch_command += '#SBATCH -A {}\n'.format(self.config['account'])
|
|
sbatch_command += '#SBATCH -J {}\n'.format(self.config['jobname'])
|
|
sbatch_command += '#SBATCH -N {}\n'.format(num_nodes)
|
|
sbatch_command += '#SBATCH -t {}\n'.format(self.config['time'])
|
|
sbatch_command += '#SBATCH --ntasks-per-node={}\n'.format(
|
|
ntasks_per_node)
|
|
sbatch_command += '#SBATCH --exclusive\n'
|
|
sbatch_command += '#SBATCH --mem=0\n'
|
|
sbatch_command += '#SBATCH --network=sharp\n'
|
|
sbatch_command += '#SBATCH --mail-type=FAIL\n'
|
|
srun_command = 'srun --nodes={} --mpi=pmix --ntasks-per-node={} --network=sharp -o {} -e {} --container-image={} bash -c '.format(
|
|
num_nodes, ntasks_per_node, outfile, errfile,
|
|
self.config['container'])
|
|
command = sbatch_command + srun_command + nccl_test_command
|
|
with open(workspace_dir + '/workload.sub', 'w') as f:
|
|
f.write(command)
|
|
with open('./preprofiling_step1.sh', 'a') as f:
|
|
f.write(f'sbatch {workspace_dir}/workload.sub\n')
|
|
return None
|
|
|
|
else:
|
|
comm_size, elapsed_time = self._prase_nccl_test_results(outfile)
|
|
if len(comm_size) < 2:
|
|
assert 0, 'the profiling for {} was failed at step1, please try again'.format(
|
|
workload_key)
|
|
else:
|
|
print(workload_key, comm_size, elapsed_time)
|
|
return np.array([comm_size, elapsed_time])
|
|
|
|
def _profile_single_comm_perf(self, device_group, comm_op, step, data_key):
|
|
results = {}
|
|
func_name = self.comm_op_to_nccl_test_func_name[comm_op]
|
|
for dtype in self.prof_comm_dtypes:
|
|
size_time = self._profile_with_nccl_test(
|
|
self.prof_min_max_size, dtype, device_group, func_name, step,
|
|
data_key + f'_dtype{dtype}')
|
|
results[dtype] = size_time
|
|
return results
|
|
|
|
def profile_all_comms_perf(self, step):
|
|
if self.mesh_shape == (1, 1):
|
|
return None
|
|
mesh_results = self.prof_database.query(self.cluster_key,
|
|
self.mesh_shape)
|
|
if mesh_results:
|
|
return mesh_results
|
|
|
|
mesh_results = {}
|
|
data_key = self.cluster_key + f'_mesh_shape{self.mesh_shape[0]}x{self.mesh_shape[1]}'
|
|
for comm_op in [
|
|
'all_reduce', 'all_to_all', 'all_gather', 'reduce_scatter',
|
|
'split'
|
|
]:
|
|
comm_perf = {}
|
|
for dim, device_group in self.devices_group.items():
|
|
# don't need to profile for mesh dim == 1
|
|
if len(dim) == 1 and self.mesh_shape[dim[0]] == 1:
|
|
continue
|
|
|
|
comm_perf[dim] = self._profile_single_comm_perf(
|
|
device_group, comm_op, step, data_key +
|
|
'_comm_op{}_dim{}'.format(comm_op, ''.join(map(str, dim))))
|
|
mesh_results[comm_op] = comm_perf
|
|
if 2 == step:
|
|
self.prof_database.update(self.cluster_key, self.mesh_shape,
|
|
mesh_results)
|
|
|
|
return mesh_results
|
|
|
|
def _model_comm_cost_from_s_curve(self, size_time_array, realsize):
|
|
assert size_time_array[0][0] <= realsize <= size_time_array[0][-1],\
|
|
'the comm_size: {} is not in the profile range: [{}{}]'\
|
|
.format(realsize, size_time_array[0][0], size_time_array[0][-1])
|
|
return np.interp(realsize, size_time_array[0], size_time_array[1])
|
|
|
|
def _model_comm_cost_from_alpha_beta(self, comm_op, dim_key, size_in_bytes):
|
|
elapsed_time = 0.0
|
|
if 'split' == comm_op:
|
|
elapsed_time = size_in_bytes * 2 / (
|
|
self.cluster_info.memory_bw *
|
|
self.cluster_info.memory_efficiency) * 1e-3
|
|
else:
|
|
dict_alpha, dict_beta = self.algo_alpha_beta[comm_op]
|
|
alpha, beta = dict_alpha[dim_key], dict_beta[dim_key]
|
|
elapsed_time = (size_in_bytes /
|
|
(alpha * self.cluster_info.communication_efficiency)
|
|
* 1e-3) + beta
|
|
return elapsed_time
|
|
|
|
def _input_size_to_comm_size(self, comm_op, dims, input_size):
|
|
ret = input_size
|
|
if 'all_gather' == comm_op:
|
|
for dim in dims:
|
|
ret = ret * self.mesh_shape[dim]
|
|
return ret
|
|
|
|
def estimate_comm_cost(self, comm_op, dim, input_size, dtype):
|
|
|
|
size = self._input_size_to_comm_size(comm_op, dim, input_size)
|
|
if self.config.comm_cost_model == CostModel.S_CURVE:
|
|
mesh_perf = self.prof_database.query(self.cluster_key,
|
|
self.mesh_shape)
|
|
assert mesh_perf is not None, 'the mesh is not profiled, mesh_shape = {}'.format(
|
|
self.mesh_shape)
|
|
comm_op_perf = mesh_perf.get(comm_op, None)
|
|
assert comm_op_perf is not None, '{} is not profiled'.format(
|
|
comm_op)
|
|
elapsed_time = self._model_comm_cost_from_s_curve(
|
|
comm_op_perf[tuple(dim)][dtype], size)
|
|
return elapsed_time
|
|
elif self.config.comm_cost_model == CostModel.ALPHA_BETA:
|
|
elapsed_time = self._model_comm_cost_from_alpha_beta(
|
|
comm_op, tuple(dim), size)
|
|
elif self.config.comm_cost_model == CostModel.PROFILE:
|
|
assert False, 'Unsupported profile based communication cost model now'
|
|
elif self.config.comm_cost_model == CostModel.ZERO:
|
|
elapsed_time = 0.0
|
|
|
|
return elapsed_time # us
|
|
|
|
|
|
class PhysicalDeviceMesh(object):
|
|
|
|
def __init__(self,
|
|
phy_devices_id,
|
|
config: AutoParallelConfig,
|
|
prof_database=None,
|
|
shape_consistency_manager=None,
|
|
host_ips=None):
|
|
self.phy_devices_id = np.array(phy_devices_id)
|
|
self.num_hosts, self.num_devices_per_host = self.phy_devices_id.shape
|
|
self.host_ips = host_ips
|
|
if host_ips is None:
|
|
self.host_ips = [''] * self.num_hosts
|
|
self.config = config
|
|
self.cluster_info = config.get_cluster_info()
|
|
self.prof_database: ProfileDB = prof_database
|
|
self.shape_consistency_manager = shape_consistency_manager
|
|
if self.config.comm_cost_model not in CostModel:
|
|
raise ValueError(
|
|
f'unsupported communication cost model: {self.config.comm_cost_model}'
|
|
)
|
|
if self.config.sharding_cost_model not in CostModel:
|
|
raise ValueError(
|
|
f'unsupported sharding cost model: {self.config.sharding_cost_model}'
|
|
)
|
|
if self.config.comm_cost_model == CostModel.S_CURVE or self.config.sharding_cost_model == CostModel.PROFILE:
|
|
if self.prof_database is None:
|
|
profile_cache = config.profile_cache
|
|
if profile_cache is None:
|
|
self.prof_database = MemDB()
|
|
else:
|
|
self.prof_database = Hdf5DB(profile_cache)
|
|
elif self.config.comm_cost_model == CostModel.ALPHA_BETA:
|
|
assert self.cluster_info.intra_node_bw_per_device > 0, 'intra_node_bw_per_device is needed for alpha_beta method'
|
|
assert self.cluster_info.inter_node_bw_per_device > 0, 'inter_node_bw_per_device is needed for alpha_beta method'
|
|
if self.config.sharding_cost_model == CostModel.ALPHA_BETA:
|
|
assert self.cluster_info.memory_bw > 0, 'memory_bw is needed for alpha_beta method'
|
|
|
|
if not shape_consistency_manager:
|
|
self.shape_consistency_manager = ShapeConsistencyManager()
|
|
|
|
@property
|
|
def size(self) -> int:
|
|
return self.phy_devices_id.size
|
|
|
|
def close(self):
|
|
if self.prof_database is not None:
|
|
self.prof_database.close()
|
|
|
|
def split_pipeline_meshes(
|
|
self, num_stages,
|
|
num_devices_per_stage) -> List["PhysicalDeviceMesh"]:
|
|
sub_meshes = []
|
|
if num_devices_per_stage <= self.num_devices_per_host:
|
|
assert self.num_devices_per_host % num_devices_per_stage == 0, \
|
|
"num_devices_per_host ({}) % num_devices_per_stage ({}) != 0"\
|
|
.format(self.num_devices_per_host, num_devices_per_stage)
|
|
num_clusters_per_host = self.num_devices_per_host // num_devices_per_stage
|
|
num_clusters = self.num_hosts * num_clusters_per_host
|
|
assert num_stages % num_clusters == 0, \
|
|
"num_stages({}) % num_clusters({}) !=0".format(num_stages, num_clusters)
|
|
for mesh_id in range(num_stages):
|
|
cluster_id = mesh_id % num_clusters
|
|
cluster_col = cluster_id % num_clusters_per_host
|
|
cluster_row = cluster_id // num_clusters_per_host
|
|
sub_devices_id = [
|
|
self.phy_devices_id[cluster_row][cluster_col *
|
|
num_devices_per_stage:(
|
|
(cluster_col + 1) *
|
|
num_devices_per_stage)]
|
|
]
|
|
sub_meshes.append(
|
|
PhysicalDeviceMesh(sub_devices_id, self.config,
|
|
self.prof_database,
|
|
self.shape_consistency_manager,
|
|
[self.host_ips[cluster_row]]))
|
|
else:
|
|
assert num_devices_per_stage % self.num_devices_per_host == 0, \
|
|
"num_devices_per_stage ({}) % num_devices_per_host ({}) != 0"\
|
|
.format(num_devices_per_stage, self.num_devices_per_host)
|
|
num_host_per_cluster = num_devices_per_stage // self.num_devices_per_host
|
|
assert self.num_hosts % num_host_per_cluster == 0, \
|
|
"num_hosts ({}) % num_host_per_cluster({}) != 0".format(self.num_hosts, num_host_per_cluster)
|
|
num_clusters = self.num_hosts // num_host_per_cluster
|
|
for mesh_id in range(num_stages):
|
|
cluster_id = mesh_id % num_clusters
|
|
cluster_row = cluster_id * num_host_per_cluster
|
|
sub_devices_id = self.phy_devices_id[cluster_row:(
|
|
cluster_row + num_host_per_cluster)]
|
|
host_ips = self.host_ips[cluster_row:(cluster_row +
|
|
num_host_per_cluster)]
|
|
sub_meshes.append(
|
|
PhysicalDeviceMesh(sub_devices_id, self.config,
|
|
self.prof_database,
|
|
self.shape_consistency_manager,
|
|
host_ips))
|
|
return sub_meshes
|
|
|
|
def _profile_logical_meshes(self, logical_meshes, step):
|
|
for lmesh in logical_meshes:
|
|
lmesh.profile_all_comms_perf(step)
|
|
|
|
def as_logical_mesh(self) -> LogicalDeviceMesh:
|
|
alpha = [
|
|
self.cluster_info.inter_node_bw_per_device,
|
|
self.cluster_info.intra_node_bw_per_device
|
|
]
|
|
beta = [
|
|
self.cluster_info.inter_node_latency,
|
|
self.cluster_info.intra_node_latency
|
|
]
|
|
sharp = [
|
|
self.cluster_info.inter_node_sharp,
|
|
self.cluster_info.intra_node_sharp
|
|
]
|
|
return LogicalDeviceMesh(
|
|
self.phy_devices_id.shape,
|
|
self.phy_devices_id.shape,
|
|
self.phy_devices_id,
|
|
self.config,
|
|
alpha,
|
|
beta,
|
|
sharp,
|
|
self.prof_database,
|
|
self.shape_consistency_manager,
|
|
self.host_ips,
|
|
)
|
|
|
|
def get_logical_meshes(self):
|
|
logical_meshes = []
|
|
# (1, 2) -> (1, 2)
|
|
# (1, 4) -> (2, 2)
|
|
# (1, 8) -> (2, 4)
|
|
# (1, 16) -> (2, 8), (4, 4)
|
|
# (1, 32) -> (2, 16), (4, 8)
|
|
# (1, 48) -> (2, 24), (3, 16), (4, 12), (6, 8)
|
|
# (1, 64) -> (2, 32), (4, 16), (8, 8)
|
|
# we will traverse logical shape's axis in sharding spec, thus (2, 8) contains (8, 2)
|
|
# we will merge logical shapes' axis, thus (2, 8) contains (1, 16) and (16, 1)
|
|
if self.num_hosts == 1:
|
|
alpha = [self.cluster_info.intra_node_bw_per_device]
|
|
beta = [self.cluster_info.intra_node_latency]
|
|
sharp = [self.cluster_info.intra_node_sharp]
|
|
for i in range(2, self.num_devices_per_host):
|
|
if self.num_devices_per_host % i == 0 and i * i <= self.num_devices_per_host:
|
|
lmesh_shape = (i, self.num_devices_per_host // i)
|
|
lmesh_phy_ids = self.phy_devices_id.reshape(lmesh_shape)
|
|
logical_meshes.append(
|
|
LogicalDeviceMesh(self.phy_devices_id.shape,
|
|
lmesh_shape, lmesh_phy_ids,
|
|
self.config, alpha, beta, sharp,
|
|
self.prof_database,
|
|
self.shape_consistency_manager,
|
|
self.host_ips))
|
|
# (8, 1) -> (2, 4)
|
|
# (16, 1) -> (2, 8), (4, 4)
|
|
elif self.num_devices_per_host == 1:
|
|
alpha = [self.cluster_info.inter_node_bw_per_device]
|
|
beta = [self.cluster_info.inter_node_latency]
|
|
sharp = [self.cluster_info.inter_node_sharp]
|
|
for i in range(2, self.num_hosts):
|
|
if self.num_hosts % i == 0 and i * i <= self.num_hosts:
|
|
lmesh_shape = (i, self.num_hosts // i)
|
|
lmesh_phy_ids = self.phy_devices_id.reshape(lmesh_shape)
|
|
logical_meshes.append(
|
|
LogicalDeviceMesh(self.phy_devices_id.shape,
|
|
lmesh_phy_ids, self.config, alpha,
|
|
beta, sharp, self.prof_database,
|
|
self.shape_consistency_manager,
|
|
self.host_ips))
|
|
# (2, 1) -> (2, 1)
|
|
# (2, 8) -> (2, 8)
|
|
# (1, 2) -> (1, 2)
|
|
# (1, 3) -> (1, 3)
|
|
# (1, 5) -> (1, 5)
|
|
if 0 == len(logical_meshes):
|
|
logical_meshes.append(self.as_logical_mesh())
|
|
return logical_meshes
|
|
|
|
'''
|
|
we assume we can evenly split the pipeline and deviceMesh
|
|
'''
|
|
|
|
def _list_all_sub_meshes(self):
|
|
sub_meshes = []
|
|
for num_devices_per_stage in range(1, self.num_devices_per_host + 1):
|
|
if self.num_devices_per_host % num_devices_per_stage == 0:
|
|
num_stages = self.num_hosts * self.num_devices_per_host // num_devices_per_stage
|
|
sub_meshes.append(
|
|
self.split_pipeline_meshes(num_stages,
|
|
num_devices_per_stage)[0])
|
|
for num_hosts_per_stage in range(2, self.num_hosts + 1):
|
|
if self.num_hosts % num_hosts_per_stage == 0:
|
|
num_stages = self.num_hosts // num_hosts_per_stage
|
|
sub_meshes.append(
|
|
self.split_pipeline_meshes(
|
|
num_stages,
|
|
num_hosts_per_stage * self.num_devices_per_host)[0])
|
|
return sub_meshes
|
|
|
|
def list_all_pipeline_configs(self):
|
|
configs = []
|
|
for num_devices_per_stage in range(1, self.num_devices_per_host + 1):
|
|
if self.num_devices_per_host % num_devices_per_stage == 0:
|
|
num_stages = self.num_hosts * self.num_devices_per_host // num_devices_per_stage
|
|
configs.append((num_stages, num_devices_per_stage))
|
|
for num_hosts_per_stage in range(2, self.num_hosts + 1):
|
|
if self.num_hosts % num_hosts_per_stage == 0:
|
|
num_stages = self.num_hosts // num_hosts_per_stage
|
|
configs.append(
|
|
(num_stages,
|
|
num_hosts_per_stage * self.num_devices_per_host))
|
|
return configs
|
|
|
|
def profile_s_curve(self, step):
|
|
sub_phy_device_meshes = self._list_all_sub_meshes()
|
|
for phy_mesh in sub_phy_device_meshes:
|
|
lmeshes = phy_mesh.get_logical_meshes()
|
|
self._profile_logical_meshes(lmeshes, step)
|
|
if 2 == step:
|
|
self.save_profile_database()
|
|
|
|
def profile_alpha_beta(self):
|
|
alpha = [250, 25]
|
|
beta = [100, 100]
|
|
return alpha, beta
|