mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[EPLB] Nixl communicator optimization. Zero-copy transfers (#41633)
Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Markov Ilya <markovilya19@gmail.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Markov Ilya <markovilya19@gmail.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
@@ -277,12 +277,15 @@ def assert_verification_synced(local_ok: bool, msg: str) -> None:
|
||||
assert bool(ok_tensor.item()), msg
|
||||
|
||||
|
||||
def create_eplb_communicator_or_raise(*, group_coordinator, backend, expert_weights):
|
||||
def create_eplb_communicator_or_raise(
|
||||
*, group_coordinator, backend, expert_weights, expert_buffer
|
||||
):
|
||||
try:
|
||||
return create_eplb_communicator(
|
||||
group_coordinator=group_coordinator,
|
||||
backend=backend,
|
||||
expert_weights=expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(
|
||||
@@ -355,7 +358,8 @@ def _test_async_transfer_layer_without_mtp_worker(
|
||||
communicator = create_eplb_communicator_or_raise(
|
||||
group_coordinator=ep_group_coordinator,
|
||||
backend=eplb_communicator,
|
||||
expert_weights=expert_weights[0],
|
||||
expert_weights=expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
)
|
||||
communicator.set_stream(cuda_stream)
|
||||
|
||||
@@ -368,6 +372,7 @@ def _test_async_transfer_layer_without_mtp_worker(
|
||||
ep_group=ep_group,
|
||||
communicator=communicator,
|
||||
cuda_stream=cuda_stream,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
cuda_stream.synchronize()
|
||||
move_from_buffer(
|
||||
@@ -460,10 +465,12 @@ def _test_rearrange_expert_weights_with_redundancy(
|
||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
||||
)
|
||||
|
||||
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||
communicator = create_eplb_communicator_or_raise(
|
||||
group_coordinator=ep_group_coordinator,
|
||||
backend=eplb_communicator,
|
||||
expert_weights=expert_weights[0],
|
||||
expert_weights=expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
)
|
||||
|
||||
# Execute weight rearrangement
|
||||
@@ -471,9 +478,9 @@ def _test_rearrange_expert_weights_with_redundancy(
|
||||
old_indices,
|
||||
new_indices,
|
||||
expert_weights,
|
||||
expert_buffer,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
communicator=communicator,
|
||||
communicator,
|
||||
)
|
||||
|
||||
# Verify the rearrangement result
|
||||
@@ -593,10 +600,12 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
|
||||
layer_copy.append(weight.clone())
|
||||
original_weights.append(layer_copy)
|
||||
|
||||
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||
communicator = create_eplb_communicator_or_raise(
|
||||
group_coordinator=ep_group_coordinator,
|
||||
backend="torch_nccl",
|
||||
expert_weights=expert_weights[0],
|
||||
expert_weights=expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
)
|
||||
|
||||
# Execute rearrangement (should be no change)
|
||||
@@ -604,9 +613,9 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
|
||||
indices,
|
||||
indices, # Same indices
|
||||
expert_weights,
|
||||
expert_buffer,
|
||||
ep_group,
|
||||
communicator,
|
||||
is_profile=False,
|
||||
)
|
||||
|
||||
# Verify that the weights have not changed
|
||||
@@ -726,10 +735,12 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
|
||||
layer_copy.append(weight.clone())
|
||||
original_weights.append(layer_copy)
|
||||
|
||||
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||
communicator = create_eplb_communicator_or_raise(
|
||||
group_coordinator=ep_group_coordinator,
|
||||
backend="torch_nccl",
|
||||
expert_weights=expert_weights[0],
|
||||
expert_weights=expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
)
|
||||
|
||||
# Execute profile mode rearrangement
|
||||
@@ -737,9 +748,10 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
|
||||
old_indices,
|
||||
new_indices,
|
||||
expert_weights,
|
||||
expert_buffer,
|
||||
ep_group,
|
||||
communicator,
|
||||
is_profile=True, # Profile mode
|
||||
is_profile=True,
|
||||
)
|
||||
|
||||
# In profile mode, the weights should remain unchanged
|
||||
|
||||
@@ -9,9 +9,11 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
|
||||
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized,
|
||||
get_eplb_group,
|
||||
get_tp_group,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
@@ -213,12 +215,20 @@ def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
|
||||
for lidx in range(test_config.num_layers):
|
||||
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
|
||||
|
||||
expert_buffer = [torch.empty_like(w) for w in rank_expert_weights[0]]
|
||||
communicator = create_eplb_communicator(
|
||||
group_coordinator=get_eplb_group(),
|
||||
backend="torch_nccl",
|
||||
expert_weights=rank_expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
)
|
||||
rearrange_expert_weights_inplace(
|
||||
indices,
|
||||
shuffled_indices,
|
||||
rank_expert_weights,
|
||||
expert_buffer,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
communicator,
|
||||
)
|
||||
|
||||
num_local_experts = test_config.num_local_experts
|
||||
|
||||
@@ -10,11 +10,13 @@ import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_quant_config
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized,
|
||||
get_dp_group,
|
||||
get_eplb_group,
|
||||
)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
@@ -171,12 +173,20 @@ def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
|
||||
for lidx in range(test_config.num_layers):
|
||||
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
|
||||
|
||||
expert_buffer = [torch.empty_like(w) for w in rank_expert_weights[0]]
|
||||
communicator = create_eplb_communicator(
|
||||
group_coordinator=get_eplb_group(),
|
||||
backend="torch_nccl",
|
||||
expert_weights=rank_expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
)
|
||||
rearrange_expert_weights_inplace(
|
||||
indices,
|
||||
shuffled_indices,
|
||||
rank_expert_weights,
|
||||
expert_buffer,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
communicator,
|
||||
)
|
||||
|
||||
num_global_experts = test_config.num_experts
|
||||
|
||||
@@ -1287,10 +1287,12 @@ def _test_body_eplb(
|
||||
|
||||
expert_weights = [list(eplb_moe_layer.get_expert_weights())]
|
||||
|
||||
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||
communicator = create_eplb_communicator(
|
||||
group_coordinator=get_eplb_group(),
|
||||
backend=vllm_config.parallel_config.eplb_config.communicator,
|
||||
expert_weights=expert_weights[0],
|
||||
expert_weights=expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
)
|
||||
|
||||
# Rearrange expert weights across EP ranks
|
||||
@@ -1298,6 +1300,7 @@ def _test_body_eplb(
|
||||
old_global_expert_indices=initial_indices.unsqueeze(0),
|
||||
new_global_expert_indices=shuffled_indices.unsqueeze(0),
|
||||
expert_weights=expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
ep_group=cpu_group,
|
||||
communicator=communicator,
|
||||
)
|
||||
|
||||
@@ -40,11 +40,7 @@ def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def is_weak_contiguous(inp: torch.Tensor):
|
||||
return inp.is_contiguous() or (
|
||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size()
|
||||
)
|
||||
from vllm.distributed.utils import is_weak_contiguous # noqa: E402
|
||||
|
||||
|
||||
class CustomAllreduce:
|
||||
|
||||
@@ -24,11 +24,7 @@ except Exception:
|
||||
quick_ar = False
|
||||
|
||||
|
||||
def is_weak_contiguous(inp: torch.Tensor):
|
||||
return inp.is_contiguous() or (
|
||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size()
|
||||
)
|
||||
from vllm.distributed.utils import is_weak_contiguous # noqa: E402, F401
|
||||
|
||||
|
||||
class QuickReduceRegime(Enum):
|
||||
|
||||
@@ -470,10 +470,14 @@ class ElasticEPScalingExecutor:
|
||||
module._replace_quant_method(module.quant_method.old_quant_method)
|
||||
prepare_communication_buffer_for_model(self.worker.model_runner.model)
|
||||
|
||||
eplb_model_state.expert_buffer = [
|
||||
torch.empty_like(w) for w in model.expert_weights[0]
|
||||
]
|
||||
eplb_model_state.communicator = create_eplb_communicator(
|
||||
group_coordinator=get_eplb_group(),
|
||||
backend=parallel_config.eplb_config.communicator,
|
||||
expert_weights=model.expert_weights[0],
|
||||
expert_weights=model.expert_weights,
|
||||
expert_buffer=eplb_model_state.expert_buffer,
|
||||
)
|
||||
|
||||
if (
|
||||
|
||||
@@ -120,6 +120,7 @@ def transfer_run_periodically(
|
||||
ep_group=eplb_group,
|
||||
is_profile=is_profile,
|
||||
cuda_stream=cuda_stream,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
# Wait until all writes to expert_buffer have finished before making the
|
||||
|
||||
@@ -30,6 +30,7 @@ from vllm.distributed.parallel_state import (
|
||||
is_local_first_rank,
|
||||
)
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
from vllm.distributed.utils import is_weak_contiguous
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -63,8 +64,22 @@ class EplbCommunicator(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, old_indices: np.ndarray | None = None) -> None:
|
||||
pass
|
||||
def execute(self) -> None:
|
||||
"""Complete all enqueued transfers.
|
||||
|
||||
Some backends perform communication here; others (e.g. NIXL)
|
||||
issue transfers eagerly in add_recv and only wait here.
|
||||
On return, all data is available in the destination buffers.
|
||||
"""
|
||||
|
||||
def set_transfer_context( # noqa: B027
|
||||
self, old_indices: np.ndarray, layer_idx: int
|
||||
) -> None:
|
||||
"""Pre-set layer context before add_recv calls.
|
||||
|
||||
Default is a no-op; overridden by backends (e.g. NIXL) that need
|
||||
layer-level context to issue transfers inside add_recv.
|
||||
"""
|
||||
|
||||
@property
|
||||
def needs_profile_buffer_reservation(self) -> bool:
|
||||
@@ -125,7 +140,7 @@ class TorchDistNcclEplbCommunicator(EplbCommunicator):
|
||||
)
|
||||
)
|
||||
|
||||
def execute(self, old_indices: np.ndarray | None = None) -> None:
|
||||
def execute(self) -> None:
|
||||
if not self._p2p_ops:
|
||||
return
|
||||
try:
|
||||
@@ -168,7 +183,7 @@ class TorchDistGlooStagedEplbCommunicator(EplbCommunicator):
|
||||
for tensor in tensors:
|
||||
self._ops.append(("recv", tensor, src_rank))
|
||||
|
||||
def execute(self, old_indices: np.ndarray | None = None) -> None:
|
||||
def execute(self) -> None:
|
||||
if not self._ops:
|
||||
return
|
||||
|
||||
@@ -229,29 +244,47 @@ class NixlEplbCommunicator(EplbCommunicator):
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
expert_weights: Sequence[torch.Tensor],
|
||||
cuda_stream: torch.cuda.Stream | None = None,
|
||||
all_expert_weights: Sequence[Sequence[torch.Tensor]],
|
||||
expert_buffer: Sequence[torch.Tensor],
|
||||
) -> None:
|
||||
assert expert_weights, "NixlEplbCommunicator requires non-empty expert_weights."
|
||||
assert all_expert_weights, (
|
||||
"NixlEplbCommunicator requires non-empty all_expert_weights."
|
||||
)
|
||||
assert expert_buffer, "NixlEplbCommunicator requires non-empty expert_buffer."
|
||||
nixl_wrapper_cls = nixl_utils.NixlWrapper
|
||||
if nixl_wrapper_cls is None:
|
||||
raise RuntimeError("NIXL/ RIXL is unavailable.")
|
||||
|
||||
self._cpu_group = cpu_group
|
||||
self._cuda_stream = cuda_stream
|
||||
self._world_size = cpu_group.size()
|
||||
self._rank = cpu_group.rank()
|
||||
# expert_id -> weight tensors to pack into the send buffer.
|
||||
self._expert_send_map: dict[int, list[torch.Tensor]] = {}
|
||||
# src_rank -> expert_id -> weight tensors to unpack after transfer.
|
||||
self._recv_map: dict[int, dict[int, list[torch.Tensor]]] = {}
|
||||
self._num_local_experts: int = expert_weights[0].shape[0]
|
||||
self._device = expert_weights[0].device
|
||||
for tensor in expert_weights:
|
||||
assert tensor.device == self._device, (
|
||||
"All local EPLB tensors are expected to be on the same device: "
|
||||
f"expected={self._device}, got={tensor.device}"
|
||||
|
||||
self._all_expert_weights = all_expert_weights
|
||||
self._expert_buffer = expert_buffer
|
||||
self._num_local_experts: int = all_expert_weights[0][0].shape[0]
|
||||
self._device = all_expert_weights[0][0].device
|
||||
|
||||
for layer_tensors in all_expert_weights:
|
||||
for tensor in layer_tensors:
|
||||
assert is_weak_contiguous(tensor), (
|
||||
"Expert weight tensors must be contiguous in memory"
|
||||
)
|
||||
assert tensor.device == self._device, (
|
||||
"All local EPLB tensors are expected to be on the same "
|
||||
f"device: expected={self._device}, got={tensor.device}"
|
||||
)
|
||||
for tensor in expert_buffer:
|
||||
assert is_weak_contiguous(tensor), (
|
||||
"expert_buffer tensors must be contiguous in memory"
|
||||
)
|
||||
|
||||
# (local_dlist, remote_dlist, xfer_handle) for in-flight READs;
|
||||
# accumulated by add_recv, drained by execute.
|
||||
self._xfer_entries: list[tuple[int, int, int]] = []
|
||||
# Per-rank expert_id -> physical row; set by set_transfer_context.
|
||||
self._expert_to_src_row: list[dict[int, int]] | None = None
|
||||
self._layer_idx: int | None = None
|
||||
|
||||
nixl_agent_config = nixl_utils.nixl_agent_config
|
||||
config = (
|
||||
nixl_agent_config(capture_telemetry=False)
|
||||
@@ -260,15 +293,16 @@ class NixlEplbCommunicator(EplbCommunicator):
|
||||
)
|
||||
self._nixl_wrapper = nixl_wrapper_cls(self._make_agent_name(), config)
|
||||
self._nixl_memory_type = "VRAM"
|
||||
self._registered_desc: object | None = None
|
||||
# NIXL registration handles; deregistered in __del__.
|
||||
self._registered_descs: list[object] = []
|
||||
self._remote_agents: dict[int, str] = {}
|
||||
self._remote_send_meta: dict[int, tuple[int, int]] = {}
|
||||
self._send_buffer: torch.Tensor = torch.empty(0)
|
||||
self._recv_buffer: torch.Tensor = torch.empty(0)
|
||||
self._expert_bytes: int = 0
|
||||
# peer -> (layer, tensor) -> (base_ptr, bytes_per_expert, dev_id).
|
||||
self._remote_send_meta: dict[
|
||||
int, dict[tuple[int, int], tuple[int, int, int]]
|
||||
] = {}
|
||||
|
||||
self._cuda_device_id = int(self._device.index or 0)
|
||||
self._init_step("buffers", self._init_registered_buffers, expert_weights)
|
||||
self._init_step("buffers", self._init_registered_buffers)
|
||||
self._init_step("agents", self._init_remote_agents)
|
||||
self._init_step("send meta", self._exchange_remote_send_meta)
|
||||
self._log_initialized()
|
||||
@@ -291,19 +325,34 @@ class NixlEplbCommunicator(EplbCommunicator):
|
||||
uid = uuid.uuid4().hex[:8]
|
||||
return f"eplb-{self._rank}{pp_suffix}-{uid}"
|
||||
|
||||
def set_stream(self, cuda_stream: torch.cuda.Stream | None) -> None:
|
||||
pass
|
||||
|
||||
def add_send(
|
||||
self,
|
||||
tensors: list[torch.Tensor],
|
||||
dst_rank: int,
|
||||
expert_id: int,
|
||||
) -> None:
|
||||
assert dst_rank != self._rank, (
|
||||
"EPLB communicator should not enqueue same-rank sends: "
|
||||
f"rank={self._rank}, dst_rank={dst_rank}"
|
||||
# No-op: NIXL READ is receiver-initiated. The sender's expert
|
||||
# weights are pre-registered and always readable in-place.
|
||||
pass
|
||||
|
||||
def set_transfer_context(self, old_indices: np.ndarray, layer_idx: int) -> None:
|
||||
# Pre-compute expert_id -> src_row mapping for every rank so that
|
||||
# add_recv can immediately issue NIXL READs.
|
||||
assert not self._xfer_entries, (
|
||||
f"set_transfer_context() called with {len(self._xfer_entries)} "
|
||||
f"pending transfers from layer {self._layer_idx}; "
|
||||
f"execute() was not called after previous add_recv() calls"
|
||||
)
|
||||
# An expert sent to multiple peers is packed only once; skip duplicates.
|
||||
if expert_id not in self._expert_send_map:
|
||||
self._expert_send_map[expert_id] = tensors
|
||||
self._layer_idx = layer_idx
|
||||
n = self._num_local_experts
|
||||
rank_experts = old_indices[: self._world_size * n].reshape(self._world_size, n)
|
||||
self._expert_to_src_row = [
|
||||
{int(eid): i for i, eid in enumerate(row) if eid != -1}
|
||||
for row in rank_experts
|
||||
]
|
||||
|
||||
def add_recv(
|
||||
self,
|
||||
@@ -311,13 +360,44 @@ class NixlEplbCommunicator(EplbCommunicator):
|
||||
src_rank: int,
|
||||
expert_id: int,
|
||||
) -> None:
|
||||
assert src_rank != self._rank, (
|
||||
"EPLB communicator should not enqueue same-rank recvs: "
|
||||
f"rank={self._rank}, src_rank={src_rank}"
|
||||
# Build NIXL descriptors and issue the RDMA READ immediately,
|
||||
# overlapping the transfer with the remaining Python loop in
|
||||
# move_to_buffer.
|
||||
assert self._expert_to_src_row is not None and self._layer_idx is not None, (
|
||||
"set_transfer_context() must be called before add_recv()"
|
||||
)
|
||||
recv_experts = self._recv_map.setdefault(src_rank, {})
|
||||
if expert_id not in recv_experts:
|
||||
recv_experts[expert_id] = tensors
|
||||
src_row = self._expert_to_src_row[src_rank][expert_id]
|
||||
layer_idx = self._layer_idx
|
||||
|
||||
local_descs: list[tuple[int, int, int]] = []
|
||||
remote_descs: list[tuple[int, int, int]] = []
|
||||
for t_idx, t in enumerate(tensors):
|
||||
send_base, send_stride, remote_dev = self._remote_send_meta[src_rank][
|
||||
(layer_idx, t_idx)
|
||||
]
|
||||
assert t.nbytes == send_stride, (
|
||||
f"tensor {t_idx} size {t.nbytes} != remote stride {send_stride}"
|
||||
)
|
||||
local_descs.append(
|
||||
(
|
||||
t.data_ptr(),
|
||||
t.nbytes,
|
||||
self._cuda_device_id,
|
||||
)
|
||||
)
|
||||
remote_descs.append(
|
||||
(
|
||||
send_base + src_row * send_stride,
|
||||
send_stride,
|
||||
remote_dev,
|
||||
)
|
||||
)
|
||||
|
||||
local_h, remote_h, xfer_h = self._create_peer_xfer(
|
||||
src_rank, local_descs, remote_descs
|
||||
)
|
||||
self._nixl_wrapper.transfer(xfer_h)
|
||||
self._xfer_entries.append((local_h, remote_h, xfer_h))
|
||||
|
||||
def _init_remote_agents(self) -> None:
|
||||
local_metadata = self._nixl_wrapper.get_agent_metadata()
|
||||
@@ -334,73 +414,60 @@ class NixlEplbCommunicator(EplbCommunicator):
|
||||
peer_metadata
|
||||
)
|
||||
|
||||
def _init_registered_buffers(self, expert_weights: Sequence[torch.Tensor]) -> None:
|
||||
total_bytes = max(sum(t.nbytes for t in expert_weights), 1)
|
||||
assert total_bytes % self._num_local_experts == 0, (
|
||||
f"Number of bytes in moe layer {total_bytes} is not divisible "
|
||||
f"by number of local experts {self._num_local_experts}"
|
||||
)
|
||||
self._expert_bytes = total_bytes // self._num_local_experts
|
||||
def _init_registered_buffers(self) -> None:
|
||||
all_tensors: list[torch.Tensor] = []
|
||||
for layer_tensors in self._all_expert_weights:
|
||||
all_tensors.extend(layer_tensors)
|
||||
all_tensors.extend(self._expert_buffer)
|
||||
|
||||
self._send_buffer = torch.empty(
|
||||
total_bytes, device=self._device, dtype=torch.uint8
|
||||
)
|
||||
self._recv_buffer = torch.empty(
|
||||
total_bytes, device=self._device, dtype=torch.uint8
|
||||
)
|
||||
|
||||
descs = self._nixl_wrapper.get_reg_descs([self._send_buffer, self._recv_buffer])
|
||||
descs = self._nixl_wrapper.get_reg_descs(all_tensors)
|
||||
self._nixl_wrapper.register_memory(descs)
|
||||
self._registered_desc = descs
|
||||
self._registered_descs.append(descs)
|
||||
|
||||
def _exchange_remote_send_meta(self) -> None:
|
||||
"""Exchange send-buffer metadata so each rank can build dynamic
|
||||
descriptors at execute time."""
|
||||
local_meta: tuple[int, int] = (
|
||||
self._send_buffer.data_ptr(),
|
||||
self._cuda_device_id,
|
||||
)
|
||||
gathered_meta: list[tuple[int, int] | None] = [None] * self._world_size
|
||||
"""Exchange per-layer per-tensor metadata so receivers can compute
|
||||
remote RDMA addresses at transfer time."""
|
||||
local_meta: dict[tuple[int, int], tuple[int, int, int]] = {}
|
||||
for layer_idx, layer_tensors in enumerate(self._all_expert_weights):
|
||||
for t_idx, t in enumerate(layer_tensors):
|
||||
nbytes_per_expert = t.nbytes // self._num_local_experts
|
||||
local_meta[(layer_idx, t_idx)] = (
|
||||
t.data_ptr(),
|
||||
nbytes_per_expert,
|
||||
self._cuda_device_id,
|
||||
)
|
||||
|
||||
# Per-rank map: (layer_idx, tensor_idx) -> (base_ptr, bytes_per_expert, dev_id).
|
||||
# add_recv uses base_ptr + src_row * bytes_per_expert to compute
|
||||
# the remote RDMA address for each expert.
|
||||
gathered_meta: list[dict[tuple[int, int], tuple[int, int, int]] | None] = [
|
||||
None
|
||||
] * self._world_size
|
||||
torch.distributed.all_gather_object(
|
||||
gathered_meta, local_meta, group=self._cpu_group
|
||||
)
|
||||
|
||||
local_keys = set(local_meta.keys())
|
||||
for peer in self._remote_agents:
|
||||
peer_meta = gathered_meta[peer]
|
||||
assert peer_meta is not None
|
||||
peer_keys = set(peer_meta.keys())
|
||||
if peer_keys != local_keys:
|
||||
raise RuntimeError(
|
||||
f"NIXL EPLB metadata key mismatch with rank {peer}: "
|
||||
f"local={sorted(local_keys)}, peer={sorted(peer_keys)}"
|
||||
)
|
||||
for key in local_keys:
|
||||
_, local_stride, _ = local_meta[key]
|
||||
_, peer_stride, _ = peer_meta[key]
|
||||
if local_stride != peer_stride:
|
||||
raise RuntimeError(
|
||||
f"NIXL EPLB nbytes_per_expert mismatch for {key} "
|
||||
f"with rank {peer}: "
|
||||
f"local={local_stride}, peer={peer_stride}"
|
||||
)
|
||||
self._remote_send_meta[peer] = peer_meta
|
||||
|
||||
@staticmethod
|
||||
def _pack_send_buffer(
|
||||
in_tensors: list[torch.Tensor],
|
||||
send_buffer: torch.Tensor,
|
||||
byte_offset: int,
|
||||
) -> None:
|
||||
for tensor in in_tensors:
|
||||
raw = tensor.reshape(-1).view(torch.uint8)
|
||||
if raw.numel() == 0:
|
||||
continue
|
||||
send_buffer[byte_offset : byte_offset + raw.numel()].copy_(
|
||||
raw, non_blocking=True
|
||||
)
|
||||
byte_offset += raw.numel()
|
||||
|
||||
@staticmethod
|
||||
def _unpack_recv_buffer(
|
||||
recv_buffer: torch.Tensor,
|
||||
out_tensors: list[torch.Tensor],
|
||||
byte_offset: int,
|
||||
) -> None:
|
||||
for tensor in out_tensors:
|
||||
num_bytes = tensor.numel() * tensor.element_size()
|
||||
if num_bytes == 0:
|
||||
continue
|
||||
tensor.reshape(-1).view(torch.uint8).copy_(
|
||||
recv_buffer[byte_offset : byte_offset + num_bytes],
|
||||
non_blocking=True,
|
||||
)
|
||||
byte_offset += num_bytes
|
||||
|
||||
def _wait_for_all_transfers(self, handles: list[int]) -> None:
|
||||
pending = set(handles)
|
||||
while pending:
|
||||
@@ -456,110 +523,52 @@ class NixlEplbCommunicator(EplbCommunicator):
|
||||
)
|
||||
return (local_handle, remote_handle, xfer_handle)
|
||||
|
||||
def execute(self, old_indices: np.ndarray | None = None) -> None:
|
||||
assert old_indices is not None, (
|
||||
"NixlEplbCommunicator.execute requires old_indices"
|
||||
def execute(self) -> None:
|
||||
assert self._layer_idx is not None or not self._xfer_entries, (
|
||||
"set_transfer_context() must be called before execute() "
|
||||
"if any add_recv() calls were made"
|
||||
)
|
||||
|
||||
xfer_entries: list[tuple[int, int, int]] = []
|
||||
try:
|
||||
n = self._num_local_experts
|
||||
rank_experts = old_indices[: self._world_size * n].reshape(
|
||||
self._world_size, n
|
||||
)
|
||||
# Build expert_id -> send slot mapping per rank.
|
||||
expert_to_send_slot: list[dict[int, int]] = [
|
||||
{int(eid): i for i, eid in enumerate(row) if eid != -1}
|
||||
for row in rank_experts
|
||||
]
|
||||
self._wait_for_all_transfers([x[2] for x in self._xfer_entries])
|
||||
|
||||
# Phase 1: pack each expert at its slot offset in the send buffer.
|
||||
with torch.cuda.stream(self._cuda_stream):
|
||||
for expert_id, tensors in self._expert_send_map.items():
|
||||
slot = expert_to_send_slot[self._rank][expert_id]
|
||||
byte_offset = slot * self._expert_bytes
|
||||
self._pack_send_buffer(tensors, self._send_buffer, byte_offset)
|
||||
|
||||
# Ensure all packed data is visible in device memory before pulls.
|
||||
if self._cuda_stream is not None:
|
||||
self._cuda_stream.synchronize()
|
||||
else:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
# READ is receiver-initiated; synchronize all ranks before transfer.
|
||||
# We use monitored_barrier so a rank that crashes or exits early
|
||||
# produces a diagnostic timeout instead of a silent hang.
|
||||
# Post-READ barrier.
|
||||
# Correctness fence for zero-copy: prevents overwrite-while-
|
||||
# remote-read race.
|
||||
torch.distributed.monitored_barrier(
|
||||
group=self._cpu_group,
|
||||
timeout=timedelta(minutes=5),
|
||||
)
|
||||
|
||||
# Phase 2: issue one batched READ per peer.
|
||||
recv_offsets: dict[tuple[int, int], int] = {}
|
||||
recv_offset = 0
|
||||
recv_base = self._recv_buffer.data_ptr()
|
||||
for src in range(self._world_size):
|
||||
if src == self._rank:
|
||||
continue
|
||||
recv_experts = self._recv_map.get(src)
|
||||
if not recv_experts:
|
||||
continue
|
||||
expert_ids = list(recv_experts.keys())
|
||||
remote_base, remote_dev = self._remote_send_meta[src]
|
||||
local_descs: list[tuple[int, int, int]] = []
|
||||
remote_descs: list[tuple[int, int, int]] = []
|
||||
for expert_id in expert_ids:
|
||||
slot = expert_to_send_slot[src][expert_id]
|
||||
remote_off = slot * self._expert_bytes
|
||||
recv_offsets[(src, expert_id)] = recv_offset
|
||||
local_descs.append(
|
||||
(
|
||||
recv_base + recv_offset,
|
||||
self._expert_bytes,
|
||||
self._cuda_device_id,
|
||||
)
|
||||
)
|
||||
remote_descs.append(
|
||||
(remote_base + remote_off, self._expert_bytes, remote_dev)
|
||||
)
|
||||
recv_offset += self._expert_bytes
|
||||
assert recv_offset <= self._recv_buffer.nbytes
|
||||
local_h, remote_h, xfer_h = self._create_peer_xfer(
|
||||
src, local_descs, remote_descs
|
||||
)
|
||||
self._nixl_wrapper.transfer(xfer_h)
|
||||
xfer_entries.append((local_h, remote_h, xfer_h))
|
||||
|
||||
# Phase 3: wait for all in-flight transfers, then unpack.
|
||||
self._wait_for_all_transfers([x[2] for x in xfer_entries])
|
||||
|
||||
with torch.cuda.stream(self._cuda_stream):
|
||||
for (src, expert_id), offset in recv_offsets.items():
|
||||
self._unpack_recv_buffer(
|
||||
self._recv_buffer,
|
||||
self._recv_map[src][expert_id],
|
||||
offset,
|
||||
)
|
||||
finally:
|
||||
for local_h, remote_h, xfer_h in xfer_entries:
|
||||
for local_h, remote_h, xfer_h in self._xfer_entries:
|
||||
with contextlib.suppress(Exception):
|
||||
self._nixl_wrapper.release_xfer_handle(xfer_h)
|
||||
with contextlib.suppress(Exception):
|
||||
self._nixl_wrapper.release_dlist_handle(local_h)
|
||||
with contextlib.suppress(Exception):
|
||||
self._nixl_wrapper.release_dlist_handle(remote_h)
|
||||
self._expert_send_map.clear()
|
||||
self._recv_map.clear()
|
||||
self._xfer_entries.clear()
|
||||
self._expert_to_src_row = None
|
||||
self._layer_idx = None
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
if self._registered_desc is not None:
|
||||
self._nixl_wrapper.deregister_memory(self._registered_desc)
|
||||
self._registered_desc = None
|
||||
with contextlib.suppress(Exception):
|
||||
for local_h, remote_h, xfer_h in self._xfer_entries:
|
||||
with contextlib.suppress(Exception):
|
||||
self._nixl_wrapper.release_xfer_handle(xfer_h)
|
||||
with contextlib.suppress(Exception):
|
||||
self._nixl_wrapper.release_dlist_handle(local_h)
|
||||
with contextlib.suppress(Exception):
|
||||
self._nixl_wrapper.release_dlist_handle(remote_h)
|
||||
with contextlib.suppress(Exception):
|
||||
for descs in self._registered_descs:
|
||||
with contextlib.suppress(Exception):
|
||||
self._nixl_wrapper.deregister_memory(descs)
|
||||
self._registered_descs.clear()
|
||||
with contextlib.suppress(Exception):
|
||||
for agent_name in self._remote_agents.values():
|
||||
self._nixl_wrapper.remove_remote_agent(agent_name)
|
||||
with contextlib.suppress(Exception):
|
||||
self._nixl_wrapper.remove_remote_agent(agent_name)
|
||||
self._remote_agents.clear()
|
||||
except Exception as e:
|
||||
logger.warning("Error during NixlEplbCommunicator cleanup: %s", e)
|
||||
|
||||
|
||||
class PyNcclEplbCommunicator(EplbCommunicator):
|
||||
@@ -600,7 +609,7 @@ class PyNcclEplbCommunicator(EplbCommunicator):
|
||||
for tensor in tensors:
|
||||
self._pynccl_comm.recv(tensor, src_rank, stream=self._cuda_stream)
|
||||
|
||||
def execute(self, old_indices: np.ndarray | None = None) -> None:
|
||||
def execute(self) -> None:
|
||||
if self._group_started:
|
||||
self._pynccl_comm.group_end()
|
||||
self._group_started = False
|
||||
@@ -609,7 +618,8 @@ class PyNcclEplbCommunicator(EplbCommunicator):
|
||||
def create_eplb_communicator(
|
||||
group_coordinator: GroupCoordinator,
|
||||
backend: str | None,
|
||||
expert_weights: Sequence[torch.Tensor],
|
||||
expert_weights: Sequence[Sequence[torch.Tensor]],
|
||||
expert_buffer: Sequence[torch.Tensor],
|
||||
) -> EplbCommunicator:
|
||||
"""Create an EPLB communicator for the given backend.
|
||||
|
||||
@@ -624,16 +634,18 @@ def create_eplb_communicator(
|
||||
``"pynccl"`` in that case. When tensors reside on CPU,
|
||||
``"torch_gloo"`` or ``"torch_nccl"`` are used via the CPU
|
||||
process group.
|
||||
expert_weights: Expert weight tensors from *one* MoE layer.
|
||||
NixlEplbCommunicator pre-allocates send/recv buffers sized
|
||||
to this layer, so all other MoE layers must have the same
|
||||
tensor count, shapes, and dtypes.
|
||||
expert_weights: Expert weight tensors for *all* MoE layers.
|
||||
Shape ``(num_layers)(num_tensors_per_layer)``.
|
||||
NixlEplbCommunicator registers all layers with NIXL for
|
||||
zero-copy RDMA reads.
|
||||
expert_buffer: Pre-allocated receive buffer tensors (one per
|
||||
weight tensor in a single layer).
|
||||
"""
|
||||
# Keep a safe default for callers that have not resolved communicator yet.
|
||||
if backend is None:
|
||||
backend = "torch_nccl"
|
||||
|
||||
tensor_device_type = expert_weights[0].device.type if expert_weights else "cpu"
|
||||
first_layer = expert_weights[0] if expert_weights else []
|
||||
tensor_device_type = first_layer[0].device.type if first_layer else "cpu"
|
||||
torch_group = (
|
||||
group_coordinator.cpu_group
|
||||
if tensor_device_type == "cpu"
|
||||
@@ -649,7 +661,7 @@ def create_eplb_communicator(
|
||||
unsupported_dtypes = sorted(
|
||||
{
|
||||
tensor.dtype
|
||||
for tensor in expert_weights
|
||||
for tensor in first_layer
|
||||
if not ncclDataTypeEnum.supports_torch_dtype(tensor.dtype)
|
||||
},
|
||||
key=str,
|
||||
@@ -704,7 +716,8 @@ def create_eplb_communicator(
|
||||
try:
|
||||
return NixlEplbCommunicator(
|
||||
cpu_group=group_coordinator.cpu_group,
|
||||
expert_weights=expert_weights,
|
||||
all_expert_weights=expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(
|
||||
|
||||
@@ -450,7 +450,8 @@ class EplbState:
|
||||
communicator = create_eplb_communicator(
|
||||
group_coordinator=get_eplb_group(),
|
||||
backend=self.parallel_config.eplb_config.communicator,
|
||||
expert_weights=model.expert_weights[0],
|
||||
expert_weights=model.expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
)
|
||||
|
||||
model_state = EplbModelState(
|
||||
@@ -766,6 +767,7 @@ class EplbState:
|
||||
eplb_model_state.physical_to_logical_map,
|
||||
new_physical_to_logical_map,
|
||||
eplb_model_state.model.expert_weights,
|
||||
eplb_model_state.expert_buffer,
|
||||
ep_group,
|
||||
eplb_model_state.communicator,
|
||||
is_profile,
|
||||
|
||||
@@ -178,6 +178,7 @@ def move_to_buffer(
|
||||
cuda_stream: torch.cuda.Stream | None,
|
||||
ep_rank: int,
|
||||
communicator: EplbCommunicator,
|
||||
layer_idx: int = 0,
|
||||
) -> TransferMetadata:
|
||||
"""
|
||||
Rearranges expert weights during EPLB rebalancing.
|
||||
@@ -193,6 +194,7 @@ def move_to_buffer(
|
||||
cuda_stream: CUDA stream for async copies (can be None for sync mode).
|
||||
ep_rank: Rank of this process in expert parallel group.
|
||||
communicator: EplbCommunicator instance for P2P communication.
|
||||
layer_idx: Index of the MoE layer being transferred.
|
||||
|
||||
Returns:
|
||||
TransferMetadata: Metadata needed for completing remote weight transfers.
|
||||
@@ -265,6 +267,8 @@ def move_to_buffer(
|
||||
for w, b in zip(expert_weights, expert_weights_buffers):
|
||||
b[dst].copy_(w[src_local], non_blocking=True)
|
||||
|
||||
communicator.set_transfer_context(old_indices, layer_idx)
|
||||
|
||||
# 2. Post sends
|
||||
if send_count > 0:
|
||||
experts = send_expert_ids[:send_count]
|
||||
@@ -331,9 +335,8 @@ def move_to_buffer(
|
||||
expert_id=int(expert),
|
||||
)
|
||||
|
||||
# 4. Execute the P2P operations. The real communication happens here.
|
||||
communicator.execute(old_indices=old_indices)
|
||||
# wait for the communication to finish
|
||||
# 4. Execute transfers and wait for completion.
|
||||
communicator.execute()
|
||||
return TransferMetadata(
|
||||
is_unchanged=is_unchanged,
|
||||
is_received_locally=is_received_locally,
|
||||
@@ -431,6 +434,7 @@ def transfer_layer(
|
||||
is_profile: bool = False,
|
||||
cuda_stream: torch.cuda.Stream | None = None,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
layer_idx: int = 0,
|
||||
) -> TransferMetadata:
|
||||
"""
|
||||
Rearranges the expert weights in place according to the new expert indices.
|
||||
@@ -452,6 +456,7 @@ def transfer_layer(
|
||||
communications to reserve enough memory for the buffers.
|
||||
cuda_stream: CUDA stream for async copies (can be None for sync mode).
|
||||
rank_mapping: Optional rank mapping for elastic expert parallelism.
|
||||
layer_idx: Index of the MoE layer being transferred.
|
||||
|
||||
Returns:
|
||||
TransferMetadata: Metadata needed for completing remote weight transfers,
|
||||
@@ -499,6 +504,7 @@ def transfer_layer(
|
||||
cuda_stream=cuda_stream,
|
||||
ep_rank=ep_group.rank(),
|
||||
communicator=communicator,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
|
||||
@@ -506,6 +512,7 @@ def rearrange_expert_weights_inplace(
|
||||
old_global_expert_indices: torch.Tensor,
|
||||
new_global_expert_indices: torch.Tensor,
|
||||
expert_weights: Sequence[Sequence[torch.Tensor]],
|
||||
expert_buffer: Sequence[torch.Tensor],
|
||||
ep_group: ProcessGroup,
|
||||
communicator: EplbCommunicator,
|
||||
is_profile: bool = False,
|
||||
@@ -524,6 +531,8 @@ def rearrange_expert_weights_inplace(
|
||||
of tensors of shape (num_local_physical_experts, hidden_size_i).
|
||||
For example, a linear layer may have up and down projection,
|
||||
so weight_count = 2. Each weight's hidden size can be different.
|
||||
expert_buffer: Pre-allocated receive buffer tensors (one per
|
||||
weight tensor in a single layer).
|
||||
ep_group: The device process group for expert parallelism.
|
||||
communicator: EplbCommunicator instance for P2P communication.
|
||||
is_profile (bool): If `True`, do not perform any actual weight copy.
|
||||
@@ -566,10 +575,10 @@ def rearrange_expert_weights_inplace(
|
||||
# Reserve NCCL communication buffers via a dummy all_gather.
|
||||
# Backends that pre-allocate their own transfer buffers
|
||||
# skip this to avoid the extra memory spike during profiling.
|
||||
weights_buffer: list[torch.Tensor] = [
|
||||
profile_buffer: list[torch.Tensor] = [
|
||||
torch.empty_like(w) for w in first_layer_weights
|
||||
]
|
||||
for weight, buffer in zip(expert_weights[0], weights_buffer):
|
||||
for weight, buffer in zip(expert_weights[0], profile_buffer):
|
||||
dummy_recv_buffer = [buffer for _ in range(ep_size)]
|
||||
torch.distributed.barrier()
|
||||
all_gather(
|
||||
@@ -579,10 +588,7 @@ def rearrange_expert_weights_inplace(
|
||||
)
|
||||
return
|
||||
|
||||
# Buffers to hold the expert weights during the exchange.
|
||||
# NOTE: Currently we assume the same weights across different layers
|
||||
# have the same shape.
|
||||
weights_buffer = [torch.empty_like(w) for w in first_layer_weights]
|
||||
weights_buffer = list(expert_buffer)
|
||||
|
||||
old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy()
|
||||
new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy()
|
||||
@@ -597,6 +603,7 @@ def rearrange_expert_weights_inplace(
|
||||
cuda_stream=None,
|
||||
ep_rank=ep_rank,
|
||||
communicator=communicator,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
move_from_buffer(
|
||||
|
||||
@@ -64,6 +64,20 @@ def divide(numerator, denominator):
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
def is_weak_contiguous(inp: torch.Tensor) -> bool:
|
||||
"""Check that *inp* occupies a single contiguous block of memory.
|
||||
|
||||
Unlike ``torch.Tensor.is_contiguous()``, this also accepts tensors
|
||||
whose strides are not strictly C-contiguous (e.g. column-major) as
|
||||
long as the underlying storage from the tensor's offset onward is
|
||||
exactly ``numel * element_size`` bytes.
|
||||
"""
|
||||
return inp.is_contiguous() or (
|
||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size()
|
||||
)
|
||||
|
||||
|
||||
def split_tensor_along_last_dim(
|
||||
tensor: torch.Tensor,
|
||||
num_partitions: int,
|
||||
|
||||
Reference in New Issue
Block a user