TensorRT-LLMs/tensorrt_llm/_mnnvl_utils.py
Balaram Buddharaju 8c1cfc872b
[TRTLLM-9493][feat] Custom AllToAll for helix parallelism (#9986)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
2025-12-23 18:14:30 -08:00

666 lines
25 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ctypes
import os
import platform
import sys
from dataclasses import dataclass
from typing import List, Optional, Union
import pynvml
import torch
try:
from cuda.bindings import driver as cuda
except ImportError:
from cuda import cuda
from ._dlpack_utils import pack_strided_memory
from ._utils import mpi_comm
from .logger import logger
from .mapping import Mapping
def _check_cu_result(cu_func_ret):
if isinstance(cu_func_ret, tuple):
cu_result, *others = cu_func_ret
if cu_result != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(cu_result)
if len(others) == 1:
return others[0]
elif len(others) > 1:
return tuple(others)
else:
return None
else:
if cu_func_ret != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(cu_func_ret)
return None
class MnnvlMemory:
"""MNNVL memory management for tensor parallel (TP) operations."""
# Shared across all subclasses (global/device state).
initialized: bool = False
allocation_granularity: int = 0
fabric_page_size: int = 1 << 29 # 512 MB.
dev_id: int = None
# Per-class state attributes. These will be auto-initialized for each subclass
# to avoid polluting the parent class's state. Use callable (e.g., dict) for mutable defaults.
_per_class_attrs = {
"current_mem_offset": 0,
"current_rank_stride": 0, # stride for ranks and also address space size.
"current_start_address": 0,
"comm": None, # MPI communicator.
"allocated_map": dict, # callable for fresh dict.
"address_refcnt": dict, # callable for fresh dict.
}
# Initialize per-class state for the base class.
current_mem_offset: int = 0
current_rank_stride: int = 0
current_start_address: int = 0
comm = None
allocated_map = {}
address_refcnt = {}
def __init_subclass__(cls, **kwargs):
"""Auto-initialize per-class attributes for each subclass to avoid sharing state with parent."""
super().__init_subclass__(**kwargs)
for attr, default in cls._per_class_attrs.items():
if callable(default):
setattr(cls, attr, default()) # e.g., dict() creates a fresh dict.
else:
setattr(cls, attr, default)
def __init__(self, mapping: Mapping, size: int):
self.mapping = mapping
self.segment_size = size
self.ptr, self.rank_stride = type(self).open_mnnvl_memory(self.mapping, size)
def __del__(self):
if not sys.is_finalizing():
if hasattr(self, "ptr"):
type(self).close_mnnvl_memory(self.ptr)
def as_torch_strided_tensor(self, dtype):
num_segments = type(self).comm.Get_size()
return pack_strided_memory(
self.ptr, self.segment_size, self.rank_stride, num_segments, dtype, MnnvlMemory.dev_id
)
@staticmethod
def initialize():
if not MnnvlMemory.initialized:
# use a dummy torch CUDA tensor to trigger CUDA context initialization
_ = torch.empty(1, device="cuda")
# ensure nvml is initialized.
try:
pynvml.nvmlDeviceGetCount()
except pynvml.NVMLError_Uninitialized:
pynvml.nvmlInit()
MnnvlMemory.initialized = True
@classmethod
def get_comm(cls, mapping: Mapping):
"""Get TP-based communicator (ranks grouped by PP+CP+MOE_TP, ordered by TP rank)."""
if cls.comm is not None:
return cls.comm
comm = mpi_comm().Split(
(mapping.pp_rank * mapping.cp_size + mapping.cp_rank) * mapping.moe_tp_size
+ mapping.moe_tp_rank,
mapping.tp_rank,
)
cls.comm = comm
return comm
@staticmethod
def get_allocation_prop(dev_id: int):
location = cuda.CUmemLocation()
location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
location.id = dev_id
allocation_prop = cuda.CUmemAllocationProp()
allocation_prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
# TODO: We differentiate FABRIC for GB200 (aarch64) and POSIX_FILE_DESCRIPTOR for BB200 (x86_64).
# May need to find a better way to handle this.
arch = platform.machine().lower()
is_on_aarch64 = "aarch64" in arch
if is_on_aarch64:
allocation_prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
)
else:
allocation_prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
)
allocation_prop.location = location
return allocation_prop
@staticmethod
def get_allocation_granularity(dev_id: int):
if MnnvlMemory.allocation_granularity != 0:
return MnnvlMemory.allocation_granularity
allocation_prop = MnnvlMemory.get_allocation_prop(dev_id)
option = cuda.CUmemAllocationGranularity_flags(
cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_RECOMMENDED
)
granularity = _check_cu_result(
cuda.cuMemGetAllocationGranularity(prop=allocation_prop, option=option)
)
MnnvlMemory.allocation_granularity = granularity
return MnnvlMemory.allocation_granularity
@classmethod
def new_mnnvl_memory_address(cls, mapping: Mapping, size: int):
page_count = (size + MnnvlMemory.fabric_page_size - 1) // MnnvlMemory.fabric_page_size
current_rank_stride = page_count * MnnvlMemory.fabric_page_size
logger.info(f"[{cls.__name__}] creating address with stride={current_rank_stride}")
comm = cls.get_comm(mapping)
comm_size = comm.Get_size()
address_size = current_rank_stride * comm_size
ptr = _check_cu_result(
cuda.cuMemAddressReserve(address_size, MnnvlMemory.fabric_page_size, 0, 0)
)
cls.current_start_address = int(ptr)
cls.current_rank_stride = current_rank_stride
cls.current_mem_offset = 0
@classmethod
def open_mnnvl_memory(cls, mapping: Mapping, size: int):
# Ensure MnnvlMemory is initialized (for dev_id and allocation_granularity)
MnnvlMemory.initialize()
dev = _check_cu_result(cuda.cuCtxGetDevice())
dev_id = int(dev)
if MnnvlMemory.dev_id is None:
MnnvlMemory.dev_id = dev_id
assert dev_id == MnnvlMemory.dev_id, (
f"Different dev_id found dev_id={dev_id} but MnnvlMemory.dev_id={MnnvlMemory.dev_id}"
)
comm = cls.get_comm(mapping)
comm_rank = comm.Get_rank()
comm_size = comm.Get_size()
all_rank_allocate_sizes = comm.allgather(size)
assert len(all_rank_allocate_sizes) == comm_size
assert all(x == size for x in all_rank_allocate_sizes), "Not all rank allocating same size."
granularity = MnnvlMemory.get_allocation_granularity(dev_id)
aligned_size = (size + granularity - 1) // granularity * granularity
if cls.current_mem_offset + aligned_size > cls.current_rank_stride:
cls.new_mnnvl_memory_address(mapping, aligned_size)
assert cls.current_mem_offset + aligned_size <= cls.current_rank_stride
allocation_prop = MnnvlMemory.get_allocation_prop(dev_id)
allocated_mem_handle = _check_cu_result(
cuda.cuMemCreate(aligned_size, allocation_prop, flags=0)
)
exported_fabric_handle = _check_cu_result(
cuda.cuMemExportToShareableHandle(
allocated_mem_handle, allocation_prop.requestedHandleTypes, 0
)
)
if (
allocation_prop.requestedHandleTypes
== cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
):
all_handles_data = comm.allgather(exported_fabric_handle.data)
else:
all_handles_data = comm.allgather(exported_fabric_handle)
all_pids = comm.allgather(os.getpid())
libc = ctypes.CDLL(None, use_errno=True)
syscall = libc.syscall
SYS_pidfd_open = 434
SYS_pidfd_getfd = 438
pidfds = []
for i, pid in enumerate(all_pids):
pidfd = syscall(SYS_pidfd_open, pid, 0)
if pidfd < 0:
err = ctypes.get_errno()
raise RuntimeError(
f"pidfd_open({pid}) failed with errno {err}: {os.strerror(err)}"
)
pidfds.append(pidfd)
remote_fds = []
for i, (pidfd, fd) in enumerate(zip(pidfds, all_handles_data)):
remote_fd = syscall(SYS_pidfd_getfd, pidfd, fd, 0)
if remote_fd < 0:
err = ctypes.get_errno()
error_msg = f"pidfd_getfd(pidfd={pidfd}, fd={fd}) failed with errno {err}: {os.strerror(err)}."
if err == 1: # EPERM
error_msg += (
" Permission denied. If running in a container, try adding --cap-add=SYS_PTRACE "
"to your docker run command."
)
else:
error_msg += " This may be due to kernel version (requires Linux 5.6+)."
raise RuntimeError(error_msg)
remote_fds.append(remote_fd)
all_handles_data = remote_fds
# all_handles_data like b'\x00\x00\x00 \x00\x00\x00\x00\x8f\xec\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x00\x00\x00\x00\x00\x1d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' # noqa: E501
# can use buf = memoryview(data) to import if using plain buffer for data.
madesc = cuda.CUmemAccessDesc()
madesc.location = allocation_prop.location
madesc.flags = cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
mem_handles = [None] * comm_size
for i, remote_handle_data in enumerate(all_handles_data):
rank_ptr = (
cls.current_start_address + cls.current_rank_stride * i + cls.current_mem_offset
)
if i == comm_rank:
# Local memory mapping
mem_handles[i] = allocated_mem_handle
_check_cu_result(cuda.cuMemMap(rank_ptr, aligned_size, 0, allocated_mem_handle, 0))
else:
# Fabric memory mapping
imported_mem_handle = _check_cu_result(
cuda.cuMemImportFromShareableHandle(
remote_handle_data, allocation_prop.requestedHandleTypes
)
)
mem_handles[i] = imported_mem_handle
_check_cu_result(cuda.cuMemMap(rank_ptr, aligned_size, 0, imported_mem_handle, 0))
_check_cu_result(cuda.cuMemSetAccess(rank_ptr, aligned_size, [madesc], 1))
ptr = cls.current_start_address + cls.current_mem_offset
stride = cls.current_rank_stride
cls.allocated_map[ptr] = (
mapping,
aligned_size,
mem_handles,
cls.current_start_address,
cls.current_rank_stride,
cls.current_mem_offset,
)
cls.address_refcnt[cls.current_start_address] = (
cls.address_refcnt.get(cls.current_start_address, 0) + 1
)
cls.current_mem_offset += aligned_size
return ptr, stride
@classmethod
def close_mnnvl_memory(cls, ptr: int):
mapping, aligned_size, mem_handles, start_address, rank_stride, address_offset = (
cls.allocated_map.pop(ptr)
)
comm = cls.get_comm(mapping)
comm_size = comm.Get_size()
for i in range(comm_size):
rank_ptr = start_address + i * rank_stride + address_offset
_check_cu_result(cuda.cuMemUnmap(rank_ptr, aligned_size))
_check_cu_result(cuda.cuMemRelease(mem_handles[i]))
cls.address_refcnt[start_address] -= 1
if cls.address_refcnt[start_address] == 0:
cls.address_refcnt.pop(start_address)
device_ptr = cuda.CUdeviceptr(start_address)
_check_cu_result(cuda.cuMemAddressFree(device_ptr, comm_size * rank_stride))
if start_address == cls.current_start_address:
cls.current_start_address = 0
cls.current_rank_stride = 0
cls.current_mem_offset = 0
@staticmethod
def support_nvlink(need_all_up: bool = True):
dev_id = torch.cuda.current_device()
handle = pynvml.nvmlDeviceGetHandleByIndex(dev_id)
link_count = pynvml.NVML_NVLINK_MAX_LINKS
active_links = 0
available_links = 0
for link_idx in range(link_count):
try:
if pynvml.nvmlDeviceGetNvLinkCapability(
handle, link_idx, pynvml.NVML_NVLINK_CAP_P2P_SUPPORTED
):
available_links += 1
is_active = pynvml.nvmlDeviceGetNvLinkState(handle, link_idx)
if is_active:
active_links += 1
except pynvml.NVMLError_NotSupported:
continue
return (
active_links == available_links and available_links > 0
if need_all_up
else available_links > 0
)
@staticmethod
def supports_mnnvl() -> bool:
# TODO:
# We check if it has all NVLink up now.
# But it is not equivalent to MNNVL support.
# May need better support check.
support_nvlink_and_all_up = MnnvlMemory.support_nvlink(True)
return support_nvlink_and_all_up
class HelixCpMnnvlMemory(MnnvlMemory):
"""MNNVL memory management for Helix context parallel (CP) operations.
Per-class state (current_mem_offset, comm, allocated_map, etc.) is automatically
initialized via __init_subclass__ in the parent class, ensuring this class has
its own isolated state separate from MnnvlMemory.
"""
@classmethod
def get_comm(cls, mapping: Mapping):
"""Get CP-based communicator (ranks grouped by PP+TP+MOE_TP, ordered by CP rank)."""
if cls.comm is not None:
return cls.comm
comm = mpi_comm().Split(
mapping.pp_rank * mapping.tp_size * mapping.moe_tp_size
+ mapping.tp_rank * mapping.moe_tp_size
+ mapping.moe_tp_rank,
mapping.cp_rank,
)
cls.comm = comm
return comm
@dataclass
class MoEAlltoallInfo:
local_gather_indices: torch.Tensor
send_rank_count_cumsum: torch.Tensor
send_rank_local_indices: torch.Tensor
recv_rank_count_cumsum: torch.Tensor
recv_rank_local_indices: torch.Tensor
backward_recv_rank_local_indices: torch.Tensor
local_token_allocation_count: int
class MnnvlMoe:
moe_workspace: MnnvlMemory = None
moe_prepare_workspace: MnnvlMemory = None
moe_workspace_tensor: torch.Tensor = None
moe_prepare_workspace_tensor: torch.Tensor = None
moe_mapping: Mapping = None
@staticmethod
def get_moe_workspaces(mapping: Mapping):
if MnnvlMoe.moe_workspace is not None:
assert mapping == MnnvlMoe.moe_mapping, "only one moe mapping supported now"
return MnnvlMoe.moe_workspace_tensor
MnnvlMoe.moe_mapping = mapping
workspace_size_per_rank = torch.ops.trtllm.get_moe_commworkspace_size_per_rank(
mapping.moe_ep_size
)
MnnvlMoe.moe_workspace = MnnvlMemory(mapping, workspace_size_per_rank)
MnnvlMoe.moe_workspace_tensor = MnnvlMoe.moe_workspace.as_torch_strided_tensor(torch.uint64)
torch.ops.trtllm.moe_initialize_workspace(
MnnvlMoe.moe_workspace_tensor, mapping.moe_ep_rank, mapping.moe_ep_size
)
torch.cuda.synchronize()
MnnvlMoe.moe_workspace.comm.barrier()
return MnnvlMoe.moe_workspace_tensor
@staticmethod
def get_moe_prepare_workspace(mapping: Mapping):
if MnnvlMoe.moe_prepare_workspace_tensor is not None:
assert mapping == MnnvlMoe.moe_mapping, "only one moe mapping supported now"
return MnnvlMoe.moe_prepare_workspace_tensor
workspace_size_per_rank = torch.ops.trtllm.get_moe_prepare_workspace_size_per_rank(
mapping.moe_ep_size
)
MnnvlMoe.moe_prepare_workspace = MnnvlMemory(mapping, workspace_size_per_rank)
MnnvlMoe.moe_prepare_workspace_tensor = (
MnnvlMoe.moe_prepare_workspace.as_torch_strided_tensor(torch.uint64)
)
return MnnvlMoe.moe_prepare_workspace_tensor
@staticmethod
def compute_target_rank_id(
token_selected_experts: torch.Tensor, expert_count: int, ep_size: int
):
assert expert_count % ep_size == 0, "expert_count should be divisible by ep_size"
expert_per_rank = expert_count // ep_size
token_target_rank_ids = token_selected_experts // expert_per_rank
return token_target_rank_ids
@staticmethod
def mnnvl_moe_alltoallv_prepare_without_allgather(
expert_ids: torch.Tensor,
expert_statics: Optional[torch.Tensor],
workspace: torch.Tensor,
max_token_count_per_rank: int,
ep_rank: int,
ep_size: int,
expert_count: int,
slot_count: int,
top_k: int,
):
(
local_send_rank_count_cumsum,
local_send_rank_indices,
local_recv_rank_count_cumsum,
local_recv_rank_indices,
backward_local_recv_rank_indices,
gathered_expert_statics,
) = torch.ops.trtllm.mnnvl_moe_alltoallv_prepare_without_allgather(
expert_ids,
expert_statics,
workspace,
max_token_count_per_rank,
ep_rank,
ep_size,
expert_count,
slot_count,
top_k,
)
local_token_allocation_count = max_token_count_per_rank * ep_size
# Looks like we don't need this.
local_gather_indices = None
alltoall_info = MoEAlltoallInfo(
local_gather_indices,
local_send_rank_count_cumsum,
local_send_rank_indices,
local_recv_rank_count_cumsum,
local_recv_rank_indices,
backward_local_recv_rank_indices,
local_token_allocation_count,
)
return alltoall_info, gathered_expert_statics
@staticmethod
def mnnvl_moe_expert_static_allgather(
expert_ids: torch.Tensor,
workspace: torch.Tensor,
ep_rank: int,
ep_size: int,
expert_count: int,
):
gathered_expert_ids = torch.ops.trtllm.mnnvl_moe_expert_static_allgather(
expert_ids, workspace, ep_rank, ep_size, expert_count
)
return gathered_expert_ids
@staticmethod
def mnnvl_moe_alltoallv_prepare(
gathered_target_rank_ids: torch.Tensor,
real_rank_token_count_cumsum: Optional[torch.Tensor],
gathered_expert_ids: torch.Tensor,
gathered_scales: Optional[torch.Tensor],
max_token_count_per_rank: int,
expert_count: int,
top_k: int,
ep_rank: int,
ep_size: int,
):
(
local_gather_indices,
send_rank_count_cumsum,
send_rank_local_indices,
recv_rank_count_cumsum,
recv_rank_local_indices,
backward_recv_rank_local_indices,
) = torch.ops.trtllm.moe_comm_prepare_indices(
gathered_target_rank_ids,
real_rank_token_count_cumsum,
max_token_count_per_rank,
expert_count,
top_k,
ep_rank,
ep_size,
)
local_token_allocation_count = max_token_count_per_rank * ep_size
local_expert_ids = torch.empty(
local_token_allocation_count, top_k, dtype=torch.int32, device=torch.device("cuda")
)
if gathered_scales is None:
local_scales = None
else:
local_scales = torch.empty(
local_token_allocation_count,
top_k,
dtype=torch.float32,
device=torch.device("cuda"),
)
torch.ops.trtllm.moe_local_gather(
recv_rank_count_cumsum,
local_gather_indices,
gathered_expert_ids,
gathered_scales,
local_expert_ids,
local_scales,
max_token_count_per_rank,
expert_count,
top_k,
ep_rank,
ep_size,
)
alltoall_info = MoEAlltoallInfo(
local_gather_indices,
send_rank_count_cumsum,
send_rank_local_indices,
recv_rank_count_cumsum,
recv_rank_local_indices,
backward_recv_rank_local_indices,
local_token_allocation_count,
)
return alltoall_info, local_expert_ids, local_scales
@staticmethod
def mnnvl_moe_alltoallv(
x: Union[torch.Tensor, List[Optional[torch.Tensor]]],
alltoall_info: MoEAlltoallInfo,
workspace: torch.Tensor,
ep_rank: int,
ep_size: int,
) -> Union[torch.Tensor, List[Optional[torch.Tensor]]]:
# Convert single tensor to list for unified handling
is_single_tensor = not isinstance(x, list)
if is_single_tensor:
assert x.dim() == 2, "only 2D tensor supported, please reshape."
x = [x]
assert len(x) > 0, "Empty tensor list not supported"
# Filter out None values
valid_list = [tensor is not None for tensor in x]
valid_tensors = [tensor for tensor in x if tensor is not None]
if len(valid_tensors) == 0:
# All tensors are None, return list of None
result = [None] * len(x)
else:
first_dim = None
for tensor in valid_tensors:
# Validate dimensions of valid tensors
assert tensor.dim() == 2, "only 2D tensor supported, please reshape."
if first_dim is None:
first_dim = tensor.shape[0]
else:
assert tensor.shape[0] == first_dim, (
f"All tensors must have the same first dimension, got {tensor.shape[0]} vs {first_dim}"
)
# Process only valid tensors
output_tensors = torch.ops.trtllm.moe_comm(
valid_tensors,
alltoall_info.send_rank_count_cumsum,
alltoall_info.send_rank_local_indices,
alltoall_info.recv_rank_count_cumsum,
alltoall_info.recv_rank_local_indices,
workspace,
alltoall_info.local_token_allocation_count,
ep_rank,
ep_size,
)
# Restore None positions in output
idx = 0
result = []
for is_valid in valid_list:
if is_valid:
result.append(output_tensors[idx])
idx += 1
else:
result.append(None)
# If input was a single tensor, return a single tensor
if is_single_tensor:
result = result[0]
return result
@staticmethod
def mnnvl_moe_alltoallv_combine(
x: torch.Tensor,
alltoall_info: MoEAlltoallInfo,
workspace: torch.Tensor,
ep_rank: int,
ep_size: int,
top_k: int,
token_count: int,
use_low_precision_combine: bool = False,
do_reduce: bool = True,
):
assert x.dim() == 2, "2D tensor supported, please reshape."
output_tensors = torch.ops.trtllm.moe_comm(
[x],
alltoall_info.recv_rank_count_cumsum,
alltoall_info.recv_rank_local_indices,
alltoall_info.send_rank_count_cumsum,
alltoall_info.backward_recv_rank_local_indices,
workspace,
token_count * top_k,
ep_rank,
ep_size,
[True],
use_low_precision_combine,
)
output_tensor = output_tensors[0].reshape(token_count, top_k, x.shape[1])
if do_reduce:
return torch.sum(output_tensor, dim=1, keepdim=False)
else:
return output_tensor