[EPLB] Nixl communicator optimization. Zero-copy transfers (#41633)

Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Markov Ilya <markovilya19@gmail.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Markov Ilya <markovilya19@gmail.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
Ilya Markov
2026-06-04 05:40:34 +02:00
committed by GitHub
parent f0cd590d62
commit 4f423bd5bc
12 changed files with 287 additions and 219 deletions
+21 -9
View File
@@ -277,12 +277,15 @@ def assert_verification_synced(local_ok: bool, msg: str) -> None:
assert bool(ok_tensor.item()), msg
def create_eplb_communicator_or_raise(*, group_coordinator, backend, expert_weights):
def create_eplb_communicator_or_raise(
*, group_coordinator, backend, expert_weights, expert_buffer
):
try:
return create_eplb_communicator(
group_coordinator=group_coordinator,
backend=backend,
expert_weights=expert_weights,
expert_buffer=expert_buffer,
)
except Exception as exc:
raise RuntimeError(
@@ -355,7 +358,8 @@ def _test_async_transfer_layer_without_mtp_worker(
communicator = create_eplb_communicator_or_raise(
group_coordinator=ep_group_coordinator,
backend=eplb_communicator,
expert_weights=expert_weights[0],
expert_weights=expert_weights,
expert_buffer=expert_buffer,
)
communicator.set_stream(cuda_stream)
@@ -368,6 +372,7 @@ def _test_async_transfer_layer_without_mtp_worker(
ep_group=ep_group,
communicator=communicator,
cuda_stream=cuda_stream,
layer_idx=layer_idx,
)
cuda_stream.synchronize()
move_from_buffer(
@@ -460,10 +465,12 @@ def _test_rearrange_expert_weights_with_redundancy(
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
)
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
communicator = create_eplb_communicator_or_raise(
group_coordinator=ep_group_coordinator,
backend=eplb_communicator,
expert_weights=expert_weights[0],
expert_weights=expert_weights,
expert_buffer=expert_buffer,
)
# Execute weight rearrangement
@@ -471,9 +478,9 @@ def _test_rearrange_expert_weights_with_redundancy(
old_indices,
new_indices,
expert_weights,
expert_buffer,
ep_group,
is_profile=False,
communicator=communicator,
communicator,
)
# Verify the rearrangement result
@@ -593,10 +600,12 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
layer_copy.append(weight.clone())
original_weights.append(layer_copy)
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
communicator = create_eplb_communicator_or_raise(
group_coordinator=ep_group_coordinator,
backend="torch_nccl",
expert_weights=expert_weights[0],
expert_weights=expert_weights,
expert_buffer=expert_buffer,
)
# Execute rearrangement (should be no change)
@@ -604,9 +613,9 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
indices,
indices, # Same indices
expert_weights,
expert_buffer,
ep_group,
communicator,
is_profile=False,
)
# Verify that the weights have not changed
@@ -726,10 +735,12 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
layer_copy.append(weight.clone())
original_weights.append(layer_copy)
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
communicator = create_eplb_communicator_or_raise(
group_coordinator=ep_group_coordinator,
backend="torch_nccl",
expert_weights=expert_weights[0],
expert_weights=expert_weights,
expert_buffer=expert_buffer,
)
# Execute profile mode rearrangement
@@ -737,9 +748,10 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
old_indices,
new_indices,
expert_weights,
expert_buffer,
ep_group,
communicator,
is_profile=True, # Profile mode
is_profile=True,
)
# In profile mode, the weights should remain unchanged
+11 -1
View File
@@ -9,9 +9,11 @@ import pytest
import torch
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
get_eplb_group,
get_tp_group,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
@@ -213,12 +215,20 @@ def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
for lidx in range(test_config.num_layers):
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
expert_buffer = [torch.empty_like(w) for w in rank_expert_weights[0]]
communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(),
backend="torch_nccl",
expert_weights=rank_expert_weights,
expert_buffer=expert_buffer,
)
rearrange_expert_weights_inplace(
indices,
shuffled_indices,
rank_expert_weights,
expert_buffer,
ep_group,
is_profile=False,
communicator,
)
num_local_experts = test_config.num_local_experts
@@ -10,11 +10,13 @@ import torch
from tests.kernels.moe.utils import make_test_quant_config
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
get_dp_group,
get_eplb_group,
)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
@@ -171,12 +173,20 @@ def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
for lidx in range(test_config.num_layers):
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
expert_buffer = [torch.empty_like(w) for w in rank_expert_weights[0]]
communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(),
backend="torch_nccl",
expert_weights=rank_expert_weights,
expert_buffer=expert_buffer,
)
rearrange_expert_weights_inplace(
indices,
shuffled_indices,
rank_expert_weights,
expert_buffer,
ep_group,
is_profile=False,
communicator,
)
num_global_experts = test_config.num_experts
+4 -1
View File
@@ -1287,10 +1287,12 @@ def _test_body_eplb(
expert_weights = [list(eplb_moe_layer.get_expert_weights())]
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(),
backend=vllm_config.parallel_config.eplb_config.communicator,
expert_weights=expert_weights[0],
expert_weights=expert_weights,
expert_buffer=expert_buffer,
)
# Rearrange expert weights across EP ranks
@@ -1298,6 +1300,7 @@ def _test_body_eplb(
old_global_expert_indices=initial_indices.unsqueeze(0),
new_global_expert_indices=shuffled_indices.unsqueeze(0),
expert_weights=expert_weights,
expert_buffer=expert_buffer,
ep_group=cpu_group,
communicator=communicator,
)
@@ -40,11 +40,7 @@ def _can_p2p(rank: int, world_size: int) -> bool:
return True
def is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
from vllm.distributed.utils import is_weak_contiguous # noqa: E402
class CustomAllreduce:
@@ -24,11 +24,7 @@ except Exception:
quick_ar = False
def is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
from vllm.distributed.utils import is_weak_contiguous # noqa: E402, F401
class QuickReduceRegime(Enum):
@@ -470,10 +470,14 @@ class ElasticEPScalingExecutor:
module._replace_quant_method(module.quant_method.old_quant_method)
prepare_communication_buffer_for_model(self.worker.model_runner.model)
eplb_model_state.expert_buffer = [
torch.empty_like(w) for w in model.expert_weights[0]
]
eplb_model_state.communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(),
backend=parallel_config.eplb_config.communicator,
expert_weights=model.expert_weights[0],
expert_weights=model.expert_weights,
expert_buffer=eplb_model_state.expert_buffer,
)
if (
+1
View File
@@ -120,6 +120,7 @@ def transfer_run_periodically(
ep_group=eplb_group,
is_profile=is_profile,
cuda_stream=cuda_stream,
layer_idx=layer_idx,
)
# Wait until all writes to expert_buffer have finished before making the
+199 -186
View File
@@ -30,6 +30,7 @@ from vllm.distributed.parallel_state import (
is_local_first_rank,
)
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.distributed.utils import is_weak_contiguous
from vllm.logger import init_logger
from vllm.platforms import current_platform
@@ -63,8 +64,22 @@ class EplbCommunicator(ABC):
pass
@abstractmethod
def execute(self, old_indices: np.ndarray | None = None) -> None:
pass
def execute(self) -> None:
"""Complete all enqueued transfers.
Some backends perform communication here; others (e.g. NIXL)
issue transfers eagerly in add_recv and only wait here.
On return, all data is available in the destination buffers.
"""
def set_transfer_context( # noqa: B027
self, old_indices: np.ndarray, layer_idx: int
) -> None:
"""Pre-set layer context before add_recv calls.
Default is a no-op; overridden by backends (e.g. NIXL) that need
layer-level context to issue transfers inside add_recv.
"""
@property
def needs_profile_buffer_reservation(self) -> bool:
@@ -125,7 +140,7 @@ class TorchDistNcclEplbCommunicator(EplbCommunicator):
)
)
def execute(self, old_indices: np.ndarray | None = None) -> None:
def execute(self) -> None:
if not self._p2p_ops:
return
try:
@@ -168,7 +183,7 @@ class TorchDistGlooStagedEplbCommunicator(EplbCommunicator):
for tensor in tensors:
self._ops.append(("recv", tensor, src_rank))
def execute(self, old_indices: np.ndarray | None = None) -> None:
def execute(self) -> None:
if not self._ops:
return
@@ -229,29 +244,47 @@ class NixlEplbCommunicator(EplbCommunicator):
def __init__(
self,
cpu_group: ProcessGroup,
expert_weights: Sequence[torch.Tensor],
cuda_stream: torch.cuda.Stream | None = None,
all_expert_weights: Sequence[Sequence[torch.Tensor]],
expert_buffer: Sequence[torch.Tensor],
) -> None:
assert expert_weights, "NixlEplbCommunicator requires non-empty expert_weights."
assert all_expert_weights, (
"NixlEplbCommunicator requires non-empty all_expert_weights."
)
assert expert_buffer, "NixlEplbCommunicator requires non-empty expert_buffer."
nixl_wrapper_cls = nixl_utils.NixlWrapper
if nixl_wrapper_cls is None:
raise RuntimeError("NIXL/ RIXL is unavailable.")
self._cpu_group = cpu_group
self._cuda_stream = cuda_stream
self._world_size = cpu_group.size()
self._rank = cpu_group.rank()
# expert_id -> weight tensors to pack into the send buffer.
self._expert_send_map: dict[int, list[torch.Tensor]] = {}
# src_rank -> expert_id -> weight tensors to unpack after transfer.
self._recv_map: dict[int, dict[int, list[torch.Tensor]]] = {}
self._num_local_experts: int = expert_weights[0].shape[0]
self._device = expert_weights[0].device
for tensor in expert_weights:
assert tensor.device == self._device, (
"All local EPLB tensors are expected to be on the same device: "
f"expected={self._device}, got={tensor.device}"
self._all_expert_weights = all_expert_weights
self._expert_buffer = expert_buffer
self._num_local_experts: int = all_expert_weights[0][0].shape[0]
self._device = all_expert_weights[0][0].device
for layer_tensors in all_expert_weights:
for tensor in layer_tensors:
assert is_weak_contiguous(tensor), (
"Expert weight tensors must be contiguous in memory"
)
assert tensor.device == self._device, (
"All local EPLB tensors are expected to be on the same "
f"device: expected={self._device}, got={tensor.device}"
)
for tensor in expert_buffer:
assert is_weak_contiguous(tensor), (
"expert_buffer tensors must be contiguous in memory"
)
# (local_dlist, remote_dlist, xfer_handle) for in-flight READs;
# accumulated by add_recv, drained by execute.
self._xfer_entries: list[tuple[int, int, int]] = []
# Per-rank expert_id -> physical row; set by set_transfer_context.
self._expert_to_src_row: list[dict[int, int]] | None = None
self._layer_idx: int | None = None
nixl_agent_config = nixl_utils.nixl_agent_config
config = (
nixl_agent_config(capture_telemetry=False)
@@ -260,15 +293,16 @@ class NixlEplbCommunicator(EplbCommunicator):
)
self._nixl_wrapper = nixl_wrapper_cls(self._make_agent_name(), config)
self._nixl_memory_type = "VRAM"
self._registered_desc: object | None = None
# NIXL registration handles; deregistered in __del__.
self._registered_descs: list[object] = []
self._remote_agents: dict[int, str] = {}
self._remote_send_meta: dict[int, tuple[int, int]] = {}
self._send_buffer: torch.Tensor = torch.empty(0)
self._recv_buffer: torch.Tensor = torch.empty(0)
self._expert_bytes: int = 0
# peer -> (layer, tensor) -> (base_ptr, bytes_per_expert, dev_id).
self._remote_send_meta: dict[
int, dict[tuple[int, int], tuple[int, int, int]]
] = {}
self._cuda_device_id = int(self._device.index or 0)
self._init_step("buffers", self._init_registered_buffers, expert_weights)
self._init_step("buffers", self._init_registered_buffers)
self._init_step("agents", self._init_remote_agents)
self._init_step("send meta", self._exchange_remote_send_meta)
self._log_initialized()
@@ -291,19 +325,34 @@ class NixlEplbCommunicator(EplbCommunicator):
uid = uuid.uuid4().hex[:8]
return f"eplb-{self._rank}{pp_suffix}-{uid}"
def set_stream(self, cuda_stream: torch.cuda.Stream | None) -> None:
pass
def add_send(
self,
tensors: list[torch.Tensor],
dst_rank: int,
expert_id: int,
) -> None:
assert dst_rank != self._rank, (
"EPLB communicator should not enqueue same-rank sends: "
f"rank={self._rank}, dst_rank={dst_rank}"
# No-op: NIXL READ is receiver-initiated. The sender's expert
# weights are pre-registered and always readable in-place.
pass
def set_transfer_context(self, old_indices: np.ndarray, layer_idx: int) -> None:
# Pre-compute expert_id -> src_row mapping for every rank so that
# add_recv can immediately issue NIXL READs.
assert not self._xfer_entries, (
f"set_transfer_context() called with {len(self._xfer_entries)} "
f"pending transfers from layer {self._layer_idx}; "
f"execute() was not called after previous add_recv() calls"
)
# An expert sent to multiple peers is packed only once; skip duplicates.
if expert_id not in self._expert_send_map:
self._expert_send_map[expert_id] = tensors
self._layer_idx = layer_idx
n = self._num_local_experts
rank_experts = old_indices[: self._world_size * n].reshape(self._world_size, n)
self._expert_to_src_row = [
{int(eid): i for i, eid in enumerate(row) if eid != -1}
for row in rank_experts
]
def add_recv(
self,
@@ -311,13 +360,44 @@ class NixlEplbCommunicator(EplbCommunicator):
src_rank: int,
expert_id: int,
) -> None:
assert src_rank != self._rank, (
"EPLB communicator should not enqueue same-rank recvs: "
f"rank={self._rank}, src_rank={src_rank}"
# Build NIXL descriptors and issue the RDMA READ immediately,
# overlapping the transfer with the remaining Python loop in
# move_to_buffer.
assert self._expert_to_src_row is not None and self._layer_idx is not None, (
"set_transfer_context() must be called before add_recv()"
)
recv_experts = self._recv_map.setdefault(src_rank, {})
if expert_id not in recv_experts:
recv_experts[expert_id] = tensors
src_row = self._expert_to_src_row[src_rank][expert_id]
layer_idx = self._layer_idx
local_descs: list[tuple[int, int, int]] = []
remote_descs: list[tuple[int, int, int]] = []
for t_idx, t in enumerate(tensors):
send_base, send_stride, remote_dev = self._remote_send_meta[src_rank][
(layer_idx, t_idx)
]
assert t.nbytes == send_stride, (
f"tensor {t_idx} size {t.nbytes} != remote stride {send_stride}"
)
local_descs.append(
(
t.data_ptr(),
t.nbytes,
self._cuda_device_id,
)
)
remote_descs.append(
(
send_base + src_row * send_stride,
send_stride,
remote_dev,
)
)
local_h, remote_h, xfer_h = self._create_peer_xfer(
src_rank, local_descs, remote_descs
)
self._nixl_wrapper.transfer(xfer_h)
self._xfer_entries.append((local_h, remote_h, xfer_h))
def _init_remote_agents(self) -> None:
local_metadata = self._nixl_wrapper.get_agent_metadata()
@@ -334,73 +414,60 @@ class NixlEplbCommunicator(EplbCommunicator):
peer_metadata
)
def _init_registered_buffers(self, expert_weights: Sequence[torch.Tensor]) -> None:
total_bytes = max(sum(t.nbytes for t in expert_weights), 1)
assert total_bytes % self._num_local_experts == 0, (
f"Number of bytes in moe layer {total_bytes} is not divisible "
f"by number of local experts {self._num_local_experts}"
)
self._expert_bytes = total_bytes // self._num_local_experts
def _init_registered_buffers(self) -> None:
all_tensors: list[torch.Tensor] = []
for layer_tensors in self._all_expert_weights:
all_tensors.extend(layer_tensors)
all_tensors.extend(self._expert_buffer)
self._send_buffer = torch.empty(
total_bytes, device=self._device, dtype=torch.uint8
)
self._recv_buffer = torch.empty(
total_bytes, device=self._device, dtype=torch.uint8
)
descs = self._nixl_wrapper.get_reg_descs([self._send_buffer, self._recv_buffer])
descs = self._nixl_wrapper.get_reg_descs(all_tensors)
self._nixl_wrapper.register_memory(descs)
self._registered_desc = descs
self._registered_descs.append(descs)
def _exchange_remote_send_meta(self) -> None:
"""Exchange send-buffer metadata so each rank can build dynamic
descriptors at execute time."""
local_meta: tuple[int, int] = (
self._send_buffer.data_ptr(),
self._cuda_device_id,
)
gathered_meta: list[tuple[int, int] | None] = [None] * self._world_size
"""Exchange per-layer per-tensor metadata so receivers can compute
remote RDMA addresses at transfer time."""
local_meta: dict[tuple[int, int], tuple[int, int, int]] = {}
for layer_idx, layer_tensors in enumerate(self._all_expert_weights):
for t_idx, t in enumerate(layer_tensors):
nbytes_per_expert = t.nbytes // self._num_local_experts
local_meta[(layer_idx, t_idx)] = (
t.data_ptr(),
nbytes_per_expert,
self._cuda_device_id,
)
# Per-rank map: (layer_idx, tensor_idx) -> (base_ptr, bytes_per_expert, dev_id).
# add_recv uses base_ptr + src_row * bytes_per_expert to compute
# the remote RDMA address for each expert.
gathered_meta: list[dict[tuple[int, int], tuple[int, int, int]] | None] = [
None
] * self._world_size
torch.distributed.all_gather_object(
gathered_meta, local_meta, group=self._cpu_group
)
local_keys = set(local_meta.keys())
for peer in self._remote_agents:
peer_meta = gathered_meta[peer]
assert peer_meta is not None
peer_keys = set(peer_meta.keys())
if peer_keys != local_keys:
raise RuntimeError(
f"NIXL EPLB metadata key mismatch with rank {peer}: "
f"local={sorted(local_keys)}, peer={sorted(peer_keys)}"
)
for key in local_keys:
_, local_stride, _ = local_meta[key]
_, peer_stride, _ = peer_meta[key]
if local_stride != peer_stride:
raise RuntimeError(
f"NIXL EPLB nbytes_per_expert mismatch for {key} "
f"with rank {peer}: "
f"local={local_stride}, peer={peer_stride}"
)
self._remote_send_meta[peer] = peer_meta
@staticmethod
def _pack_send_buffer(
in_tensors: list[torch.Tensor],
send_buffer: torch.Tensor,
byte_offset: int,
) -> None:
for tensor in in_tensors:
raw = tensor.reshape(-1).view(torch.uint8)
if raw.numel() == 0:
continue
send_buffer[byte_offset : byte_offset + raw.numel()].copy_(
raw, non_blocking=True
)
byte_offset += raw.numel()
@staticmethod
def _unpack_recv_buffer(
recv_buffer: torch.Tensor,
out_tensors: list[torch.Tensor],
byte_offset: int,
) -> None:
for tensor in out_tensors:
num_bytes = tensor.numel() * tensor.element_size()
if num_bytes == 0:
continue
tensor.reshape(-1).view(torch.uint8).copy_(
recv_buffer[byte_offset : byte_offset + num_bytes],
non_blocking=True,
)
byte_offset += num_bytes
def _wait_for_all_transfers(self, handles: list[int]) -> None:
pending = set(handles)
while pending:
@@ -456,110 +523,52 @@ class NixlEplbCommunicator(EplbCommunicator):
)
return (local_handle, remote_handle, xfer_handle)
def execute(self, old_indices: np.ndarray | None = None) -> None:
assert old_indices is not None, (
"NixlEplbCommunicator.execute requires old_indices"
def execute(self) -> None:
assert self._layer_idx is not None or not self._xfer_entries, (
"set_transfer_context() must be called before execute() "
"if any add_recv() calls were made"
)
xfer_entries: list[tuple[int, int, int]] = []
try:
n = self._num_local_experts
rank_experts = old_indices[: self._world_size * n].reshape(
self._world_size, n
)
# Build expert_id -> send slot mapping per rank.
expert_to_send_slot: list[dict[int, int]] = [
{int(eid): i for i, eid in enumerate(row) if eid != -1}
for row in rank_experts
]
self._wait_for_all_transfers([x[2] for x in self._xfer_entries])
# Phase 1: pack each expert at its slot offset in the send buffer.
with torch.cuda.stream(self._cuda_stream):
for expert_id, tensors in self._expert_send_map.items():
slot = expert_to_send_slot[self._rank][expert_id]
byte_offset = slot * self._expert_bytes
self._pack_send_buffer(tensors, self._send_buffer, byte_offset)
# Ensure all packed data is visible in device memory before pulls.
if self._cuda_stream is not None:
self._cuda_stream.synchronize()
else:
torch.cuda.current_stream().synchronize()
# READ is receiver-initiated; synchronize all ranks before transfer.
# We use monitored_barrier so a rank that crashes or exits early
# produces a diagnostic timeout instead of a silent hang.
# Post-READ barrier.
# Correctness fence for zero-copy: prevents overwrite-while-
# remote-read race.
torch.distributed.monitored_barrier(
group=self._cpu_group,
timeout=timedelta(minutes=5),
)
# Phase 2: issue one batched READ per peer.
recv_offsets: dict[tuple[int, int], int] = {}
recv_offset = 0
recv_base = self._recv_buffer.data_ptr()
for src in range(self._world_size):
if src == self._rank:
continue
recv_experts = self._recv_map.get(src)
if not recv_experts:
continue
expert_ids = list(recv_experts.keys())
remote_base, remote_dev = self._remote_send_meta[src]
local_descs: list[tuple[int, int, int]] = []
remote_descs: list[tuple[int, int, int]] = []
for expert_id in expert_ids:
slot = expert_to_send_slot[src][expert_id]
remote_off = slot * self._expert_bytes
recv_offsets[(src, expert_id)] = recv_offset
local_descs.append(
(
recv_base + recv_offset,
self._expert_bytes,
self._cuda_device_id,
)
)
remote_descs.append(
(remote_base + remote_off, self._expert_bytes, remote_dev)
)
recv_offset += self._expert_bytes
assert recv_offset <= self._recv_buffer.nbytes
local_h, remote_h, xfer_h = self._create_peer_xfer(
src, local_descs, remote_descs
)
self._nixl_wrapper.transfer(xfer_h)
xfer_entries.append((local_h, remote_h, xfer_h))
# Phase 3: wait for all in-flight transfers, then unpack.
self._wait_for_all_transfers([x[2] for x in xfer_entries])
with torch.cuda.stream(self._cuda_stream):
for (src, expert_id), offset in recv_offsets.items():
self._unpack_recv_buffer(
self._recv_buffer,
self._recv_map[src][expert_id],
offset,
)
finally:
for local_h, remote_h, xfer_h in xfer_entries:
for local_h, remote_h, xfer_h in self._xfer_entries:
with contextlib.suppress(Exception):
self._nixl_wrapper.release_xfer_handle(xfer_h)
with contextlib.suppress(Exception):
self._nixl_wrapper.release_dlist_handle(local_h)
with contextlib.suppress(Exception):
self._nixl_wrapper.release_dlist_handle(remote_h)
self._expert_send_map.clear()
self._recv_map.clear()
self._xfer_entries.clear()
self._expert_to_src_row = None
self._layer_idx = None
def __del__(self) -> None:
try:
if self._registered_desc is not None:
self._nixl_wrapper.deregister_memory(self._registered_desc)
self._registered_desc = None
with contextlib.suppress(Exception):
for local_h, remote_h, xfer_h in self._xfer_entries:
with contextlib.suppress(Exception):
self._nixl_wrapper.release_xfer_handle(xfer_h)
with contextlib.suppress(Exception):
self._nixl_wrapper.release_dlist_handle(local_h)
with contextlib.suppress(Exception):
self._nixl_wrapper.release_dlist_handle(remote_h)
with contextlib.suppress(Exception):
for descs in self._registered_descs:
with contextlib.suppress(Exception):
self._nixl_wrapper.deregister_memory(descs)
self._registered_descs.clear()
with contextlib.suppress(Exception):
for agent_name in self._remote_agents.values():
self._nixl_wrapper.remove_remote_agent(agent_name)
with contextlib.suppress(Exception):
self._nixl_wrapper.remove_remote_agent(agent_name)
self._remote_agents.clear()
except Exception as e:
logger.warning("Error during NixlEplbCommunicator cleanup: %s", e)
class PyNcclEplbCommunicator(EplbCommunicator):
@@ -600,7 +609,7 @@ class PyNcclEplbCommunicator(EplbCommunicator):
for tensor in tensors:
self._pynccl_comm.recv(tensor, src_rank, stream=self._cuda_stream)
def execute(self, old_indices: np.ndarray | None = None) -> None:
def execute(self) -> None:
if self._group_started:
self._pynccl_comm.group_end()
self._group_started = False
@@ -609,7 +618,8 @@ class PyNcclEplbCommunicator(EplbCommunicator):
def create_eplb_communicator(
group_coordinator: GroupCoordinator,
backend: str | None,
expert_weights: Sequence[torch.Tensor],
expert_weights: Sequence[Sequence[torch.Tensor]],
expert_buffer: Sequence[torch.Tensor],
) -> EplbCommunicator:
"""Create an EPLB communicator for the given backend.
@@ -624,16 +634,18 @@ def create_eplb_communicator(
``"pynccl"`` in that case. When tensors reside on CPU,
``"torch_gloo"`` or ``"torch_nccl"`` are used via the CPU
process group.
expert_weights: Expert weight tensors from *one* MoE layer.
NixlEplbCommunicator pre-allocates send/recv buffers sized
to this layer, so all other MoE layers must have the same
tensor count, shapes, and dtypes.
expert_weights: Expert weight tensors for *all* MoE layers.
Shape ``(num_layers)(num_tensors_per_layer)``.
NixlEplbCommunicator registers all layers with NIXL for
zero-copy RDMA reads.
expert_buffer: Pre-allocated receive buffer tensors (one per
weight tensor in a single layer).
"""
# Keep a safe default for callers that have not resolved communicator yet.
if backend is None:
backend = "torch_nccl"
tensor_device_type = expert_weights[0].device.type if expert_weights else "cpu"
first_layer = expert_weights[0] if expert_weights else []
tensor_device_type = first_layer[0].device.type if first_layer else "cpu"
torch_group = (
group_coordinator.cpu_group
if tensor_device_type == "cpu"
@@ -649,7 +661,7 @@ def create_eplb_communicator(
unsupported_dtypes = sorted(
{
tensor.dtype
for tensor in expert_weights
for tensor in first_layer
if not ncclDataTypeEnum.supports_torch_dtype(tensor.dtype)
},
key=str,
@@ -704,7 +716,8 @@ def create_eplb_communicator(
try:
return NixlEplbCommunicator(
cpu_group=group_coordinator.cpu_group,
expert_weights=expert_weights,
all_expert_weights=expert_weights,
expert_buffer=expert_buffer,
)
except Exception as exc:
raise RuntimeError(
+3 -1
View File
@@ -450,7 +450,8 @@ class EplbState:
communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(),
backend=self.parallel_config.eplb_config.communicator,
expert_weights=model.expert_weights[0],
expert_weights=model.expert_weights,
expert_buffer=expert_buffer,
)
model_state = EplbModelState(
@@ -766,6 +767,7 @@ class EplbState:
eplb_model_state.physical_to_logical_map,
new_physical_to_logical_map,
eplb_model_state.model.expert_weights,
eplb_model_state.expert_buffer,
ep_group,
eplb_model_state.communicator,
is_profile,
+16 -9
View File
@@ -178,6 +178,7 @@ def move_to_buffer(
cuda_stream: torch.cuda.Stream | None,
ep_rank: int,
communicator: EplbCommunicator,
layer_idx: int = 0,
) -> TransferMetadata:
"""
Rearranges expert weights during EPLB rebalancing.
@@ -193,6 +194,7 @@ def move_to_buffer(
cuda_stream: CUDA stream for async copies (can be None for sync mode).
ep_rank: Rank of this process in expert parallel group.
communicator: EplbCommunicator instance for P2P communication.
layer_idx: Index of the MoE layer being transferred.
Returns:
TransferMetadata: Metadata needed for completing remote weight transfers.
@@ -265,6 +267,8 @@ def move_to_buffer(
for w, b in zip(expert_weights, expert_weights_buffers):
b[dst].copy_(w[src_local], non_blocking=True)
communicator.set_transfer_context(old_indices, layer_idx)
# 2. Post sends
if send_count > 0:
experts = send_expert_ids[:send_count]
@@ -331,9 +335,8 @@ def move_to_buffer(
expert_id=int(expert),
)
# 4. Execute the P2P operations. The real communication happens here.
communicator.execute(old_indices=old_indices)
# wait for the communication to finish
# 4. Execute transfers and wait for completion.
communicator.execute()
return TransferMetadata(
is_unchanged=is_unchanged,
is_received_locally=is_received_locally,
@@ -431,6 +434,7 @@ def transfer_layer(
is_profile: bool = False,
cuda_stream: torch.cuda.Stream | None = None,
rank_mapping: dict[int, int] | None = None,
layer_idx: int = 0,
) -> TransferMetadata:
"""
Rearranges the expert weights in place according to the new expert indices.
@@ -452,6 +456,7 @@ def transfer_layer(
communications to reserve enough memory for the buffers.
cuda_stream: CUDA stream for async copies (can be None for sync mode).
rank_mapping: Optional rank mapping for elastic expert parallelism.
layer_idx: Index of the MoE layer being transferred.
Returns:
TransferMetadata: Metadata needed for completing remote weight transfers,
@@ -499,6 +504,7 @@ def transfer_layer(
cuda_stream=cuda_stream,
ep_rank=ep_group.rank(),
communicator=communicator,
layer_idx=layer_idx,
)
@@ -506,6 +512,7 @@ def rearrange_expert_weights_inplace(
old_global_expert_indices: torch.Tensor,
new_global_expert_indices: torch.Tensor,
expert_weights: Sequence[Sequence[torch.Tensor]],
expert_buffer: Sequence[torch.Tensor],
ep_group: ProcessGroup,
communicator: EplbCommunicator,
is_profile: bool = False,
@@ -524,6 +531,8 @@ def rearrange_expert_weights_inplace(
of tensors of shape (num_local_physical_experts, hidden_size_i).
For example, a linear layer may have up and down projection,
so weight_count = 2. Each weight's hidden size can be different.
expert_buffer: Pre-allocated receive buffer tensors (one per
weight tensor in a single layer).
ep_group: The device process group for expert parallelism.
communicator: EplbCommunicator instance for P2P communication.
is_profile (bool): If `True`, do not perform any actual weight copy.
@@ -566,10 +575,10 @@ def rearrange_expert_weights_inplace(
# Reserve NCCL communication buffers via a dummy all_gather.
# Backends that pre-allocate their own transfer buffers
# skip this to avoid the extra memory spike during profiling.
weights_buffer: list[torch.Tensor] = [
profile_buffer: list[torch.Tensor] = [
torch.empty_like(w) for w in first_layer_weights
]
for weight, buffer in zip(expert_weights[0], weights_buffer):
for weight, buffer in zip(expert_weights[0], profile_buffer):
dummy_recv_buffer = [buffer for _ in range(ep_size)]
torch.distributed.barrier()
all_gather(
@@ -579,10 +588,7 @@ def rearrange_expert_weights_inplace(
)
return
# Buffers to hold the expert weights during the exchange.
# NOTE: Currently we assume the same weights across different layers
# have the same shape.
weights_buffer = [torch.empty_like(w) for w in first_layer_weights]
weights_buffer = list(expert_buffer)
old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy()
new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy()
@@ -597,6 +603,7 @@ def rearrange_expert_weights_inplace(
cuda_stream=None,
ep_rank=ep_rank,
communicator=communicator,
layer_idx=layer_idx,
)
move_from_buffer(
+14
View File
@@ -64,6 +64,20 @@ def divide(numerator, denominator):
return numerator // denominator
def is_weak_contiguous(inp: torch.Tensor) -> bool:
"""Check that *inp* occupies a single contiguous block of memory.
Unlike ``torch.Tensor.is_contiguous()``, this also accepts tensors
whose strides are not strictly C-contiguous (e.g. column-major) as
long as the underlying storage from the tensor's offset onward is
exactly ``numel * element_size`` bytes.
"""
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,