From b0ed553028df09178f285f41a4d88ba1e5d0287d Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 3 Jun 2026 15:46:10 -0400 Subject: [PATCH] [Distributed] Add UCX one-shot AllReduce for DP metadata sync Replace Gloo TCP AllReduce with UCX tag-matching over IB RDMA for the per-iteration DP metadata synchronization in data-parallel mode. The one-shot allreduce sends each rank's full tensor to every other rank and locally sums, staying well within the RDMA eager threshold (no rendezvous handshake). This reduces P99 DP sync latency from ~100ms (Gloo/TCP) to ~0.1ms (UCX/RDMA) on InfiniBand clusters. Enabled by setting VLLM_DP_SYNC_BACKEND=ucx. Falls back to Gloo automatically if UCX is unavailable or initialization fails. New files: - ucx_dp_communicator.py: Python wrapper with auto-compile of C library - ucx_dp_sync.c: C-level UCX one-shot allreduce primitive - test_ucx_allreduce.py: Latency benchmark (Gloo vs UCX) - test_ucx_correctness.py: Multi-pattern correctness tests Modified files: - vllm/v1/worker/dp_utils.py: UCX path for V1 model runner - vllm/v1/worker/gpu/dp_utils.py: UCX path for V2 model runner Co-authored-by: Claude Signed-off-by: Tyler Michael Smith --- .../test_ucx_allreduce.py | 143 ++++++++++ .../test_ucx_correctness.py | 200 +++++++++++++ .../ucx_dp_communicator.py | 263 ++++++++++++++++++ .../device_communicators/ucx_dp_sync.c | 251 +++++++++++++++++ vllm/v1/worker/dp_utils.py | 76 ++++- vllm/v1/worker/gpu/dp_utils.py | 61 +++- 6 files changed, 989 insertions(+), 5 deletions(-) create mode 100644 vllm/distributed/device_communicators/test_ucx_allreduce.py create mode 100644 vllm/distributed/device_communicators/test_ucx_correctness.py create mode 100644 vllm/distributed/device_communicators/ucx_dp_communicator.py create mode 100644 vllm/distributed/device_communicators/ucx_dp_sync.c diff --git a/vllm/distributed/device_communicators/test_ucx_allreduce.py b/vllm/distributed/device_communicators/test_ucx_allreduce.py new file mode 100644 index 00000000000..51c1b5f34ed --- /dev/null +++ b/vllm/distributed/device_communicators/test_ucx_allreduce.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Standalone test for UCX DP allreduce. + +Single-node quick test with 4 ranks: + torchrun --nproc-per-node=4 \ + vllm/distributed/device_communicators/test_ucx_allreduce.py + +Cross-node test (4 nodes x 4 ranks): + torchrun --nproc-per-node=4 --nnodes=4 \ + --master-addr=$MASTER_ADDR --master-port=29500 \ + --node-rank=$NODE_RANK \ + vllm/distributed/device_communicators/test_ucx_allreduce.py +""" + +import os +import sys +import time + +import torch +import torch.distributed as dist + + +def test_gloo_baseline(rank, world_size, group, n_iters=100): + """Measure Gloo TCP allreduce latency.""" + tensor = torch.zeros(4, world_size, dtype=torch.int32) + # warmup + for _ in range(10): + tensor.zero_() + tensor[0][rank] = 1 + dist.all_reduce(tensor, group=group) + + latencies = [] + for _ in range(n_iters): + tensor.zero_() + tensor[0][rank] = rank + 1 + tensor[1][rank] = (rank + 1) * 10 + tensor[2][rank] = 1 + tensor[3][rank] = 2 + t0 = time.monotonic() + dist.all_reduce(tensor, group=group) + latencies.append(time.monotonic() - t0) + + for i in range(world_size): + assert tensor[0][i].item() == i + 1, f"Gloo: rank {i} col 0 wrong" + assert tensor[1][i].item() == (i + 1) * 10, f"Gloo: rank {i} col 1 wrong" + + return latencies + + +def test_ucx_allreduce(rank, world_size, gloo_group, n_iters=100): + """Measure UCX RDMA allreduce latency.""" + here = os.path.dirname(os.path.abspath(__file__)) + sys.path.insert(0, here) + from ucx_dp_communicator import UCXDPCommunicator + + comm = UCXDPCommunicator(rank, world_size, max_msg_bytes=1024) + comm.bootstrap(gloo_group) + print(f"[rank {rank}] UCX communicator initialized") + + tensor = torch.zeros(4, world_size, dtype=torch.int32) + # warmup + for _ in range(10): + tensor.zero_() + tensor[0][rank] = 1 + comm.allreduce_inplace(tensor) + + latencies = [] + for _ in range(n_iters): + tensor.zero_() + tensor[0][rank] = rank + 1 + tensor[1][rank] = (rank + 1) * 10 + tensor[2][rank] = 1 + tensor[3][rank] = 2 + t0 = time.monotonic() + comm.allreduce_inplace(tensor) + latencies.append(time.monotonic() - t0) + + for i in range(world_size): + assert tensor[0][i].item() == i + 1, ( + f"UCX: rank {i} col 0 = {tensor[0][i].item()}, expected {i + 1}" + ) + assert tensor[1][i].item() == (i + 1) * 10, ( + f"UCX: rank {i} col 1 = {tensor[1][i].item()}, expected {(i + 1) * 10}" + ) + + comm.finalize() + return latencies + + +def percentile(data, p): + data = sorted(data) + idx = int(len(data) * p / 100) + return data[min(idx, len(data) - 1)] + + +def print_stats(name, latencies, rank): + if latencies is None: + print(f"[rank {rank}] {name}: SKIPPED") + return + us = [t * 1e6 for t in latencies] + print( + f"[rank {rank}] {name}: " + f"p50={percentile(us, 50):.1f}us " + f"p95={percentile(us, 95):.1f}us " + f"p99={percentile(us, 99):.1f}us " + f"mean={sum(us) / len(us):.1f}us " + f"min={min(us):.1f}us " + f"max={max(us):.1f}us" + ) + + +def main(): + dist.init_process_group(backend="gloo") + rank = dist.get_rank() + world_size = dist.get_world_size() + gloo_group = dist.group.WORLD + + print(f"[rank {rank}] world_size={world_size}") + + n_iters = int(os.environ.get("TEST_ITERS", "200")) + + gloo_latencies = test_gloo_baseline(rank, world_size, gloo_group, n_iters) + print_stats("Gloo TCP", gloo_latencies, rank) + + ucx_latencies = test_ucx_allreduce(rank, world_size, gloo_group, n_iters) + print_stats("UCX RDMA", ucx_latencies, rank) + + if ucx_latencies and rank == 0: + gloo_p50 = percentile(gloo_latencies, 50) * 1e6 + ucx_p50 = percentile(ucx_latencies, 50) * 1e6 + gloo_p99 = percentile(gloo_latencies, 99) * 1e6 + ucx_p99 = percentile(ucx_latencies, 99) * 1e6 + print("\n=== Speedup ===") + print(f" P50: {gloo_p50:.1f}us -> {ucx_p50:.1f}us ({gloo_p50 / ucx_p50:.1f}x)") + print(f" P99: {gloo_p99:.1f}us -> {ucx_p99:.1f}us ({gloo_p99 / ucx_p99:.1f}x)") + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/vllm/distributed/device_communicators/test_ucx_correctness.py b/vllm/distributed/device_communicators/test_ucx_correctness.py new file mode 100644 index 00000000000..d5942482f22 --- /dev/null +++ b/vllm/distributed/device_communicators/test_ucx_correctness.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Correctness test for UCX allreduce — exercises cross-node paths. + +Tests multiple patterns: +1. Each rank writes its column (the actual DP sync pattern) +2. Each rank writes a unique value — verifies every peer's data arrives +3. Sequential rounds — verifies round counter stays in sync +4. Stress test with many rapid calls + +Single-node quick test: + torchrun --nproc-per-node=4 test_ucx_correctness.py + +Cross-node (4 nodes x 4 ranks): + torchrun --nproc-per-node=4 --nnodes=4 \ + --master-addr=$MASTER_ADDR --master-port=29500 \ + --node-rank=$NODE_RANK test_ucx_correctness.py +""" + +import os +import sys +import time + +import torch +import torch.distributed as dist + + +def load_ucx_communicator(rank, world_size, gloo_group): + here = os.path.dirname(os.path.abspath(__file__)) + sys.path.insert(0, here) + from ucx_dp_communicator import UCXDPCommunicator + + comm = UCXDPCommunicator(rank, world_size, max_msg_bytes=1024) + comm.bootstrap(gloo_group) + return comm + + +def test_column_pattern(comm, rank, world_size, n_iters=50): + """The actual DP sync pattern: each rank fills its column.""" + failures = 0 + for it in range(n_iters): + tensor = torch.zeros(4, world_size, dtype=torch.int32) + tensor[0][rank] = rank + 1 + tensor[1][rank] = (rank + 1) * 10 + tensor[2][rank] = 1 + tensor[3][rank] = it % 3 + + comm.allreduce_inplace(tensor) + + for r in range(world_size): + expected = [r + 1, (r + 1) * 10, 1, it % 3] + actual = [tensor[row][r].item() for row in range(4)] + if actual != expected: + print( + f"[rank {rank}] FAIL iter={it} col={r}: " + f"expected {expected} got {actual}" + ) + failures += 1 + if failures > 5: + return failures + return failures + + +def test_unique_values(comm, rank, world_size, n_iters=50): + """Each rank writes a unique value in every cell of its + column.""" + failures = 0 + for it in range(n_iters): + tensor = torch.zeros(4, world_size, dtype=torch.int32) + val = rank * 1000 + it + for row in range(4): + tensor[row][rank] = val + row + + comm.allreduce_inplace(tensor) + + for r in range(world_size): + base = r * 1000 + it + for row in range(4): + expected = base + row + actual = tensor[row][r].item() + if actual != expected: + print( + f"[rank {rank}] FAIL iter={it} " + f"row={row} col={r}: " + f"expected {expected} got {actual}" + ) + failures += 1 + if failures > 5: + return failures + return failures + + +def test_gloo_comparison(comm, rank, world_size, gloo_group, n_iters=50): + """Run same allreduce via both UCX and Gloo, compare.""" + failures = 0 + for it in range(n_iters): + tensor_ucx = torch.zeros(4, world_size, dtype=torch.int32) + tensor_gloo = torch.zeros(4, world_size, dtype=torch.int32) + + val = rank * 100 + it + for row in range(4): + tensor_ucx[row][rank] = val + row + tensor_gloo[row][rank] = val + row + + comm.allreduce_inplace(tensor_ucx) + dist.all_reduce(tensor_gloo, group=gloo_group) + + if not torch.equal(tensor_ucx, tensor_gloo): + diff_mask = tensor_ucx != tensor_gloo + print( + f"[rank {rank}] FAIL iter={it}: " + f"UCX != Gloo\n" + f" UCX: {tensor_ucx.tolist()}\n" + f" Gloo: {tensor_gloo.tolist()}\n" + f" Diff: {diff_mask.nonzero().tolist()}" + ) + failures += 1 + if failures > 5: + return failures + return failures + + +def test_rapid_fire(comm, rank, world_size, n_iters=200): + """Many rapid consecutive allreduces — stress the round + counter.""" + failures = 0 + for it in range(n_iters): + tensor = torch.zeros(4, world_size, dtype=torch.int32) + tensor[0][rank] = it + 1 + + comm.allreduce_inplace(tensor) + + total = tensor[0].sum().item() + expected = (it + 1) * world_size + if total != expected: + print( + f"[rank {rank}] FAIL iter={it}: " + f"sum={total} expected={expected} " + f"row0={tensor[0].tolist()}" + ) + failures += 1 + if failures > 5: + return failures + return failures + + +def main(): + dist.init_process_group(backend="gloo") + rank = dist.get_rank() + world_size = dist.get_world_size() + gloo_group = dist.group.WORLD + + print(f"[rank {rank}] world_size={world_size}") + + comm = load_ucx_communicator(rank, world_size, gloo_group) + print(f"[rank {rank}] UCX communicator ready") + + tests = [ + ( + "column_pattern", + lambda: test_column_pattern(comm, rank, world_size), + ), + ( + "unique_values", + lambda: test_unique_values(comm, rank, world_size), + ), + ( + "gloo_comparison", + lambda: test_gloo_comparison(comm, rank, world_size, gloo_group), + ), + ( + "rapid_fire", + lambda: test_rapid_fire(comm, rank, world_size), + ), + ] + + all_pass = True + for name, fn in tests: + dist.barrier(group=gloo_group) + t0 = time.monotonic() + failures = fn() + elapsed = time.monotonic() - t0 + status = "PASS" if failures == 0 else f"FAIL ({failures} failures)" + print(f"[rank {rank}] {name}: {status} ({elapsed:.2f}s)") + if failures > 0: + all_pass = False + + comm.finalize() + dist.barrier(group=gloo_group) + + if rank == 0: + result = "ALL TESTS PASSED" if all_pass else "SOME TESTS FAILED" + print(f"\n{result}") + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/vllm/distributed/device_communicators/ucx_dp_communicator.py b/vllm/distributed/device_communicators/ucx_dp_communicator.py new file mode 100644 index 00000000000..7424ba39ba8 --- /dev/null +++ b/vllm/distributed/device_communicators/ucx_dp_communicator.py @@ -0,0 +1,263 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +UCX-based one-shot AllReduce for DP metadata synchronization. + +Replaces Gloo's TCP AllReduce with UCX tag-matching over IB RDMA. +Falls back to Gloo if UCX is unavailable or init fails. + +Usage: + from vllm.distributed.device_communicators.ucx_dp_communicator import ( + try_init_ucx_dp, + get_ucx_dp_communicator, + ) + + # Once, after DP group is created: + try_init_ucx_dp(dp_rank, dp_size, gloo_group) + + # Per-iteration: + comm = get_ucx_dp_communicator() + if comm is not None: + comm.allreduce_inplace(tensor) # UCX/RDMA path + else: + dist.all_reduce(tensor, group=gloo_group) # fallback +""" + +import ctypes +import ctypes.util +import logging +import os +import subprocess +import threading + +import torch +import torch.distributed as dist + +logger = logging.getLogger(__name__) + +_communicator: "UCXDPCommunicator | None" = None +_init_lock = threading.Lock() + + +def get_ucx_dp_communicator() -> "UCXDPCommunicator | None": + return _communicator + + +def try_init_ucx_dp( + rank: int, + world_size: int, + gloo_group, + max_msg_bytes: int = 1024, +) -> bool: + global _communicator + with _init_lock: + if _communicator is not None: + return True + try: + comm = UCXDPCommunicator(rank, world_size, max_msg_bytes) + comm.bootstrap(gloo_group) + _communicator = comm + logger.info( + "UCX DP communicator ready (rank %d/%d, RDMA allreduce)", + rank, + world_size, + ) + return True + except Exception: + logger.warning( + "UCX DP communicator unavailable, using Gloo fallback", + exc_info=True, + ) + return False + + +class UCXDPCommunicator: + def __init__( + self, + rank: int, + world_size: int, + max_msg_bytes: int = 1024, + ): + self.rank = rank + self.world_size = world_size + self._lib = _load_library() + self._state = ctypes.c_void_p() + + addr_ptr = ctypes.c_void_p() + addr_len = ctypes.c_size_t() + rc = self._lib.ucx_dp_init( + rank, + world_size, + max_msg_bytes, + ctypes.byref(self._state), + ctypes.byref(addr_ptr), + ctypes.byref(addr_len), + ) + if rc != 0: + raise RuntimeError("ucx_dp_init failed") + + assert addr_ptr.value is not None + self._address = ctypes.string_at(addr_ptr.value, addr_len.value) + self._lib.ucx_dp_release_address(self._state, addr_ptr) + + def bootstrap(self, gloo_group) -> None: + """Exchange UCX worker addresses via existing Gloo group.""" + all_addrs: list = [None] * self.world_size + dist.all_gather_object(all_addrs, self._address, group=gloo_group) + + for i, addr in enumerate(all_addrs): + if i == self.rank: + continue + rc = self._lib.ucx_dp_connect( + self._state, + i, + addr, + len(addr), + ) + if rc != 0: + raise RuntimeError(f"ucx_dp_connect to rank {i} failed") + + def allreduce_inplace(self, tensor: torch.Tensor) -> None: + """In-place sum allreduce of a contiguous CPU int32 + tensor.""" + assert tensor.is_contiguous() and tensor.device.type == "cpu" + nbytes = tensor.nelement() * tensor.element_size() + rc = self._lib.ucx_dp_allreduce_inplace( + self._state, + ctypes.c_void_p(tensor.data_ptr()), + nbytes, + ) + if rc != 0: + raise RuntimeError("ucx_dp_allreduce_inplace failed") + + def finalize(self) -> None: + if self._state: + self._lib.ucx_dp_finalize(self._state) + self._state = ctypes.c_void_p() + + def __del__(self): + self.finalize() + + +# ---- library loading ---- + +_lib_cache: ctypes.CDLL | None = None + + +def _load_library() -> ctypes.CDLL: + global _lib_cache + if _lib_cache is not None: + return _lib_cache + + so_path = _find_or_compile() + if so_path is None: + raise RuntimeError( + "Cannot find or compile _ucx_dp_sync.so. " + "Place ucx_dp_sync.c next to this file, or set " + "VLLM_UCX_DP_LIB to the .so path." + ) + + lib = ctypes.CDLL(so_path) + _setup_signatures(lib) + _lib_cache = lib + return lib + + +def _setup_signatures(lib: ctypes.CDLL) -> None: + c_int = ctypes.c_int + c_size_t = ctypes.c_size_t + c_void_p = ctypes.c_void_p + POINTER = ctypes.POINTER + + lib.ucx_dp_init.restype = c_int + lib.ucx_dp_init.argtypes = [ + c_int, + c_int, + c_size_t, + POINTER(c_void_p), + POINTER(c_void_p), + POINTER(c_size_t), + ] + + lib.ucx_dp_release_address.restype = None + lib.ucx_dp_release_address.argtypes = [c_void_p, c_void_p] + + lib.ucx_dp_connect.restype = c_int + lib.ucx_dp_connect.argtypes = [ + c_void_p, + c_int, + c_void_p, + c_size_t, + ] + + lib.ucx_dp_allreduce_inplace.restype = c_int + lib.ucx_dp_allreduce_inplace.argtypes = [ + c_void_p, + c_void_p, + c_size_t, + ] + + lib.ucx_dp_finalize.restype = None + lib.ucx_dp_finalize.argtypes = [c_void_p] + + +def _find_or_compile() -> str | None: + env_path = os.environ.get("VLLM_UCX_DP_LIB") + if env_path and os.path.isfile(env_path): + return env_path + + here = os.path.dirname(os.path.abspath(__file__)) + so_path = os.path.join(here, "_ucx_dp_sync.so") + if os.path.isfile(so_path): + return so_path + + c_path = os.path.join(here, "ucx_dp_sync.c") + if not os.path.isfile(c_path): + return None + + if not _has_ucp(): + logger.warning("libucp.so not found — cannot compile UCX DP sync") + return None + + logger.info("Compiling _ucx_dp_sync.so …") + cflags: list[str] = [] + ldflags: list[str] = [] + try: + pc = subprocess.run( + ["pkg-config", "--cflags", "--libs", "ucx"], + capture_output=True, + text=True, + timeout=5, + ) + if pc.returncode == 0: + for tok in pc.stdout.strip().split(): + if tok.startswith(("-I", "-D")): + cflags.append(tok) + else: + ldflags.append(tok) + except Exception: + pass + + cmd = [ + "gcc", + "-shared", + "-fPIC", + "-O2", + *cflags, + "-o", + so_path, + c_path, + "-lucp", + "-lucs", + *ldflags, + ] + result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) + if result.returncode != 0: + logger.warning("Compile failed: %s", result.stderr) + return None + + return so_path + + +def _has_ucp() -> bool: + return ctypes.util.find_library("ucp") is not None diff --git a/vllm/distributed/device_communicators/ucx_dp_sync.c b/vllm/distributed/device_communicators/ucx_dp_sync.c new file mode 100644 index 00000000000..00a444ca377 --- /dev/null +++ b/vllm/distributed/device_communicators/ucx_dp_sync.c @@ -0,0 +1,251 @@ +/* + * ucx_dp_sync.c — One-shot AllReduce for vLLM DP metadata over UCX/RDMA + * + * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright contributors to the vLLM project + * + * Replaces Gloo TCP AllReduce (~100ms P99) with UCX tag-matching over + * InfiniBand RDMA (~0.1ms P99) for the per-iteration DP metadata sync. + * + * Build: + * gcc -shared -fPIC -O2 -o _ucx_dp_sync.so ucx_dp_sync.c -lucp -lucs + * + * The allreduce is one-shot: every rank sends its full tensor to every + * other rank, then locally sums. For 256 bytes x 15 peers this is 3.8KB + * total — well within RDMA eager threshold, so all sends complete in a + * single round with no rendezvous handshake. + */ + +#include +#include +#include +#include +#include +#include + +#define MAX_RANKS 256 + +typedef struct { + ucp_context_h ctx; + ucp_worker_h worker; + ucp_ep_h *eps; /* [world_size], NULL for self */ + int rank; + int world_size; + uint64_t round; /* monotonic counter for tag uniqueness */ + uint8_t **recv_bufs; /* [world_size] pre-allocated */ + uint8_t *send_staging; + size_t max_bytes; +} ucx_dp_state_t; + +/* ---- request completion ---- */ + +static void req_init(void *request) { + *(int *)request = 0; +} + +static void send_cb(void *request, ucs_status_t status, void *user_data) { + *(int *)request = 1; +} + +static void recv_cb(void *request, ucs_status_t status, + const ucp_tag_recv_info_t *info, void *user_data) { + *(int *)request = 1; +} + +/* ---- public API ---- */ + +int ucx_dp_init(int rank, int world_size, size_t max_bytes, + void **state_out, void **addr_out, size_t *addr_len_out) { + ucs_status_t st; + + if (world_size > MAX_RANKS) return -1; + + ucx_dp_state_t *s = (ucx_dp_state_t *)calloc(1, sizeof(*s)); + if (!s) return -1; + s->rank = rank; + s->world_size = world_size; + s->max_bytes = max_bytes; + + /* context */ + ucp_config_t *config; + st = ucp_config_read(NULL, NULL, &config); + if (st != UCS_OK) goto fail_s; + + ucp_params_t params; + memset(¶ms, 0, sizeof(params)); + params.field_mask = UCP_PARAM_FIELD_FEATURES + | UCP_PARAM_FIELD_REQUEST_SIZE + | UCP_PARAM_FIELD_REQUEST_INIT; + params.features = UCP_FEATURE_TAG; + params.request_size = sizeof(int); + params.request_init = req_init; + + st = ucp_init(¶ms, config, &s->ctx); + ucp_config_release(config); + if (st != UCS_OK) goto fail_s; + + /* worker */ + ucp_worker_params_t wp; + memset(&wp, 0, sizeof(wp)); + wp.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; + wp.thread_mode = UCS_THREAD_MODE_SINGLE; + + st = ucp_worker_create(s->ctx, &wp, &s->worker); + if (st != UCS_OK) goto fail_ctx; + + /* worker address */ + ucp_address_t *addr; + size_t addr_len; + st = ucp_worker_get_address(s->worker, &addr, &addr_len); + if (st != UCS_OK) goto fail_worker; + + /* buffers */ + s->eps = (ucp_ep_h *)calloc(world_size, sizeof(ucp_ep_h)); + s->recv_bufs = (uint8_t **)calloc(world_size, sizeof(uint8_t *)); + s->send_staging = (uint8_t *)malloc(max_bytes); + for (int i = 0; i < world_size; i++) { + if (i != rank) + s->recv_bufs[i] = (uint8_t *)malloc(max_bytes); + } + + *state_out = s; + *addr_out = addr; /* caller copies, then calls release */ + *addr_len_out = addr_len; + return 0; + +fail_worker: ucp_worker_destroy(s->worker); +fail_ctx: ucp_cleanup(s->ctx); +fail_s: free(s); + return -1; +} + +void ucx_dp_release_address(void *state, void *addr) { + ucx_dp_state_t *s = (ucx_dp_state_t *)state; + ucp_worker_release_address(s->worker, (ucp_address_t *)addr); +} + +int ucx_dp_connect(void *state, int peer, const void *addr, size_t len) { + ucx_dp_state_t *s = (ucx_dp_state_t *)state; + + ucp_ep_params_t ep; + memset(&ep, 0, sizeof(ep)); + ep.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; + ep.address = (const ucp_address_t *)addr; + + ucs_status_t st = ucp_ep_create(s->worker, &ep, &s->eps[peer]); + return (st == UCS_OK) ? 0 : -1; +} + +int ucx_dp_allreduce_inplace(void *state, void *buf, size_t nbytes) { + ucx_dp_state_t *s = (ucx_dp_state_t *)state; + if (nbytes > s->max_bytes) return -1; + + uint64_t round = s->round++; + int ws = s->world_size; + + memcpy(s->send_staging, buf, nbytes); + + ucs_status_ptr_t recv_reqs[MAX_RANKS]; + ucs_status_ptr_t send_reqs[MAX_RANKS]; + + /* post all receives */ + for (int i = 0; i < ws; i++) { + if (i == s->rank) { recv_reqs[i] = NULL; continue; } + + ucp_request_param_t p; + memset(&p, 0, sizeof(p)); + p.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK + | UCP_OP_ATTR_FIELD_FLAGS; + p.cb.recv = recv_cb; + p.flags = UCP_OP_ATTR_FLAG_NO_IMM_CMPL; + + ucp_tag_t tag = (round << 16) | (uint64_t)i; + recv_reqs[i] = ucp_tag_recv_nbx(s->worker, s->recv_bufs[i], + nbytes, tag, ~(ucp_tag_t)0, &p); + } + + /* post all sends */ + ucp_tag_t my_tag = (round << 16) | (uint64_t)s->rank; + for (int i = 0; i < ws; i++) { + if (i == s->rank) { send_reqs[i] = NULL; continue; } + + ucp_request_param_t p; + memset(&p, 0, sizeof(p)); + p.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK; + p.cb.send = send_cb; + + send_reqs[i] = ucp_tag_send_nbx(s->eps[i], s->send_staging, + nbytes, my_tag, &p); + } + + /* progress until every request completes */ + for (;;) { + ucp_worker_progress(s->worker); + int done = 1; + for (int i = 0; i < ws; i++) { + if (i == s->rank) continue; + if (recv_reqs[i] && !UCS_PTR_IS_ERR(recv_reqs[i])) { + if (ucp_request_check_status(recv_reqs[i]) == UCS_INPROGRESS) { + done = 0; break; + } + } + if (send_reqs[i] && !UCS_PTR_IS_ERR(send_reqs[i])) { + if (ucp_request_check_status(send_reqs[i]) == UCS_INPROGRESS) { + done = 0; break; + } + } + } + if (done) break; + } + + /* free requests */ + for (int i = 0; i < ws; i++) { + if (i == s->rank) continue; + if (recv_reqs[i] && !UCS_PTR_IS_ERR(recv_reqs[i])) + ucp_request_free(recv_reqs[i]); + if (send_reqs[i] && !UCS_PTR_IS_ERR(send_reqs[i])) + ucp_request_free(send_reqs[i]); + } + + /* + * Memory fence: ensure all recv buffer writes from UCX are + * visible before we read them for the reduction. On ARM + * (aarch64) the store from the transport thread and our read + * may not be ordered without this. + */ + __sync_synchronize(); + + /* local reduce: buf = local + sum(received) */ + int32_t *out = (int32_t *)buf; + int32_t *local = (int32_t *)s->send_staging; + int count = (int)(nbytes / sizeof(int32_t)); + memcpy(out, local, nbytes); + for (int i = 0; i < ws; i++) { + if (i == s->rank) continue; + int32_t *peer = (int32_t *)s->recv_bufs[i]; + for (int j = 0; j < count; j++) + out[j] += peer[j]; + } + + return 0; +} + +void ucx_dp_finalize(void *state) { + ucx_dp_state_t *s = (ucx_dp_state_t *)state; + if (!s) return; + + /* flush pending endpoint ops */ + for (int i = 0; i < 64; i++) + ucp_worker_progress(s->worker); + + if (s->worker) ucp_worker_destroy(s->worker); + if (s->ctx) ucp_cleanup(s->ctx); + + for (int i = 0; i < s->world_size; i++) { + if (s->recv_bufs && s->recv_bufs[i]) free(s->recv_bufs[i]); + } + free(s->eps); + free(s->recv_bufs); + free(s->send_staging); + free(s); +} diff --git a/vllm/v1/worker/dp_utils.py b/vllm/v1/worker/dp_utils.py index e7c6d81a992..63e568c788d 100644 --- a/vllm/v1/worker/dp_utils.py +++ b/vllm/v1/worker/dp_utils.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import threading +import time + import torch import torch.distributed as dist @@ -14,6 +17,53 @@ from vllm.v1.worker.ubatch_utils import ( logger = init_logger(__name__) +_dp_sync_stats_lock = threading.Lock() +_dp_sync_stats: list[float] = [] + +_ucx_init_attempted = False + + +def _maybe_init_ucx(parallel_config: ParallelConfig) -> None: + """Lazily initialize the UCX DP communicator on first use. + + Gated on VLLM_DP_SYNC_BACKEND=ucx. Falls back to Gloo on + failure. + """ + global _ucx_init_attempted + if _ucx_init_attempted: + return + _ucx_init_attempted = True + + import os + + if os.environ.get("VLLM_DP_SYNC_BACKEND", "").lower() != "ucx": + return + + try: + from vllm.distributed.device_communicators.ucx_dp_communicator import ( + try_init_ucx_dp, + ) + + gloo_group = get_dp_group().cpu_group + try_init_ucx_dp( + rank=parallel_config.data_parallel_rank, + world_size=parallel_config.data_parallel_size, + gloo_group=gloo_group, + max_msg_bytes=1024, + ) + except Exception: + logger.warning("UCX DP init failed, using Gloo", exc_info=True) + + +def get_dp_sync_stats() -> list[float] | None: + """Return and clear the list of DP sync latencies.""" + with _dp_sync_stats_lock: + if not _dp_sync_stats: + return None + result = list(_dp_sync_stats) + _dp_sync_stats.clear() + return result + def _get_device_and_group(parallel_config: ParallelConfig): # Use the actual device assigned to the DP group, not just the device type @@ -40,17 +90,37 @@ def _run_ar( cudagraph_mode: int, parallel_config: ParallelConfig, ) -> torch.Tensor: + _maybe_init_ucx(parallel_config) + dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank - device, group = _get_device_and_group(parallel_config) + # Populate this rank's contribution on CPU to reduce GPU syncs. tensor_cpu = torch.zeros(4, dp_size, dtype=torch.int32) tensor_cpu[0][dp_rank] = orig_num_tokens_per_ubatch tensor_cpu[1][dp_rank] = padded_num_tokens_per_ubatch tensor_cpu[2][dp_rank] = 1 if should_ubatch else 0 tensor_cpu[3][dp_rank] = cudagraph_mode - tensor = tensor_cpu.to(device, non_blocking=True) - dist.all_reduce(tensor, group=group) + + t0 = time.monotonic() + + from vllm.distributed.device_communicators.ucx_dp_communicator import ( + get_ucx_dp_communicator, + ) + + ucx = get_ucx_dp_communicator() + if ucx is not None: + ucx.allreduce_inplace(tensor_cpu) + tensor = tensor_cpu + else: + device, group = _get_device_and_group(parallel_config) + tensor = tensor_cpu.to(device, non_blocking=True) + dist.all_reduce(tensor, group=group) + + elapsed = time.monotonic() - t0 + with _dp_sync_stats_lock: + _dp_sync_stats.append(elapsed) + return tensor diff --git a/vllm/v1/worker/gpu/dp_utils.py b/vllm/v1/worker/gpu/dp_utils.py index b3c172738c3..6c1a14767ca 100644 --- a/vllm/v1/worker/gpu/dp_utils.py +++ b/vllm/v1/worker/gpu/dp_utils.py @@ -2,16 +2,55 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from __future__ import annotations +import os +import time + import torch import torch.distributed as dist from vllm.config.compilation import CUDAGraphMode from vllm.distributed.parallel_state import get_dp_group +from vllm.logger import init_logger +from vllm.v1.worker.dp_utils import _dp_sync_stats, _dp_sync_stats_lock from vllm.v1.worker.gpu.cudagraph_utils import ( BatchExecutionDescriptor, CudaGraphManager, ) +logger = init_logger(__name__) + +_ucx_init_attempted = False + + +def _maybe_init_ucx(dp_rank: int, dp_size: int) -> None: + """Lazily initialize the UCX DP communicator on first use. + + Gated on VLLM_DP_SYNC_BACKEND=ucx. Falls back to Gloo on + failure. + """ + global _ucx_init_attempted + if _ucx_init_attempted: + return + _ucx_init_attempted = True + + if os.environ.get("VLLM_DP_SYNC_BACKEND", "").lower() != "ucx": + return + + try: + from vllm.distributed.device_communicators.ucx_dp_communicator import ( + try_init_ucx_dp, + ) + + gloo_group = get_dp_group().cpu_group + try_init_ucx_dp( + rank=dp_rank, + world_size=dp_size, + gloo_group=gloo_group, + max_msg_bytes=1024, + ) + except Exception: + logger.warning("UCX DP init failed, using Gloo", exc_info=True) + def sync_cudagraph_and_dp_padding( cudagraph_manager: CudaGraphManager | None, @@ -28,12 +67,30 @@ def sync_cudagraph_and_dp_padding( Returns (synced_batch_desc, num_tokens_across_dp). """ assert dp_size > 1, "DP size must be greater than 1" - group = get_dp_group().cpu_group + + _maybe_init_ucx(dp_rank, dp_size) + tensor = torch.zeros(3, dp_size, dtype=torch.int32, device="cpu") tensor[0][dp_rank] = num_tokens tensor[1][dp_rank] = desired_batch_desc.cg_mode.value tensor[2][dp_rank] = uniform_token_count or 0 # (0 means None) - dist.all_reduce(tensor, group=group) + + t0 = time.monotonic() + + from vllm.distributed.device_communicators.ucx_dp_communicator import ( + get_ucx_dp_communicator, + ) + + ucx = get_ucx_dp_communicator() + if ucx is not None: + ucx.allreduce_inplace(tensor) + else: + group = get_dp_group().cpu_group + dist.all_reduce(tensor, group=group) + + elapsed = time.monotonic() - t0 + with _dp_sync_stats_lock: + _dp_sync_stats.append(elapsed) num_tokens_across_dp = tensor[0] cg_mode_across_dp = tensor[1]