diff --git a/vllm/distributed/device_communicators/test_ucx_allreduce.py b/vllm/distributed/device_communicators/test_ucx_allreduce.py new file mode 100644 index 00000000000..51c1b5f34ed --- /dev/null +++ b/vllm/distributed/device_communicators/test_ucx_allreduce.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Standalone test for UCX DP allreduce. + +Single-node quick test with 4 ranks: + torchrun --nproc-per-node=4 \ + vllm/distributed/device_communicators/test_ucx_allreduce.py + +Cross-node test (4 nodes x 4 ranks): + torchrun --nproc-per-node=4 --nnodes=4 \ + --master-addr=$MASTER_ADDR --master-port=29500 \ + --node-rank=$NODE_RANK \ + vllm/distributed/device_communicators/test_ucx_allreduce.py +""" + +import os +import sys +import time + +import torch +import torch.distributed as dist + + +def test_gloo_baseline(rank, world_size, group, n_iters=100): + """Measure Gloo TCP allreduce latency.""" + tensor = torch.zeros(4, world_size, dtype=torch.int32) + # warmup + for _ in range(10): + tensor.zero_() + tensor[0][rank] = 1 + dist.all_reduce(tensor, group=group) + + latencies = [] + for _ in range(n_iters): + tensor.zero_() + tensor[0][rank] = rank + 1 + tensor[1][rank] = (rank + 1) * 10 + tensor[2][rank] = 1 + tensor[3][rank] = 2 + t0 = time.monotonic() + dist.all_reduce(tensor, group=group) + latencies.append(time.monotonic() - t0) + + for i in range(world_size): + assert tensor[0][i].item() == i + 1, f"Gloo: rank {i} col 0 wrong" + assert tensor[1][i].item() == (i + 1) * 10, f"Gloo: rank {i} col 1 wrong" + + return latencies + + +def test_ucx_allreduce(rank, world_size, gloo_group, n_iters=100): + """Measure UCX RDMA allreduce latency.""" + here = os.path.dirname(os.path.abspath(__file__)) + sys.path.insert(0, here) + from ucx_dp_communicator import UCXDPCommunicator + + comm = UCXDPCommunicator(rank, world_size, max_msg_bytes=1024) + comm.bootstrap(gloo_group) + print(f"[rank {rank}] UCX communicator initialized") + + tensor = torch.zeros(4, world_size, dtype=torch.int32) + # warmup + for _ in range(10): + tensor.zero_() + tensor[0][rank] = 1 + comm.allreduce_inplace(tensor) + + latencies = [] + for _ in range(n_iters): + tensor.zero_() + tensor[0][rank] = rank + 1 + tensor[1][rank] = (rank + 1) * 10 + tensor[2][rank] = 1 + tensor[3][rank] = 2 + t0 = time.monotonic() + comm.allreduce_inplace(tensor) + latencies.append(time.monotonic() - t0) + + for i in range(world_size): + assert tensor[0][i].item() == i + 1, ( + f"UCX: rank {i} col 0 = {tensor[0][i].item()}, expected {i + 1}" + ) + assert tensor[1][i].item() == (i + 1) * 10, ( + f"UCX: rank {i} col 1 = {tensor[1][i].item()}, expected {(i + 1) * 10}" + ) + + comm.finalize() + return latencies + + +def percentile(data, p): + data = sorted(data) + idx = int(len(data) * p / 100) + return data[min(idx, len(data) - 1)] + + +def print_stats(name, latencies, rank): + if latencies is None: + print(f"[rank {rank}] {name}: SKIPPED") + return + us = [t * 1e6 for t in latencies] + print( + f"[rank {rank}] {name}: " + f"p50={percentile(us, 50):.1f}us " + f"p95={percentile(us, 95):.1f}us " + f"p99={percentile(us, 99):.1f}us " + f"mean={sum(us) / len(us):.1f}us " + f"min={min(us):.1f}us " + f"max={max(us):.1f}us" + ) + + +def main(): + dist.init_process_group(backend="gloo") + rank = dist.get_rank() + world_size = dist.get_world_size() + gloo_group = dist.group.WORLD + + print(f"[rank {rank}] world_size={world_size}") + + n_iters = int(os.environ.get("TEST_ITERS", "200")) + + gloo_latencies = test_gloo_baseline(rank, world_size, gloo_group, n_iters) + print_stats("Gloo TCP", gloo_latencies, rank) + + ucx_latencies = test_ucx_allreduce(rank, world_size, gloo_group, n_iters) + print_stats("UCX RDMA", ucx_latencies, rank) + + if ucx_latencies and rank == 0: + gloo_p50 = percentile(gloo_latencies, 50) * 1e6 + ucx_p50 = percentile(ucx_latencies, 50) * 1e6 + gloo_p99 = percentile(gloo_latencies, 99) * 1e6 + ucx_p99 = percentile(ucx_latencies, 99) * 1e6 + print("\n=== Speedup ===") + print(f" P50: {gloo_p50:.1f}us -> {ucx_p50:.1f}us ({gloo_p50 / ucx_p50:.1f}x)") + print(f" P99: {gloo_p99:.1f}us -> {ucx_p99:.1f}us ({gloo_p99 / ucx_p99:.1f}x)") + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/vllm/distributed/device_communicators/test_ucx_correctness.py b/vllm/distributed/device_communicators/test_ucx_correctness.py new file mode 100644 index 00000000000..d5942482f22 --- /dev/null +++ b/vllm/distributed/device_communicators/test_ucx_correctness.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Correctness test for UCX allreduce — exercises cross-node paths. + +Tests multiple patterns: +1. Each rank writes its column (the actual DP sync pattern) +2. Each rank writes a unique value — verifies every peer's data arrives +3. Sequential rounds — verifies round counter stays in sync +4. Stress test with many rapid calls + +Single-node quick test: + torchrun --nproc-per-node=4 test_ucx_correctness.py + +Cross-node (4 nodes x 4 ranks): + torchrun --nproc-per-node=4 --nnodes=4 \ + --master-addr=$MASTER_ADDR --master-port=29500 \ + --node-rank=$NODE_RANK test_ucx_correctness.py +""" + +import os +import sys +import time + +import torch +import torch.distributed as dist + + +def load_ucx_communicator(rank, world_size, gloo_group): + here = os.path.dirname(os.path.abspath(__file__)) + sys.path.insert(0, here) + from ucx_dp_communicator import UCXDPCommunicator + + comm = UCXDPCommunicator(rank, world_size, max_msg_bytes=1024) + comm.bootstrap(gloo_group) + return comm + + +def test_column_pattern(comm, rank, world_size, n_iters=50): + """The actual DP sync pattern: each rank fills its column.""" + failures = 0 + for it in range(n_iters): + tensor = torch.zeros(4, world_size, dtype=torch.int32) + tensor[0][rank] = rank + 1 + tensor[1][rank] = (rank + 1) * 10 + tensor[2][rank] = 1 + tensor[3][rank] = it % 3 + + comm.allreduce_inplace(tensor) + + for r in range(world_size): + expected = [r + 1, (r + 1) * 10, 1, it % 3] + actual = [tensor[row][r].item() for row in range(4)] + if actual != expected: + print( + f"[rank {rank}] FAIL iter={it} col={r}: " + f"expected {expected} got {actual}" + ) + failures += 1 + if failures > 5: + return failures + return failures + + +def test_unique_values(comm, rank, world_size, n_iters=50): + """Each rank writes a unique value in every cell of its + column.""" + failures = 0 + for it in range(n_iters): + tensor = torch.zeros(4, world_size, dtype=torch.int32) + val = rank * 1000 + it + for row in range(4): + tensor[row][rank] = val + row + + comm.allreduce_inplace(tensor) + + for r in range(world_size): + base = r * 1000 + it + for row in range(4): + expected = base + row + actual = tensor[row][r].item() + if actual != expected: + print( + f"[rank {rank}] FAIL iter={it} " + f"row={row} col={r}: " + f"expected {expected} got {actual}" + ) + failures += 1 + if failures > 5: + return failures + return failures + + +def test_gloo_comparison(comm, rank, world_size, gloo_group, n_iters=50): + """Run same allreduce via both UCX and Gloo, compare.""" + failures = 0 + for it in range(n_iters): + tensor_ucx = torch.zeros(4, world_size, dtype=torch.int32) + tensor_gloo = torch.zeros(4, world_size, dtype=torch.int32) + + val = rank * 100 + it + for row in range(4): + tensor_ucx[row][rank] = val + row + tensor_gloo[row][rank] = val + row + + comm.allreduce_inplace(tensor_ucx) + dist.all_reduce(tensor_gloo, group=gloo_group) + + if not torch.equal(tensor_ucx, tensor_gloo): + diff_mask = tensor_ucx != tensor_gloo + print( + f"[rank {rank}] FAIL iter={it}: " + f"UCX != Gloo\n" + f" UCX: {tensor_ucx.tolist()}\n" + f" Gloo: {tensor_gloo.tolist()}\n" + f" Diff: {diff_mask.nonzero().tolist()}" + ) + failures += 1 + if failures > 5: + return failures + return failures + + +def test_rapid_fire(comm, rank, world_size, n_iters=200): + """Many rapid consecutive allreduces — stress the round + counter.""" + failures = 0 + for it in range(n_iters): + tensor = torch.zeros(4, world_size, dtype=torch.int32) + tensor[0][rank] = it + 1 + + comm.allreduce_inplace(tensor) + + total = tensor[0].sum().item() + expected = (it + 1) * world_size + if total != expected: + print( + f"[rank {rank}] FAIL iter={it}: " + f"sum={total} expected={expected} " + f"row0={tensor[0].tolist()}" + ) + failures += 1 + if failures > 5: + return failures + return failures + + +def main(): + dist.init_process_group(backend="gloo") + rank = dist.get_rank() + world_size = dist.get_world_size() + gloo_group = dist.group.WORLD + + print(f"[rank {rank}] world_size={world_size}") + + comm = load_ucx_communicator(rank, world_size, gloo_group) + print(f"[rank {rank}] UCX communicator ready") + + tests = [ + ( + "column_pattern", + lambda: test_column_pattern(comm, rank, world_size), + ), + ( + "unique_values", + lambda: test_unique_values(comm, rank, world_size), + ), + ( + "gloo_comparison", + lambda: test_gloo_comparison(comm, rank, world_size, gloo_group), + ), + ( + "rapid_fire", + lambda: test_rapid_fire(comm, rank, world_size), + ), + ] + + all_pass = True + for name, fn in tests: + dist.barrier(group=gloo_group) + t0 = time.monotonic() + failures = fn() + elapsed = time.monotonic() - t0 + status = "PASS" if failures == 0 else f"FAIL ({failures} failures)" + print(f"[rank {rank}] {name}: {status} ({elapsed:.2f}s)") + if failures > 0: + all_pass = False + + comm.finalize() + dist.barrier(group=gloo_group) + + if rank == 0: + result = "ALL TESTS PASSED" if all_pass else "SOME TESTS FAILED" + print(f"\n{result}") + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/vllm/distributed/device_communicators/ucx_dp_communicator.py b/vllm/distributed/device_communicators/ucx_dp_communicator.py new file mode 100644 index 00000000000..7424ba39ba8 --- /dev/null +++ b/vllm/distributed/device_communicators/ucx_dp_communicator.py @@ -0,0 +1,263 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +UCX-based one-shot AllReduce for DP metadata synchronization. + +Replaces Gloo's TCP AllReduce with UCX tag-matching over IB RDMA. +Falls back to Gloo if UCX is unavailable or init fails. + +Usage: + from vllm.distributed.device_communicators.ucx_dp_communicator import ( + try_init_ucx_dp, + get_ucx_dp_communicator, + ) + + # Once, after DP group is created: + try_init_ucx_dp(dp_rank, dp_size, gloo_group) + + # Per-iteration: + comm = get_ucx_dp_communicator() + if comm is not None: + comm.allreduce_inplace(tensor) # UCX/RDMA path + else: + dist.all_reduce(tensor, group=gloo_group) # fallback +""" + +import ctypes +import ctypes.util +import logging +import os +import subprocess +import threading + +import torch +import torch.distributed as dist + +logger = logging.getLogger(__name__) + +_communicator: "UCXDPCommunicator | None" = None +_init_lock = threading.Lock() + + +def get_ucx_dp_communicator() -> "UCXDPCommunicator | None": + return _communicator + + +def try_init_ucx_dp( + rank: int, + world_size: int, + gloo_group, + max_msg_bytes: int = 1024, +) -> bool: + global _communicator + with _init_lock: + if _communicator is not None: + return True + try: + comm = UCXDPCommunicator(rank, world_size, max_msg_bytes) + comm.bootstrap(gloo_group) + _communicator = comm + logger.info( + "UCX DP communicator ready (rank %d/%d, RDMA allreduce)", + rank, + world_size, + ) + return True + except Exception: + logger.warning( + "UCX DP communicator unavailable, using Gloo fallback", + exc_info=True, + ) + return False + + +class UCXDPCommunicator: + def __init__( + self, + rank: int, + world_size: int, + max_msg_bytes: int = 1024, + ): + self.rank = rank + self.world_size = world_size + self._lib = _load_library() + self._state = ctypes.c_void_p() + + addr_ptr = ctypes.c_void_p() + addr_len = ctypes.c_size_t() + rc = self._lib.ucx_dp_init( + rank, + world_size, + max_msg_bytes, + ctypes.byref(self._state), + ctypes.byref(addr_ptr), + ctypes.byref(addr_len), + ) + if rc != 0: + raise RuntimeError("ucx_dp_init failed") + + assert addr_ptr.value is not None + self._address = ctypes.string_at(addr_ptr.value, addr_len.value) + self._lib.ucx_dp_release_address(self._state, addr_ptr) + + def bootstrap(self, gloo_group) -> None: + """Exchange UCX worker addresses via existing Gloo group.""" + all_addrs: list = [None] * self.world_size + dist.all_gather_object(all_addrs, self._address, group=gloo_group) + + for i, addr in enumerate(all_addrs): + if i == self.rank: + continue + rc = self._lib.ucx_dp_connect( + self._state, + i, + addr, + len(addr), + ) + if rc != 0: + raise RuntimeError(f"ucx_dp_connect to rank {i} failed") + + def allreduce_inplace(self, tensor: torch.Tensor) -> None: + """In-place sum allreduce of a contiguous CPU int32 + tensor.""" + assert tensor.is_contiguous() and tensor.device.type == "cpu" + nbytes = tensor.nelement() * tensor.element_size() + rc = self._lib.ucx_dp_allreduce_inplace( + self._state, + ctypes.c_void_p(tensor.data_ptr()), + nbytes, + ) + if rc != 0: + raise RuntimeError("ucx_dp_allreduce_inplace failed") + + def finalize(self) -> None: + if self._state: + self._lib.ucx_dp_finalize(self._state) + self._state = ctypes.c_void_p() + + def __del__(self): + self.finalize() + + +# ---- library loading ---- + +_lib_cache: ctypes.CDLL | None = None + + +def _load_library() -> ctypes.CDLL: + global _lib_cache + if _lib_cache is not None: + return _lib_cache + + so_path = _find_or_compile() + if so_path is None: + raise RuntimeError( + "Cannot find or compile _ucx_dp_sync.so. " + "Place ucx_dp_sync.c next to this file, or set " + "VLLM_UCX_DP_LIB to the .so path." + ) + + lib = ctypes.CDLL(so_path) + _setup_signatures(lib) + _lib_cache = lib + return lib + + +def _setup_signatures(lib: ctypes.CDLL) -> None: + c_int = ctypes.c_int + c_size_t = ctypes.c_size_t + c_void_p = ctypes.c_void_p + POINTER = ctypes.POINTER + + lib.ucx_dp_init.restype = c_int + lib.ucx_dp_init.argtypes = [ + c_int, + c_int, + c_size_t, + POINTER(c_void_p), + POINTER(c_void_p), + POINTER(c_size_t), + ] + + lib.ucx_dp_release_address.restype = None + lib.ucx_dp_release_address.argtypes = [c_void_p, c_void_p] + + lib.ucx_dp_connect.restype = c_int + lib.ucx_dp_connect.argtypes = [ + c_void_p, + c_int, + c_void_p, + c_size_t, + ] + + lib.ucx_dp_allreduce_inplace.restype = c_int + lib.ucx_dp_allreduce_inplace.argtypes = [ + c_void_p, + c_void_p, + c_size_t, + ] + + lib.ucx_dp_finalize.restype = None + lib.ucx_dp_finalize.argtypes = [c_void_p] + + +def _find_or_compile() -> str | None: + env_path = os.environ.get("VLLM_UCX_DP_LIB") + if env_path and os.path.isfile(env_path): + return env_path + + here = os.path.dirname(os.path.abspath(__file__)) + so_path = os.path.join(here, "_ucx_dp_sync.so") + if os.path.isfile(so_path): + return so_path + + c_path = os.path.join(here, "ucx_dp_sync.c") + if not os.path.isfile(c_path): + return None + + if not _has_ucp(): + logger.warning("libucp.so not found — cannot compile UCX DP sync") + return None + + logger.info("Compiling _ucx_dp_sync.so …") + cflags: list[str] = [] + ldflags: list[str] = [] + try: + pc = subprocess.run( + ["pkg-config", "--cflags", "--libs", "ucx"], + capture_output=True, + text=True, + timeout=5, + ) + if pc.returncode == 0: + for tok in pc.stdout.strip().split(): + if tok.startswith(("-I", "-D")): + cflags.append(tok) + else: + ldflags.append(tok) + except Exception: + pass + + cmd = [ + "gcc", + "-shared", + "-fPIC", + "-O2", + *cflags, + "-o", + so_path, + c_path, + "-lucp", + "-lucs", + *ldflags, + ] + result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) + if result.returncode != 0: + logger.warning("Compile failed: %s", result.stderr) + return None + + return so_path + + +def _has_ucp() -> bool: + return ctypes.util.find_library("ucp") is not None diff --git a/vllm/distributed/device_communicators/ucx_dp_sync.c b/vllm/distributed/device_communicators/ucx_dp_sync.c new file mode 100644 index 00000000000..00a444ca377 --- /dev/null +++ b/vllm/distributed/device_communicators/ucx_dp_sync.c @@ -0,0 +1,251 @@ +/* + * ucx_dp_sync.c — One-shot AllReduce for vLLM DP metadata over UCX/RDMA + * + * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright contributors to the vLLM project + * + * Replaces Gloo TCP AllReduce (~100ms P99) with UCX tag-matching over + * InfiniBand RDMA (~0.1ms P99) for the per-iteration DP metadata sync. + * + * Build: + * gcc -shared -fPIC -O2 -o _ucx_dp_sync.so ucx_dp_sync.c -lucp -lucs + * + * The allreduce is one-shot: every rank sends its full tensor to every + * other rank, then locally sums. For 256 bytes x 15 peers this is 3.8KB + * total — well within RDMA eager threshold, so all sends complete in a + * single round with no rendezvous handshake. + */ + +#include +#include +#include +#include +#include +#include + +#define MAX_RANKS 256 + +typedef struct { + ucp_context_h ctx; + ucp_worker_h worker; + ucp_ep_h *eps; /* [world_size], NULL for self */ + int rank; + int world_size; + uint64_t round; /* monotonic counter for tag uniqueness */ + uint8_t **recv_bufs; /* [world_size] pre-allocated */ + uint8_t *send_staging; + size_t max_bytes; +} ucx_dp_state_t; + +/* ---- request completion ---- */ + +static void req_init(void *request) { + *(int *)request = 0; +} + +static void send_cb(void *request, ucs_status_t status, void *user_data) { + *(int *)request = 1; +} + +static void recv_cb(void *request, ucs_status_t status, + const ucp_tag_recv_info_t *info, void *user_data) { + *(int *)request = 1; +} + +/* ---- public API ---- */ + +int ucx_dp_init(int rank, int world_size, size_t max_bytes, + void **state_out, void **addr_out, size_t *addr_len_out) { + ucs_status_t st; + + if (world_size > MAX_RANKS) return -1; + + ucx_dp_state_t *s = (ucx_dp_state_t *)calloc(1, sizeof(*s)); + if (!s) return -1; + s->rank = rank; + s->world_size = world_size; + s->max_bytes = max_bytes; + + /* context */ + ucp_config_t *config; + st = ucp_config_read(NULL, NULL, &config); + if (st != UCS_OK) goto fail_s; + + ucp_params_t params; + memset(¶ms, 0, sizeof(params)); + params.field_mask = UCP_PARAM_FIELD_FEATURES + | UCP_PARAM_FIELD_REQUEST_SIZE + | UCP_PARAM_FIELD_REQUEST_INIT; + params.features = UCP_FEATURE_TAG; + params.request_size = sizeof(int); + params.request_init = req_init; + + st = ucp_init(¶ms, config, &s->ctx); + ucp_config_release(config); + if (st != UCS_OK) goto fail_s; + + /* worker */ + ucp_worker_params_t wp; + memset(&wp, 0, sizeof(wp)); + wp.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; + wp.thread_mode = UCS_THREAD_MODE_SINGLE; + + st = ucp_worker_create(s->ctx, &wp, &s->worker); + if (st != UCS_OK) goto fail_ctx; + + /* worker address */ + ucp_address_t *addr; + size_t addr_len; + st = ucp_worker_get_address(s->worker, &addr, &addr_len); + if (st != UCS_OK) goto fail_worker; + + /* buffers */ + s->eps = (ucp_ep_h *)calloc(world_size, sizeof(ucp_ep_h)); + s->recv_bufs = (uint8_t **)calloc(world_size, sizeof(uint8_t *)); + s->send_staging = (uint8_t *)malloc(max_bytes); + for (int i = 0; i < world_size; i++) { + if (i != rank) + s->recv_bufs[i] = (uint8_t *)malloc(max_bytes); + } + + *state_out = s; + *addr_out = addr; /* caller copies, then calls release */ + *addr_len_out = addr_len; + return 0; + +fail_worker: ucp_worker_destroy(s->worker); +fail_ctx: ucp_cleanup(s->ctx); +fail_s: free(s); + return -1; +} + +void ucx_dp_release_address(void *state, void *addr) { + ucx_dp_state_t *s = (ucx_dp_state_t *)state; + ucp_worker_release_address(s->worker, (ucp_address_t *)addr); +} + +int ucx_dp_connect(void *state, int peer, const void *addr, size_t len) { + ucx_dp_state_t *s = (ucx_dp_state_t *)state; + + ucp_ep_params_t ep; + memset(&ep, 0, sizeof(ep)); + ep.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; + ep.address = (const ucp_address_t *)addr; + + ucs_status_t st = ucp_ep_create(s->worker, &ep, &s->eps[peer]); + return (st == UCS_OK) ? 0 : -1; +} + +int ucx_dp_allreduce_inplace(void *state, void *buf, size_t nbytes) { + ucx_dp_state_t *s = (ucx_dp_state_t *)state; + if (nbytes > s->max_bytes) return -1; + + uint64_t round = s->round++; + int ws = s->world_size; + + memcpy(s->send_staging, buf, nbytes); + + ucs_status_ptr_t recv_reqs[MAX_RANKS]; + ucs_status_ptr_t send_reqs[MAX_RANKS]; + + /* post all receives */ + for (int i = 0; i < ws; i++) { + if (i == s->rank) { recv_reqs[i] = NULL; continue; } + + ucp_request_param_t p; + memset(&p, 0, sizeof(p)); + p.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK + | UCP_OP_ATTR_FIELD_FLAGS; + p.cb.recv = recv_cb; + p.flags = UCP_OP_ATTR_FLAG_NO_IMM_CMPL; + + ucp_tag_t tag = (round << 16) | (uint64_t)i; + recv_reqs[i] = ucp_tag_recv_nbx(s->worker, s->recv_bufs[i], + nbytes, tag, ~(ucp_tag_t)0, &p); + } + + /* post all sends */ + ucp_tag_t my_tag = (round << 16) | (uint64_t)s->rank; + for (int i = 0; i < ws; i++) { + if (i == s->rank) { send_reqs[i] = NULL; continue; } + + ucp_request_param_t p; + memset(&p, 0, sizeof(p)); + p.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK; + p.cb.send = send_cb; + + send_reqs[i] = ucp_tag_send_nbx(s->eps[i], s->send_staging, + nbytes, my_tag, &p); + } + + /* progress until every request completes */ + for (;;) { + ucp_worker_progress(s->worker); + int done = 1; + for (int i = 0; i < ws; i++) { + if (i == s->rank) continue; + if (recv_reqs[i] && !UCS_PTR_IS_ERR(recv_reqs[i])) { + if (ucp_request_check_status(recv_reqs[i]) == UCS_INPROGRESS) { + done = 0; break; + } + } + if (send_reqs[i] && !UCS_PTR_IS_ERR(send_reqs[i])) { + if (ucp_request_check_status(send_reqs[i]) == UCS_INPROGRESS) { + done = 0; break; + } + } + } + if (done) break; + } + + /* free requests */ + for (int i = 0; i < ws; i++) { + if (i == s->rank) continue; + if (recv_reqs[i] && !UCS_PTR_IS_ERR(recv_reqs[i])) + ucp_request_free(recv_reqs[i]); + if (send_reqs[i] && !UCS_PTR_IS_ERR(send_reqs[i])) + ucp_request_free(send_reqs[i]); + } + + /* + * Memory fence: ensure all recv buffer writes from UCX are + * visible before we read them for the reduction. On ARM + * (aarch64) the store from the transport thread and our read + * may not be ordered without this. + */ + __sync_synchronize(); + + /* local reduce: buf = local + sum(received) */ + int32_t *out = (int32_t *)buf; + int32_t *local = (int32_t *)s->send_staging; + int count = (int)(nbytes / sizeof(int32_t)); + memcpy(out, local, nbytes); + for (int i = 0; i < ws; i++) { + if (i == s->rank) continue; + int32_t *peer = (int32_t *)s->recv_bufs[i]; + for (int j = 0; j < count; j++) + out[j] += peer[j]; + } + + return 0; +} + +void ucx_dp_finalize(void *state) { + ucx_dp_state_t *s = (ucx_dp_state_t *)state; + if (!s) return; + + /* flush pending endpoint ops */ + for (int i = 0; i < 64; i++) + ucp_worker_progress(s->worker); + + if (s->worker) ucp_worker_destroy(s->worker); + if (s->ctx) ucp_cleanup(s->ctx); + + for (int i = 0; i < s->world_size; i++) { + if (s->recv_bufs && s->recv_bufs[i]) free(s->recv_bufs[i]); + } + free(s->eps); + free(s->recv_bufs); + free(s->send_staging); + free(s); +} diff --git a/vllm/v1/worker/dp_utils.py b/vllm/v1/worker/dp_utils.py index e7c6d81a992..63e568c788d 100644 --- a/vllm/v1/worker/dp_utils.py +++ b/vllm/v1/worker/dp_utils.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import threading +import time + import torch import torch.distributed as dist @@ -14,6 +17,53 @@ from vllm.v1.worker.ubatch_utils import ( logger = init_logger(__name__) +_dp_sync_stats_lock = threading.Lock() +_dp_sync_stats: list[float] = [] + +_ucx_init_attempted = False + + +def _maybe_init_ucx(parallel_config: ParallelConfig) -> None: + """Lazily initialize the UCX DP communicator on first use. + + Gated on VLLM_DP_SYNC_BACKEND=ucx. Falls back to Gloo on + failure. + """ + global _ucx_init_attempted + if _ucx_init_attempted: + return + _ucx_init_attempted = True + + import os + + if os.environ.get("VLLM_DP_SYNC_BACKEND", "").lower() != "ucx": + return + + try: + from vllm.distributed.device_communicators.ucx_dp_communicator import ( + try_init_ucx_dp, + ) + + gloo_group = get_dp_group().cpu_group + try_init_ucx_dp( + rank=parallel_config.data_parallel_rank, + world_size=parallel_config.data_parallel_size, + gloo_group=gloo_group, + max_msg_bytes=1024, + ) + except Exception: + logger.warning("UCX DP init failed, using Gloo", exc_info=True) + + +def get_dp_sync_stats() -> list[float] | None: + """Return and clear the list of DP sync latencies.""" + with _dp_sync_stats_lock: + if not _dp_sync_stats: + return None + result = list(_dp_sync_stats) + _dp_sync_stats.clear() + return result + def _get_device_and_group(parallel_config: ParallelConfig): # Use the actual device assigned to the DP group, not just the device type @@ -40,17 +90,37 @@ def _run_ar( cudagraph_mode: int, parallel_config: ParallelConfig, ) -> torch.Tensor: + _maybe_init_ucx(parallel_config) + dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank - device, group = _get_device_and_group(parallel_config) + # Populate this rank's contribution on CPU to reduce GPU syncs. tensor_cpu = torch.zeros(4, dp_size, dtype=torch.int32) tensor_cpu[0][dp_rank] = orig_num_tokens_per_ubatch tensor_cpu[1][dp_rank] = padded_num_tokens_per_ubatch tensor_cpu[2][dp_rank] = 1 if should_ubatch else 0 tensor_cpu[3][dp_rank] = cudagraph_mode - tensor = tensor_cpu.to(device, non_blocking=True) - dist.all_reduce(tensor, group=group) + + t0 = time.monotonic() + + from vllm.distributed.device_communicators.ucx_dp_communicator import ( + get_ucx_dp_communicator, + ) + + ucx = get_ucx_dp_communicator() + if ucx is not None: + ucx.allreduce_inplace(tensor_cpu) + tensor = tensor_cpu + else: + device, group = _get_device_and_group(parallel_config) + tensor = tensor_cpu.to(device, non_blocking=True) + dist.all_reduce(tensor, group=group) + + elapsed = time.monotonic() - t0 + with _dp_sync_stats_lock: + _dp_sync_stats.append(elapsed) + return tensor diff --git a/vllm/v1/worker/gpu/dp_utils.py b/vllm/v1/worker/gpu/dp_utils.py index b3c172738c3..6c1a14767ca 100644 --- a/vllm/v1/worker/gpu/dp_utils.py +++ b/vllm/v1/worker/gpu/dp_utils.py @@ -2,16 +2,55 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from __future__ import annotations +import os +import time + import torch import torch.distributed as dist from vllm.config.compilation import CUDAGraphMode from vllm.distributed.parallel_state import get_dp_group +from vllm.logger import init_logger +from vllm.v1.worker.dp_utils import _dp_sync_stats, _dp_sync_stats_lock from vllm.v1.worker.gpu.cudagraph_utils import ( BatchExecutionDescriptor, CudaGraphManager, ) +logger = init_logger(__name__) + +_ucx_init_attempted = False + + +def _maybe_init_ucx(dp_rank: int, dp_size: int) -> None: + """Lazily initialize the UCX DP communicator on first use. + + Gated on VLLM_DP_SYNC_BACKEND=ucx. Falls back to Gloo on + failure. + """ + global _ucx_init_attempted + if _ucx_init_attempted: + return + _ucx_init_attempted = True + + if os.environ.get("VLLM_DP_SYNC_BACKEND", "").lower() != "ucx": + return + + try: + from vllm.distributed.device_communicators.ucx_dp_communicator import ( + try_init_ucx_dp, + ) + + gloo_group = get_dp_group().cpu_group + try_init_ucx_dp( + rank=dp_rank, + world_size=dp_size, + gloo_group=gloo_group, + max_msg_bytes=1024, + ) + except Exception: + logger.warning("UCX DP init failed, using Gloo", exc_info=True) + def sync_cudagraph_and_dp_padding( cudagraph_manager: CudaGraphManager | None, @@ -28,12 +67,30 @@ def sync_cudagraph_and_dp_padding( Returns (synced_batch_desc, num_tokens_across_dp). """ assert dp_size > 1, "DP size must be greater than 1" - group = get_dp_group().cpu_group + + _maybe_init_ucx(dp_rank, dp_size) + tensor = torch.zeros(3, dp_size, dtype=torch.int32, device="cpu") tensor[0][dp_rank] = num_tokens tensor[1][dp_rank] = desired_batch_desc.cg_mode.value tensor[2][dp_rank] = uniform_token_count or 0 # (0 means None) - dist.all_reduce(tensor, group=group) + + t0 = time.monotonic() + + from vllm.distributed.device_communicators.ucx_dp_communicator import ( + get_ucx_dp_communicator, + ) + + ucx = get_ucx_dp_communicator() + if ucx is not None: + ucx.allreduce_inplace(tensor) + else: + group = get_dp_group().cpu_group + dist.all_reduce(tensor, group=group) + + elapsed = time.monotonic() - t0 + with _dp_sync_stats_lock: + _dp_sync_stats.append(elapsed) num_tokens_across_dp = tensor[0] cg_mode_across_dp = tensor[1]