TensorRT-LLMs/tensorrt_llm/_mnnvl_utils.py
Balaram Buddharaju ccdfa43a6e
[https://nvbugs/5791900][fix] Fix HelixCpMnnvlMemory init with PP (#10533)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
2026-01-13 15:48:42 -05:00

683 lines
26 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ctypes
import os
import platform
import sys
from dataclasses import dataclass
from typing import List, Optional, Union
import pynvml
import torch
try:
from cuda.bindings import driver as cuda
except ImportError:
from cuda import cuda
from ._dlpack_utils import pack_strided_memory
from ._utils import mpi_comm
from .logger import logger
from .mapping import Mapping
def _check_cu_result(cu_func_ret):
if isinstance(cu_func_ret, tuple):
cu_result, *others = cu_func_ret
if cu_result != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(cu_result)
if len(others) == 1:
return others[0]
elif len(others) > 1:
return tuple(others)
else:
return None
else:
if cu_func_ret != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError(cu_func_ret)
return None
class MnnvlMemory:
"""MNNVL memory management for tensor parallel (TP) operations."""
# Shared across all subclasses (global/device state).
initialized: bool = False
allocation_granularity: int = 0
fabric_page_size: int = 1 << 29 # 512 MB.
dev_id: int = None
# Per-class state attributes. These will be auto-initialized for each subclass
# to avoid polluting the parent class's state. Use callable (e.g., dict) for mutable defaults.
_per_class_attrs = {
"current_mem_offset": 0,
"current_rank_stride": 0, # stride for ranks and also address space size.
"current_start_address": 0,
"comm": None, # MPI communicator.
"allocated_map": dict, # callable for fresh dict.
"address_refcnt": dict, # callable for fresh dict.
}
# Initialize per-class state for the base class.
current_mem_offset: int = 0
current_rank_stride: int = 0
current_start_address: int = 0
comm = None
allocated_map = {}
address_refcnt = {}
def __init_subclass__(cls, **kwargs):
"""Auto-initialize per-class attributes for each subclass to avoid sharing state with parent."""
super().__init_subclass__(**kwargs)
for attr, default in cls._per_class_attrs.items():
if callable(default):
setattr(cls, attr, default()) # e.g., dict() creates a fresh dict.
else:
setattr(cls, attr, default)
def __init__(self, mapping: Mapping, size: int):
self.mapping = mapping
self.segment_size = size
self.ptr, self.rank_stride = type(self).open_mnnvl_memory(self.mapping, size)
def __del__(self):
if not sys.is_finalizing():
if hasattr(self, "ptr"):
type(self).close_mnnvl_memory(self.ptr)
def as_torch_strided_tensor(self, dtype):
num_segments = type(self).comm.Get_size()
return pack_strided_memory(
self.ptr, self.segment_size, self.rank_stride, num_segments, dtype, MnnvlMemory.dev_id
)
@staticmethod
def initialize():
if not MnnvlMemory.initialized:
# use a dummy torch CUDA tensor to trigger CUDA context initialization
_ = torch.empty(1, device="cuda")
# ensure nvml is initialized.
try:
pynvml.nvmlDeviceGetCount()
except pynvml.NVMLError_Uninitialized:
pynvml.nvmlInit()
MnnvlMemory.initialized = True
@classmethod
def get_comm(cls, mapping: Mapping):
"""Get TP-based communicator (ranks grouped by PP+CP+MOE_TP, ordered by TP rank)."""
if cls.comm is not None:
return cls.comm
comm = mpi_comm().Split(
(mapping.pp_rank * mapping.cp_size + mapping.cp_rank) * mapping.moe_tp_size
+ mapping.moe_tp_rank,
mapping.tp_rank,
)
cls.comm = comm
return comm
@staticmethod
def get_allocation_prop(dev_id: int):
location = cuda.CUmemLocation()
location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
location.id = dev_id
allocation_prop = cuda.CUmemAllocationProp()
allocation_prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
# TODO: We differentiate FABRIC for GB200 (aarch64) and POSIX_FILE_DESCRIPTOR for BB200 (x86_64).
# May need to find a better way to handle this.
arch = platform.machine().lower()
is_on_aarch64 = "aarch64" in arch
if is_on_aarch64:
allocation_prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
)
else:
allocation_prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
)
allocation_prop.location = location
return allocation_prop
@staticmethod
def get_allocation_granularity(dev_id: int):
if MnnvlMemory.allocation_granularity != 0:
return MnnvlMemory.allocation_granularity
allocation_prop = MnnvlMemory.get_allocation_prop(dev_id)
option = cuda.CUmemAllocationGranularity_flags(
cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_RECOMMENDED
)
granularity = _check_cu_result(
cuda.cuMemGetAllocationGranularity(prop=allocation_prop, option=option)
)
MnnvlMemory.allocation_granularity = granularity
return MnnvlMemory.allocation_granularity
@classmethod
def new_mnnvl_memory_address(cls, mapping: Mapping, size: int):
page_count = (size + MnnvlMemory.fabric_page_size - 1) // MnnvlMemory.fabric_page_size
current_rank_stride = page_count * MnnvlMemory.fabric_page_size
logger.info(f"[{cls.__name__}] creating address with stride={current_rank_stride}")
comm = cls.get_comm(mapping)
comm_size = comm.Get_size()
address_size = current_rank_stride * comm_size
ptr = _check_cu_result(
cuda.cuMemAddressReserve(address_size, MnnvlMemory.fabric_page_size, 0, 0)
)
cls.current_start_address = int(ptr)
cls.current_rank_stride = current_rank_stride
cls.current_mem_offset = 0
@classmethod
def open_mnnvl_memory(cls, mapping: Mapping, size: int):
# Ensure MnnvlMemory is initialized (for dev_id and allocation_granularity)
MnnvlMemory.initialize()
dev = _check_cu_result(cuda.cuCtxGetDevice())
dev_id = int(dev)
if MnnvlMemory.dev_id is None:
MnnvlMemory.dev_id = dev_id
assert dev_id == MnnvlMemory.dev_id, (
f"Different dev_id found dev_id={dev_id} but MnnvlMemory.dev_id={MnnvlMemory.dev_id}"
)
comm = cls.get_comm(mapping)
comm_rank = comm.Get_rank()
comm_size = comm.Get_size()
all_rank_allocate_sizes = comm.allgather(size)
assert len(all_rank_allocate_sizes) == comm_size
assert all(x == size for x in all_rank_allocate_sizes), "Not all rank allocating same size."
granularity = MnnvlMemory.get_allocation_granularity(dev_id)
aligned_size = (size + granularity - 1) // granularity * granularity
if cls.current_mem_offset + aligned_size > cls.current_rank_stride:
cls.new_mnnvl_memory_address(mapping, aligned_size)
assert cls.current_mem_offset + aligned_size <= cls.current_rank_stride
allocation_prop = MnnvlMemory.get_allocation_prop(dev_id)
allocated_mem_handle = _check_cu_result(
cuda.cuMemCreate(aligned_size, allocation_prop, flags=0)
)
exported_fabric_handle = _check_cu_result(
cuda.cuMemExportToShareableHandle(
allocated_mem_handle, allocation_prop.requestedHandleTypes, 0
)
)
if (
allocation_prop.requestedHandleTypes
== cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
):
all_handles_data = comm.allgather(exported_fabric_handle.data)
else:
all_handles_data = comm.allgather(exported_fabric_handle)
all_pids = comm.allgather(os.getpid())
libc = ctypes.CDLL(None, use_errno=True)
syscall = libc.syscall
SYS_pidfd_open = 434
SYS_pidfd_getfd = 438
pidfds = []
for i, pid in enumerate(all_pids):
pidfd = syscall(SYS_pidfd_open, pid, 0)
if pidfd < 0:
err = ctypes.get_errno()
raise RuntimeError(
f"pidfd_open({pid}) failed with errno {err}: {os.strerror(err)}"
)
pidfds.append(pidfd)
remote_fds = []
for i, (pidfd, fd) in enumerate(zip(pidfds, all_handles_data)):
remote_fd = syscall(SYS_pidfd_getfd, pidfd, fd, 0)
if remote_fd < 0:
err = ctypes.get_errno()
error_msg = f"pidfd_getfd(pidfd={pidfd}, fd={fd}) failed with errno {err}: {os.strerror(err)}."
if err == 1: # EPERM
error_msg += (
" Permission denied. If running in a container, try adding --cap-add=SYS_PTRACE "
"to your docker run command."
)
else:
error_msg += " This may be due to kernel version (requires Linux 5.6+)."
raise RuntimeError(error_msg)
remote_fds.append(remote_fd)
all_handles_data = remote_fds
# all_handles_data like b'\x00\x00\x00 \x00\x00\x00\x00\x8f\xec\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x00\x00\x00\x00\x00\x1d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' # noqa: E501
# can use buf = memoryview(data) to import if using plain buffer for data.
madesc = cuda.CUmemAccessDesc()
madesc.location = allocation_prop.location
madesc.flags = cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
mem_handles = [None] * comm_size
for i, remote_handle_data in enumerate(all_handles_data):
rank_ptr = (
cls.current_start_address + cls.current_rank_stride * i + cls.current_mem_offset
)
if i == comm_rank:
# Local memory mapping
mem_handles[i] = allocated_mem_handle
_check_cu_result(cuda.cuMemMap(rank_ptr, aligned_size, 0, allocated_mem_handle, 0))
else:
# Fabric memory mapping
imported_mem_handle = _check_cu_result(
cuda.cuMemImportFromShareableHandle(
remote_handle_data, allocation_prop.requestedHandleTypes
)
)
mem_handles[i] = imported_mem_handle
_check_cu_result(cuda.cuMemMap(rank_ptr, aligned_size, 0, imported_mem_handle, 0))
_check_cu_result(cuda.cuMemSetAccess(rank_ptr, aligned_size, [madesc], 1))
ptr = cls.current_start_address + cls.current_mem_offset
stride = cls.current_rank_stride
cls.allocated_map[ptr] = (
mapping,
aligned_size,
mem_handles,
cls.current_start_address,
cls.current_rank_stride,
cls.current_mem_offset,
)
cls.address_refcnt[cls.current_start_address] = (
cls.address_refcnt.get(cls.current_start_address, 0) + 1
)
cls.current_mem_offset += aligned_size
return ptr, stride
@classmethod
def close_mnnvl_memory(cls, ptr: int):
mapping, aligned_size, mem_handles, start_address, rank_stride, address_offset = (
cls.allocated_map.pop(ptr)
)
comm = cls.get_comm(mapping)
comm_size = comm.Get_size()
for i in range(comm_size):
rank_ptr = start_address + i * rank_stride + address_offset
_check_cu_result(cuda.cuMemUnmap(rank_ptr, aligned_size))
_check_cu_result(cuda.cuMemRelease(mem_handles[i]))
cls.address_refcnt[start_address] -= 1
if cls.address_refcnt[start_address] == 0:
cls.address_refcnt.pop(start_address)
device_ptr = cuda.CUdeviceptr(start_address)
_check_cu_result(cuda.cuMemAddressFree(device_ptr, comm_size * rank_stride))
if start_address == cls.current_start_address:
cls.current_start_address = 0
cls.current_rank_stride = 0
cls.current_mem_offset = 0
@staticmethod
def support_nvlink(need_all_up: bool = True):
dev_id = torch.cuda.current_device()
handle = pynvml.nvmlDeviceGetHandleByIndex(dev_id)
link_count = pynvml.NVML_NVLINK_MAX_LINKS
active_links = 0
available_links = 0
for link_idx in range(link_count):
try:
if pynvml.nvmlDeviceGetNvLinkCapability(
handle, link_idx, pynvml.NVML_NVLINK_CAP_P2P_SUPPORTED
):
available_links += 1
is_active = pynvml.nvmlDeviceGetNvLinkState(handle, link_idx)
if is_active:
active_links += 1
except pynvml.NVMLError_NotSupported:
continue
return (
active_links == available_links and available_links > 0
if need_all_up
else available_links > 0
)
@staticmethod
def supports_mnnvl() -> bool:
# TODO:
# We check if it has all NVLink up now.
# But it is not equivalent to MNNVL support.
# May need better support check.
support_nvlink_and_all_up = MnnvlMemory.support_nvlink(True)
return support_nvlink_and_all_up
class HelixCpMnnvlMemory(MnnvlMemory):
"""MNNVL memory management for Helix context parallel (CP) operations.
Per-class state (current_mem_offset, comm, allocated_map, etc.) is automatically
initialized via __init_subclass__ in the parent class, ensuring this class has
its own isolated state separate from MnnvlMemory.
"""
@classmethod
def get_comm(cls, mapping: Mapping):
"""Get CP-based communicator (ranks grouped by PP+TP+MOE_TP, ordered by CP rank)."""
if cls.comm is not None:
return cls.comm
comm = mpi_comm().Split(
mapping.pp_rank * mapping.tp_size + mapping.tp_rank,
mapping.cp_rank,
)
cls.comm = comm
return comm
def init_helix_cp_comm(mapping: Mapping) -> None:
"""Pre-initialize the Helix CP communicator.
This function MUST be called during model initialization when all ranks
are synchronized (before any PP pipeline divergence). The MPI Split operation
is collective and requires all ranks in the communicator to participate.
In PP (pipeline parallel) mode, different PP stages execute different parts
of the model at different times. If the communicator is initialized lazily
during the first forward pass, ranks in different PP stages may not reach
the Split operation at the same time, causing a deadlock.
Args:
mapping: The mapping object containing parallelism configuration.
"""
if mapping.has_cp_helix() and not mapping.cp_config.get("use_nccl_for_alltoall", True):
HelixCpMnnvlMemory.get_comm(mapping)
@dataclass
class MoEAlltoallInfo:
local_gather_indices: torch.Tensor
send_rank_count_cumsum: torch.Tensor
send_rank_local_indices: torch.Tensor
recv_rank_count_cumsum: torch.Tensor
recv_rank_local_indices: torch.Tensor
backward_recv_rank_local_indices: torch.Tensor
local_token_allocation_count: int
class MnnvlMoe:
moe_workspace: MnnvlMemory = None
moe_prepare_workspace: MnnvlMemory = None
moe_workspace_tensor: torch.Tensor = None
moe_prepare_workspace_tensor: torch.Tensor = None
moe_mapping: Mapping = None
@staticmethod
def get_moe_workspaces(mapping: Mapping):
if MnnvlMoe.moe_workspace is not None:
assert mapping == MnnvlMoe.moe_mapping, "only one moe mapping supported now"
return MnnvlMoe.moe_workspace_tensor
MnnvlMoe.moe_mapping = mapping
workspace_size_per_rank = torch.ops.trtllm.get_moe_commworkspace_size_per_rank(
mapping.moe_ep_size
)
MnnvlMoe.moe_workspace = MnnvlMemory(mapping, workspace_size_per_rank)
MnnvlMoe.moe_workspace_tensor = MnnvlMoe.moe_workspace.as_torch_strided_tensor(torch.uint64)
torch.ops.trtllm.moe_initialize_workspace(
MnnvlMoe.moe_workspace_tensor, mapping.moe_ep_rank, mapping.moe_ep_size
)
torch.cuda.synchronize()
MnnvlMoe.moe_workspace.comm.barrier()
return MnnvlMoe.moe_workspace_tensor
@staticmethod
def get_moe_prepare_workspace(mapping: Mapping):
if MnnvlMoe.moe_prepare_workspace_tensor is not None:
assert mapping == MnnvlMoe.moe_mapping, "only one moe mapping supported now"
return MnnvlMoe.moe_prepare_workspace_tensor
workspace_size_per_rank = torch.ops.trtllm.get_moe_prepare_workspace_size_per_rank(
mapping.moe_ep_size
)
MnnvlMoe.moe_prepare_workspace = MnnvlMemory(mapping, workspace_size_per_rank)
MnnvlMoe.moe_prepare_workspace_tensor = (
MnnvlMoe.moe_prepare_workspace.as_torch_strided_tensor(torch.uint64)
)
return MnnvlMoe.moe_prepare_workspace_tensor
@staticmethod
def compute_target_rank_id(
token_selected_experts: torch.Tensor, expert_count: int, ep_size: int
):
assert expert_count % ep_size == 0, "expert_count should be divisible by ep_size"
expert_per_rank = expert_count // ep_size
token_target_rank_ids = token_selected_experts // expert_per_rank
return token_target_rank_ids
@staticmethod
def mnnvl_moe_alltoallv_prepare_without_allgather(
expert_ids: torch.Tensor,
expert_statics: Optional[torch.Tensor],
workspace: torch.Tensor,
max_token_count_per_rank: int,
ep_rank: int,
ep_size: int,
expert_count: int,
slot_count: int,
top_k: int,
):
(
local_send_rank_count_cumsum,
local_send_rank_indices,
local_recv_rank_count_cumsum,
local_recv_rank_indices,
backward_local_recv_rank_indices,
gathered_expert_statics,
) = torch.ops.trtllm.mnnvl_moe_alltoallv_prepare_without_allgather(
expert_ids,
expert_statics,
workspace,
max_token_count_per_rank,
ep_rank,
ep_size,
expert_count,
slot_count,
top_k,
)
local_token_allocation_count = max_token_count_per_rank * ep_size
# Looks like we don't need this.
local_gather_indices = None
alltoall_info = MoEAlltoallInfo(
local_gather_indices,
local_send_rank_count_cumsum,
local_send_rank_indices,
local_recv_rank_count_cumsum,
local_recv_rank_indices,
backward_local_recv_rank_indices,
local_token_allocation_count,
)
return alltoall_info, gathered_expert_statics
@staticmethod
def mnnvl_moe_expert_static_allgather(
expert_ids: torch.Tensor,
workspace: torch.Tensor,
ep_rank: int,
ep_size: int,
expert_count: int,
):
gathered_expert_ids = torch.ops.trtllm.mnnvl_moe_expert_static_allgather(
expert_ids, workspace, ep_rank, ep_size, expert_count
)
return gathered_expert_ids
@staticmethod
def mnnvl_moe_alltoallv_prepare(
gathered_target_rank_ids: torch.Tensor,
real_rank_token_count_cumsum: Optional[torch.Tensor],
gathered_expert_ids: torch.Tensor,
gathered_scales: Optional[torch.Tensor],
max_token_count_per_rank: int,
expert_count: int,
top_k: int,
ep_rank: int,
ep_size: int,
):
(
local_gather_indices,
send_rank_count_cumsum,
send_rank_local_indices,
recv_rank_count_cumsum,
recv_rank_local_indices,
backward_recv_rank_local_indices,
) = torch.ops.trtllm.moe_comm_prepare_indices(
gathered_target_rank_ids,
real_rank_token_count_cumsum,
max_token_count_per_rank,
expert_count,
top_k,
ep_rank,
ep_size,
)
local_token_allocation_count = max_token_count_per_rank * ep_size
local_expert_ids = torch.empty(
local_token_allocation_count, top_k, dtype=torch.int32, device=torch.device("cuda")
)
if gathered_scales is None:
local_scales = None
else:
local_scales = torch.empty(
local_token_allocation_count,
top_k,
dtype=torch.float32,
device=torch.device("cuda"),
)
torch.ops.trtllm.moe_local_gather(
recv_rank_count_cumsum,
local_gather_indices,
gathered_expert_ids,
gathered_scales,
local_expert_ids,
local_scales,
max_token_count_per_rank,
expert_count,
top_k,
ep_rank,
ep_size,
)
alltoall_info = MoEAlltoallInfo(
local_gather_indices,
send_rank_count_cumsum,
send_rank_local_indices,
recv_rank_count_cumsum,
recv_rank_local_indices,
backward_recv_rank_local_indices,
local_token_allocation_count,
)
return alltoall_info, local_expert_ids, local_scales
@staticmethod
def mnnvl_moe_alltoallv(
x: Union[torch.Tensor, List[Optional[torch.Tensor]]],
alltoall_info: MoEAlltoallInfo,
workspace: torch.Tensor,
ep_rank: int,
ep_size: int,
) -> Union[torch.Tensor, List[Optional[torch.Tensor]]]:
# Convert single tensor to list for unified handling
is_single_tensor = not isinstance(x, list)
if is_single_tensor:
assert x.dim() == 2, "only 2D tensor supported, please reshape."
x = [x]
assert len(x) > 0, "Empty tensor list not supported"
# Filter out None values
valid_list = [tensor is not None for tensor in x]
valid_tensors = [tensor for tensor in x if tensor is not None]
if len(valid_tensors) == 0:
# All tensors are None, return list of None
result = [None] * len(x)
else:
first_dim = None
for tensor in valid_tensors:
# Validate dimensions of valid tensors
assert tensor.dim() == 2, "only 2D tensor supported, please reshape."
if first_dim is None:
first_dim = tensor.shape[0]
else:
assert tensor.shape[0] == first_dim, (
f"All tensors must have the same first dimension, got {tensor.shape[0]} vs {first_dim}"
)
# Process only valid tensors
output_tensors = torch.ops.trtllm.moe_comm(
valid_tensors,
alltoall_info.send_rank_count_cumsum,
alltoall_info.send_rank_local_indices,
alltoall_info.recv_rank_count_cumsum,
alltoall_info.recv_rank_local_indices,
workspace,
alltoall_info.local_token_allocation_count,
ep_rank,
ep_size,
)
# Restore None positions in output
idx = 0
result = []
for is_valid in valid_list:
if is_valid:
result.append(output_tensors[idx])
idx += 1
else:
result.append(None)
# If input was a single tensor, return a single tensor
if is_single_tensor:
result = result[0]
return result
@staticmethod
def mnnvl_moe_alltoallv_combine(
x: torch.Tensor,
alltoall_info: MoEAlltoallInfo,
workspace: torch.Tensor,
ep_rank: int,
ep_size: int,
top_k: int,
token_count: int,
use_low_precision_combine: bool = False,
do_reduce: bool = True,
):
assert x.dim() == 2, "2D tensor supported, please reshape."
output_tensors = torch.ops.trtllm.moe_comm(
[x],
alltoall_info.recv_rank_count_cumsum,
alltoall_info.recv_rank_local_indices,
alltoall_info.send_rank_count_cumsum,
alltoall_info.backward_recv_rank_local_indices,
workspace,
token_count * top_k,
ep_rank,
ep_size,
[True],
use_low_precision_combine,
)
output_tensor = output_tensors[0].reshape(token_count, top_k, x.shape[1])
if do_reduce:
return torch.sum(output_tensor, dim=1, keepdim=False)
else:
return output_tensor