mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
update #define for hopper & blackwell Update cpp/tensorrt_llm/thop/fusedAddRMSNormQuant.cpp Signed-off-by: jintaop <jintaop@nvidia.com>
687 lines
22 KiB
Python
687 lines
22 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from enum import IntEnum
|
|
from typing import List
|
|
|
|
import torch
|
|
from torch.distributed import ProcessGroup
|
|
|
|
from tensorrt_llm._torch.device_mesh import DeviceMeshTopologyImpl
|
|
from tensorrt_llm._utils import mpi_disabled
|
|
|
|
|
|
class CpType(IntEnum):
|
|
# CP type for ulysses parallelism
|
|
ULYSSES = 0
|
|
# CP type for star attention
|
|
STAR = 1
|
|
# CP type for ring attention
|
|
RING = 2
|
|
# CP type for helix parallelism
|
|
HELIX = 3
|
|
|
|
|
|
class MappingBase:
|
|
"""Base class for distributed mapping configurations"""
|
|
|
|
tp_rank: int
|
|
pp_rank: int
|
|
cp_rank: int
|
|
|
|
def __init__(
|
|
self,
|
|
world_size=1,
|
|
rank=0,
|
|
gpus_per_node=8,
|
|
*,
|
|
cp_size=1,
|
|
cp_config=None,
|
|
tp_size=1,
|
|
pp_size=1,
|
|
pp_partition=None,
|
|
moe_cluster_size=-1, # -1 means no moe
|
|
moe_tp_size=-1, # -1 means no moe
|
|
moe_ep_size=-1, # -1 means no moe
|
|
attn_tp_size=-1,
|
|
attn_cp_size=-1,
|
|
enable_attention_dp=False,
|
|
enable_lm_head_tp_in_adp=False):
|
|
# set default values for non-moe cases
|
|
# or where only one MOE parallelism size is specified
|
|
if moe_cluster_size == -1:
|
|
moe_cluster_size = 1
|
|
|
|
# Set default cp_type to ULYSSES.
|
|
cp_type = CpType.ULYSSES
|
|
|
|
# Convert cp_type to CpType enum if it is a string.
|
|
if cp_config is not None:
|
|
if "cp_type" in cp_config and isinstance(cp_config["cp_type"], str):
|
|
try:
|
|
cp_config["cp_type"] = CpType[cp_config["cp_type"].upper()]
|
|
except KeyError:
|
|
raise ValueError(f"Invalid cp_type: {cp_config['cp_type']}. " \
|
|
f"Must be one of: {', '.join([t.name for t in CpType])}")
|
|
cp_type = cp_config.get("cp_type", CpType.ULYSSES)
|
|
|
|
moe_world_size = tp_size if cp_type == CpType.ULYSSES else tp_size * cp_size
|
|
|
|
if moe_tp_size == -1 and moe_ep_size == -1:
|
|
moe_tp_size = moe_world_size // moe_cluster_size
|
|
moe_ep_size = 1
|
|
|
|
elif moe_tp_size == -1:
|
|
moe_tp_size = moe_world_size // (moe_ep_size * moe_cluster_size)
|
|
|
|
elif moe_ep_size == -1:
|
|
moe_ep_size = moe_world_size // (moe_tp_size * moe_cluster_size)
|
|
|
|
if attn_tp_size == -1 and attn_cp_size == -1:
|
|
if cp_type == CpType.ULYSSES:
|
|
# fallback to ulysses
|
|
attn_tp_size = tp_size * cp_size
|
|
attn_cp_size = 1
|
|
else:
|
|
# fallback to helix
|
|
attn_tp_size = tp_size
|
|
attn_cp_size = cp_size
|
|
|
|
elif attn_tp_size == -1:
|
|
attn_tp_size = (tp_size * cp_size) // attn_cp_size
|
|
|
|
elif attn_cp_size == -1:
|
|
attn_cp_size = (tp_size * cp_size) // attn_tp_size
|
|
|
|
if attn_cp_size != 1 and cp_type == CpType.ULYSSES:
|
|
raise ValueError(
|
|
f"attn_cp_size must be 1 for now for ulysses, but got {attn_tp_size}, {attn_cp_size}."
|
|
)
|
|
|
|
if tp_size * pp_size * cp_size != world_size:
|
|
raise ValueError(
|
|
"world_size must equal to tp_size * pp_size * cp_size, "
|
|
f"but got {world_size} != {tp_size} * {pp_size} * {cp_size}.")
|
|
|
|
moe_tp_ep_size = moe_tp_size * moe_ep_size
|
|
self.moe_tp_cluster_ep_size = moe_tp_ep_size * moe_cluster_size
|
|
if self.moe_tp_cluster_ep_size != moe_world_size:
|
|
raise ValueError(
|
|
"moe_tp_size * moe_ep_size * moe_cluster_size must equal to moe_world_size, "
|
|
f"but got {self.moe_tp_cluster_ep_size} != {moe_world_size}")
|
|
|
|
attn_tp_cp_size = attn_tp_size * attn_cp_size
|
|
if attn_tp_cp_size != tp_size * cp_size:
|
|
raise ValueError(
|
|
"tp_size * cp_size must equal to attn_tp_size * attn_cp_size, "
|
|
f"but got {tp_size} * {cp_size} != {attn_tp_size} * {attn_cp_size}"
|
|
)
|
|
|
|
if moe_ep_size != 1 and cp_size > 1 and cp_type != CpType.HELIX:
|
|
raise NotImplementedError(
|
|
f"CP {cp_type} doesn't support MoE tp/ep yet")
|
|
|
|
if moe_cluster_size > 1:
|
|
assert moe_ep_size == 1
|
|
|
|
self.tp_size = tp_size
|
|
self.cp_size = cp_size
|
|
self.cp_config = cp_config if cp_config is not None else {}
|
|
self.pp_size = pp_size
|
|
self.pp_partition = pp_partition
|
|
self.moe_tp_size = moe_tp_size
|
|
self.moe_ep_size = moe_ep_size
|
|
self.moe_cluster_size = moe_cluster_size
|
|
self.attn_tp_size = attn_tp_size
|
|
self.attn_cp_size = attn_cp_size
|
|
self.world_size = world_size
|
|
self.enable_attention_dp = enable_attention_dp
|
|
if enable_lm_head_tp_in_adp:
|
|
assert enable_attention_dp, "enable_lm_head_tp_in_adp requires enable_attention_dp"
|
|
self.enable_lm_head_tp_in_adp = enable_lm_head_tp_in_adp
|
|
self.rank = rank
|
|
self.gpus_per_node = gpus_per_node
|
|
|
|
self.pp_groups = []
|
|
self.cp_groups = []
|
|
self.tp_groups = []
|
|
self.moe_cluster_groups = []
|
|
self.moe_tp_groups = []
|
|
self.moe_ep_groups = []
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, MappingBase):
|
|
return NotImplemented
|
|
|
|
return (self.world_size == other.world_size and self.rank == other.rank
|
|
and self.gpus_per_node == other.gpus_per_node
|
|
and self.cp_size == other.cp_size
|
|
and self.tp_size == other.tp_size
|
|
and self.moe_cluster_size == other.moe_cluster_size
|
|
and self.pp_size == other.pp_size
|
|
and self.pp_partition == other.pp_partition
|
|
and self.moe_tp_size == other.moe_tp_size
|
|
and self.moe_ep_size == other.moe_ep_size
|
|
and self.attn_tp_size == other.attn_tp_size
|
|
and self.attn_cp_size == other.attn_cp_size
|
|
and self.cp_config == other.cp_config)
|
|
|
|
def __hash__(self):
|
|
return hash((
|
|
self.world_size,
|
|
self.rank,
|
|
self.gpus_per_node,
|
|
self.cp_size,
|
|
self.tp_size,
|
|
self.pp_size,
|
|
self.moe_tp_size,
|
|
self.moe_cluster_size,
|
|
self.moe_ep_size,
|
|
self.attn_tp_size,
|
|
self.attn_cp_size,
|
|
# note: we do not allow updating cp_config after initialization
|
|
tuple(sorted(self.cp_config.items())),
|
|
tuple(self.pp_partition) if self.pp_partition is not None else (),
|
|
))
|
|
|
|
@property
|
|
def rank(self):
|
|
return self._rank
|
|
|
|
@rank.setter
|
|
def rank(self, rank: int):
|
|
# TODO(qijun): skip check for enable_attention_dp temporarily, will support attention_dp_size
|
|
if not self.enable_attention_dp:
|
|
if not isinstance(rank, int) or rank < 0 or rank >= self.world_size:
|
|
raise ValueError(
|
|
f"Rank should be an integer between 0 and {self.world_size-1}, but got {rank}."
|
|
)
|
|
self._rank = rank
|
|
|
|
@property
|
|
def moe_tp_rank(self):
|
|
return self.tp_rank // (self.moe_ep_size * self.moe_cluster_size)
|
|
|
|
@property
|
|
def moe_cluster_rank(self):
|
|
return self.tp_rank % self.moe_cluster_size
|
|
|
|
@property
|
|
def moe_ep_rank(self):
|
|
return self.tp_rank % self.moe_ep_size
|
|
|
|
@property
|
|
def moe_cluster_group(self):
|
|
return self.moe_cluster_groups[self.pp_rank * self.moe_tp_size +
|
|
self.moe_tp_rank]
|
|
|
|
@property
|
|
def node_rank(self):
|
|
return self.rank // self.gpus_per_node
|
|
|
|
@property
|
|
def local_rank(self):
|
|
return self.rank % self.gpus_per_node
|
|
|
|
@property
|
|
def dp_size(self):
|
|
return self.tp_size if self.enable_attention_dp else 1
|
|
|
|
def has_cp_ulysses(self):
|
|
return self.cp_size > 1 and self.cp_config.get(
|
|
"cp_type") == CpType.ULYSSES
|
|
|
|
def has_cp_helix(self):
|
|
return self.cp_size > 1 and self.cp_config.get(
|
|
"cp_type") == CpType.HELIX
|
|
|
|
def get_node_rank(self, rank: int):
|
|
return rank // self.gpus_per_node
|
|
|
|
def get_local_rank(self, rank: int):
|
|
return rank % self.gpus_per_node
|
|
|
|
def is_multi_node(self):
|
|
return self.world_size > self.gpus_per_node
|
|
|
|
def has_tp(self):
|
|
return self.tp_size > 1
|
|
|
|
def is_last_pp_rank(self):
|
|
return self.pp_rank == self.pp_size - 1
|
|
|
|
def is_second_last_pp_rank(self):
|
|
return self.pp_rank == self.pp_size - 2
|
|
|
|
def is_first_pp_rank(self):
|
|
return self.pp_rank == 0
|
|
|
|
def has_pp(self):
|
|
return self.pp_size > 1
|
|
|
|
def prev_pp_rank(self):
|
|
p = self.rank - self.tp_size * self.cp_size
|
|
if p < 0:
|
|
p = p + self.world_size
|
|
return p
|
|
|
|
def next_pp_rank(self):
|
|
p = self.rank + self.tp_size * self.cp_size
|
|
if p >= self.world_size:
|
|
p = p - self.world_size
|
|
return p
|
|
|
|
def is_last_cp_rank(self):
|
|
return self.cp_rank == self.cp_size - 1
|
|
|
|
def is_first_cp_rank(self):
|
|
return self.cp_rank == 0
|
|
|
|
def has_cp(self):
|
|
return self.cp_size > 1
|
|
|
|
def prev_cp_rank(self):
|
|
# cp ranks are consecutive, so prev is rank - 1 with wraparound within cp group.
|
|
if self.cp_rank == 0:
|
|
return self.rank + self.cp_size - 1
|
|
return self.rank - 1
|
|
|
|
def next_cp_rank(self):
|
|
# cp ranks are consecutive, so next is rank + 1 with wraparound within cp group.
|
|
if self.cp_rank == self.cp_size - 1:
|
|
return self.rank - self.cp_size + 1
|
|
return self.rank + 1
|
|
|
|
def has_moe_cluster(self):
|
|
return self.moe_cluster_size > 1
|
|
|
|
def has_moe_tp(self):
|
|
return self.moe_tp_size > 1
|
|
|
|
def has_moe_ep(self):
|
|
return self.moe_ep_size > 1
|
|
|
|
def pp_layers(self, num_layers: int) -> List[int]:
|
|
if self.pp_partition is not None:
|
|
if len(self.pp_partition) != self.pp_size:
|
|
raise ValueError(
|
|
f"{len(self.pp_partition)=} does not match {self.pp_size=}."
|
|
)
|
|
if sum(self.pp_partition) != num_layers:
|
|
raise ValueError(
|
|
f"{sum(self.pp_partition)=} does not match {num_layers=}.")
|
|
return torch.arange(num_layers).split(
|
|
self.pp_partition)[self.pp_rank].tolist()
|
|
else:
|
|
# If num_layers % pp_size = n != 0, first n ranks get one extra layer
|
|
return torch.tensor_split(torch.arange(num_layers),
|
|
self.pp_size)[self.pp_rank].tolist()
|
|
|
|
def pp_rank_of_layer(self, layer_idx: int, num_layers: int) -> int:
|
|
"""Return pipeline-parallel rank that owns `layer_idx` for a model with `num_layers` layers.
|
|
Mirrors the partitioning behavior in `pp_layers()`.
|
|
"""
|
|
if layer_idx < 0 or layer_idx >= num_layers:
|
|
raise ValueError(f"{layer_idx=} is out of range for {num_layers=}.")
|
|
if not self.has_pp():
|
|
return 0
|
|
|
|
if self.pp_partition is not None:
|
|
if len(self.pp_partition) != self.pp_size:
|
|
raise ValueError(
|
|
f"{len(self.pp_partition)=} does not match {self.pp_size=}."
|
|
)
|
|
if sum(self.pp_partition) != num_layers:
|
|
raise ValueError(
|
|
f"{sum(self.pp_partition)=} does not match {num_layers=}.")
|
|
end = 0
|
|
for pp_rank, n in enumerate(self.pp_partition):
|
|
end += n
|
|
if layer_idx < end:
|
|
return pp_rank
|
|
raise RuntimeError("Unreachable: invalid pp_partition.")
|
|
|
|
base, rem = divmod(num_layers, self.pp_size)
|
|
if base == 0:
|
|
# Matches torch.tensor_split: first `num_layers` ranks get one layer.
|
|
return layer_idx
|
|
cutoff = (base + 1) * rem
|
|
if layer_idx < cutoff:
|
|
return layer_idx // (base + 1)
|
|
return rem + (layer_idx - cutoff) // base
|
|
|
|
def ep_experts(self, num_experts: int) -> List[int]:
|
|
assert self.cp_size == 1
|
|
experts_per_rank = num_experts // self.moe_ep_size
|
|
experts_range = range(self.moe_ep_rank * experts_per_rank,
|
|
(self.moe_ep_rank + 1) * experts_per_rank)
|
|
return list(experts_range)
|
|
|
|
@classmethod
|
|
def from_dict(cls, mapping: dict):
|
|
return cls(**mapping)
|
|
|
|
def to_dict(self):
|
|
return {
|
|
'world_size': self.world_size,
|
|
'rank': self.rank,
|
|
'gpus_per_node': self.gpus_per_node,
|
|
'cp_size': self.cp_size,
|
|
'tp_size': self.tp_size,
|
|
'pp_size': self.pp_size,
|
|
'moe_tp_size': self.moe_tp_size,
|
|
'moe_cluster_size': self.moe_cluster_size,
|
|
'moe_ep_size': self.moe_ep_size,
|
|
'attn_tp_size': self.attn_tp_size,
|
|
'attn_cp_size': self.attn_cp_size,
|
|
'cp_config': self.cp_config,
|
|
'enable_attention_dp': self.enable_attention_dp,
|
|
'enable_lm_head_tp_in_adp': self.enable_lm_head_tp_in_adp,
|
|
}
|
|
|
|
|
|
class Mapping(MappingBase):
|
|
"""
|
|
A node with 8 GPUs, tp_size = 4, cp_size = 1, pp_size = 2
|
|
|
|
2 tp groups:
|
|
|
|
- [0, 1, 2, 3]
|
|
- [4, 5, 6, 7]
|
|
|
|
4 pp groups:
|
|
|
|
- [0, 4]
|
|
- [1, 5]
|
|
- [2, 6]
|
|
- [3, 7]
|
|
|
|
A node with 8 GPUs, tp_size = 4, cp_size = 2, pp_size = 1
|
|
|
|
4 cp groups:
|
|
|
|
- [0, 1]
|
|
- [2, 3]
|
|
- [4, 5]
|
|
- [6, 7]
|
|
|
|
2 tp groups:
|
|
|
|
- [0, 2, 4, 6]
|
|
- [1, 3, 5, 7]
|
|
|
|
A node with 8 GPUs, moe_tp_size = 2, moe_ep_size = 4
|
|
|
|
4 moe_tp groups:
|
|
|
|
- [0, 4]
|
|
- [1, 5]
|
|
- [2, 6]
|
|
- [3, 7]
|
|
|
|
2 moe_ep groups:
|
|
|
|
- [0, 1, 2, 3]
|
|
- [4, 5, 6, 7]
|
|
|
|
2 nodes with 16 GPUs, moe_tp_size = 2, moe_ep_size = 4, pp_size = 2
|
|
|
|
8 moe_tp groups:
|
|
|
|
- [0 4]
|
|
- [1 5]
|
|
- [2 6]
|
|
- [3 7]
|
|
- [8 12]
|
|
- [9 13]
|
|
- [10 14]
|
|
- [11 15]
|
|
|
|
4 moe_ep groups:
|
|
|
|
- [0, 1, 2, 3]
|
|
- [4, 5, 6, 7]
|
|
- [8, 9, 10, 11]
|
|
- [12, 13, 14, 15]
|
|
|
|
8 pp groups:
|
|
|
|
- [0 8]
|
|
- [1 9]
|
|
- [2 10]
|
|
- [3 11]
|
|
- [4 12]
|
|
- [5 13]
|
|
- [6 14]
|
|
- [7 15]
|
|
|
|
2 nodes with 8 GPUs, tp_size 2, pp_size 2, cp_size 2
|
|
|
|
4 cp groups:
|
|
- [0, 1]
|
|
- [2, 3]
|
|
- [4, 5]
|
|
- [6, 7]
|
|
|
|
4 tp groups:
|
|
- [0, 2]
|
|
- [1, 3]
|
|
- [4, 6]
|
|
- [5, 7]
|
|
|
|
4 pp groups:
|
|
- [0, 4]
|
|
- [1, 5]
|
|
- [2, 6]
|
|
- [3, 7]
|
|
"""
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
if mpi_disabled():
|
|
return super().__new__(DeviceMeshTopology)
|
|
else:
|
|
return super().__new__(MpiTopology)
|
|
|
|
# Intentionally repeated for type hints
|
|
def __init__(
|
|
self,
|
|
world_size=1,
|
|
rank=0,
|
|
gpus_per_node=8,
|
|
*,
|
|
cp_size=1,
|
|
cp_config=None,
|
|
tp_size=1,
|
|
pp_size=1,
|
|
pp_partition=None,
|
|
moe_cluster_size=-1, # -1 means no moe
|
|
moe_tp_size=-1, # -1 means no moe
|
|
moe_ep_size=-1, # -1 means no moe
|
|
attn_tp_size=-1,
|
|
attn_cp_size=-1,
|
|
enable_attention_dp=False,
|
|
enable_lm_head_tp_in_adp=False):
|
|
super().__init__(world_size=world_size,
|
|
rank=rank,
|
|
gpus_per_node=gpus_per_node,
|
|
cp_size=cp_size,
|
|
cp_config=cp_config,
|
|
tp_size=tp_size,
|
|
pp_size=pp_size,
|
|
pp_partition=pp_partition,
|
|
moe_cluster_size=moe_cluster_size,
|
|
moe_tp_size=moe_tp_size,
|
|
moe_ep_size=moe_ep_size,
|
|
attn_tp_size=attn_tp_size,
|
|
attn_cp_size=attn_cp_size,
|
|
enable_attention_dp=enable_attention_dp,
|
|
enable_lm_head_tp_in_adp=enable_lm_head_tp_in_adp)
|
|
|
|
def repurpose_helix_cp_to_tp(self):
|
|
# In helix parallelism, CP is relevant only for the attention layer. These ranks are repurposed to TP
|
|
# for FFN layers.
|
|
assert self.has_cp_helix()
|
|
return Mapping(
|
|
world_size=self.world_size,
|
|
rank=self.rank,
|
|
gpus_per_node=self.gpus_per_node,
|
|
cp_size=1,
|
|
cp_config={},
|
|
tp_size=self.tp_size * self.cp_size,
|
|
pp_size=self.pp_size,
|
|
pp_partition=self.pp_partition,
|
|
moe_cluster_size=self.moe_cluster_size,
|
|
moe_tp_size=self.moe_tp_size,
|
|
moe_ep_size=self.moe_ep_size,
|
|
# attn_tp_size, attn_cp_size shall be set in the constructor of Mapping.
|
|
enable_attention_dp=self.enable_attention_dp,
|
|
enable_lm_head_tp_in_adp=self.enable_lm_head_tp_in_adp)
|
|
|
|
# DeviceMesh specific methods
|
|
@property
|
|
def tp_group_pg(self) -> ProcessGroup:
|
|
raise NotImplementedError("tp_group_pg is not implemented.")
|
|
|
|
@property
|
|
def pp_group_pg(self) -> ProcessGroup:
|
|
raise NotImplementedError("pp_group_pg is not implemented.")
|
|
|
|
@property
|
|
def cp_group_pg(self) -> ProcessGroup:
|
|
raise NotImplementedError("cp_group_pg is not implemented.")
|
|
|
|
@property
|
|
def moe_tp_group_pg(self) -> ProcessGroup:
|
|
raise NotImplementedError("moe_tp_group_pg is not implemented.")
|
|
|
|
@property
|
|
def moe_ep_group_pg(self) -> ProcessGroup:
|
|
raise NotImplementedError("moe_ep_group_pg is not implemented.")
|
|
|
|
def build_mesh(self):
|
|
raise NotImplementedError("build_mesh is not implemented.")
|
|
|
|
|
|
class MpiTopology(Mapping):
|
|
'''MPI-based mapping implementation'''
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._init_parallel_groups()
|
|
|
|
@property
|
|
def tp_rank(self) -> int:
|
|
return self.rank % (self.tp_size * self.cp_size) // self.cp_size
|
|
|
|
@property
|
|
def pp_rank(self) -> int:
|
|
return self.rank // (self.tp_size * self.cp_size)
|
|
|
|
@property
|
|
def cp_rank(self) -> int:
|
|
return self.rank % self.cp_size
|
|
|
|
@property
|
|
def tp_group(self) -> List[int]:
|
|
return self.tp_groups[self.pp_rank * self.cp_size + self.cp_rank]
|
|
|
|
@property
|
|
def pp_group(self) -> List[int]:
|
|
return self.pp_groups[self.tp_rank * self.cp_size + self.cp_rank]
|
|
|
|
@property
|
|
def cp_group(self) -> List[int]:
|
|
return self.cp_groups[self.pp_rank * self.tp_size + self.tp_rank]
|
|
|
|
@property
|
|
def moe_tp_group(self) -> List[int]:
|
|
return self.moe_tp_groups[self.pp_rank * self.moe_cluster_size *
|
|
self.moe_ep_size +
|
|
self.moe_cluster_rank * self.moe_ep_size +
|
|
self.moe_ep_rank]
|
|
|
|
@property
|
|
def moe_ep_group(self) -> List[int]:
|
|
return self.moe_ep_groups[self.pp_rank * self.moe_tp_size *
|
|
self.moe_cluster_size +
|
|
self.moe_tp_rank * self.moe_cluster_size +
|
|
self.moe_cluster_rank]
|
|
|
|
@property
|
|
def moe_cluster_group(self) -> List[int]:
|
|
return self.moe_cluster_groups[self.pp_rank * self.moe_tp_size +
|
|
self.moe_tp_rank]
|
|
|
|
def _init_parallel_groups(self):
|
|
# init pp group
|
|
for i in range(self.tp_size * self.cp_size):
|
|
ranks = range(i, self.world_size, self.tp_size * self.cp_size)
|
|
self.pp_groups.append(list(ranks))
|
|
|
|
# init cp group (consecutive ranks within each tp slice).
|
|
for i in range(self.pp_size):
|
|
for j in range(self.tp_size):
|
|
ranks = range(
|
|
i * self.tp_size * self.cp_size + j * self.cp_size,
|
|
i * self.tp_size * self.cp_size + (j + 1) * self.cp_size)
|
|
self.cp_groups.append(list(ranks))
|
|
|
|
# init tp group (interleaved ranks with stride of cp_size).
|
|
for i in range(self.pp_size):
|
|
for j in range(self.cp_size):
|
|
ranks = range(i * self.tp_size * self.cp_size + j,
|
|
(i + 1) * self.tp_size * self.cp_size + j,
|
|
self.cp_size)
|
|
self.tp_groups.append(list(ranks))
|
|
|
|
# init moe tp group
|
|
for i in range(self.pp_size):
|
|
for j in range(self.moe_cluster_size * self.moe_ep_size):
|
|
ranks = range(i * self.moe_tp_cluster_ep_size + j,
|
|
(i + 1) * self.moe_tp_cluster_ep_size,
|
|
self.moe_cluster_size * self.moe_ep_size)
|
|
self.moe_tp_groups.append(list(ranks))
|
|
|
|
# init moe cluster group
|
|
for i in range(self.pp_size):
|
|
for j in range(self.moe_tp_size):
|
|
ranks = range(
|
|
i * self.moe_tp_cluster_ep_size +
|
|
j * self.moe_cluster_size * self.moe_ep_size,
|
|
i * self.moe_tp_cluster_ep_size +
|
|
(j + 1) * self.moe_cluster_size * self.moe_ep_size)
|
|
self.moe_cluster_groups.append(list(ranks))
|
|
|
|
# init moe ep group
|
|
for i in range(self.pp_size):
|
|
for j in range(self.moe_tp_size):
|
|
for k in range(self.moe_cluster_size):
|
|
ranks = range(
|
|
i * self.moe_tp_cluster_ep_size +
|
|
j * self.moe_cluster_size * self.moe_ep_size +
|
|
k * self.moe_ep_size, i * self.moe_tp_cluster_ep_size +
|
|
j * self.moe_cluster_size * self.moe_ep_size +
|
|
(k + 1) * self.moe_ep_size)
|
|
self.moe_ep_groups.append(list(ranks))
|
|
|
|
|
|
class DeviceMeshTopology(DeviceMeshTopologyImpl, Mapping):
|
|
"""PyTorch DeviceMesh-based mapping implementation"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
assert mpi_disabled(
|
|
), "DeviceMeshTopology is only available in Ray orchestrator mode."
|
|
|
|
super().__init__(*args, **kwargs)
|