[Distributed] Add UCX one-shot AllReduce for DP metadata sync

Replace Gloo TCP AllReduce with UCX tag-matching over IB RDMA for the
per-iteration DP metadata synchronization in data-parallel mode.

The one-shot allreduce sends each rank's full tensor to every other rank
and locally sums, staying well within the RDMA eager threshold (no
rendezvous handshake). This reduces P99 DP sync latency from ~100ms
(Gloo/TCP) to ~0.1ms (UCX/RDMA) on InfiniBand clusters.

Enabled by setting VLLM_DP_SYNC_BACKEND=ucx. Falls back to Gloo
automatically if UCX is unavailable or initialization fails.

New files:
- ucx_dp_communicator.py: Python wrapper with auto-compile of C library
- ucx_dp_sync.c: C-level UCX one-shot allreduce primitive
- test_ucx_allreduce.py: Latency benchmark (Gloo vs UCX)
- test_ucx_correctness.py: Multi-pattern correctness tests

Modified files:
- vllm/v1/worker/dp_utils.py: UCX path for V1 model runner
- vllm/v1/worker/gpu/dp_utils.py: UCX path for V2 model runner

Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
Tyler Michael Smith
2026-06-03 15:46:10 -04:00
parent 271328e256
commit b0ed553028
6 changed files with 989 additions and 5 deletions
@@ -0,0 +1,143 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Standalone test for UCX DP allreduce.
Single-node quick test with 4 ranks:
torchrun --nproc-per-node=4 \
vllm/distributed/device_communicators/test_ucx_allreduce.py
Cross-node test (4 nodes x 4 ranks):
torchrun --nproc-per-node=4 --nnodes=4 \
--master-addr=$MASTER_ADDR --master-port=29500 \
--node-rank=$NODE_RANK \
vllm/distributed/device_communicators/test_ucx_allreduce.py
"""
import os
import sys
import time
import torch
import torch.distributed as dist
def test_gloo_baseline(rank, world_size, group, n_iters=100):
"""Measure Gloo TCP allreduce latency."""
tensor = torch.zeros(4, world_size, dtype=torch.int32)
# warmup
for _ in range(10):
tensor.zero_()
tensor[0][rank] = 1
dist.all_reduce(tensor, group=group)
latencies = []
for _ in range(n_iters):
tensor.zero_()
tensor[0][rank] = rank + 1
tensor[1][rank] = (rank + 1) * 10
tensor[2][rank] = 1
tensor[3][rank] = 2
t0 = time.monotonic()
dist.all_reduce(tensor, group=group)
latencies.append(time.monotonic() - t0)
for i in range(world_size):
assert tensor[0][i].item() == i + 1, f"Gloo: rank {i} col 0 wrong"
assert tensor[1][i].item() == (i + 1) * 10, f"Gloo: rank {i} col 1 wrong"
return latencies
def test_ucx_allreduce(rank, world_size, gloo_group, n_iters=100):
"""Measure UCX RDMA allreduce latency."""
here = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, here)
from ucx_dp_communicator import UCXDPCommunicator
comm = UCXDPCommunicator(rank, world_size, max_msg_bytes=1024)
comm.bootstrap(gloo_group)
print(f"[rank {rank}] UCX communicator initialized")
tensor = torch.zeros(4, world_size, dtype=torch.int32)
# warmup
for _ in range(10):
tensor.zero_()
tensor[0][rank] = 1
comm.allreduce_inplace(tensor)
latencies = []
for _ in range(n_iters):
tensor.zero_()
tensor[0][rank] = rank + 1
tensor[1][rank] = (rank + 1) * 10
tensor[2][rank] = 1
tensor[3][rank] = 2
t0 = time.monotonic()
comm.allreduce_inplace(tensor)
latencies.append(time.monotonic() - t0)
for i in range(world_size):
assert tensor[0][i].item() == i + 1, (
f"UCX: rank {i} col 0 = {tensor[0][i].item()}, expected {i + 1}"
)
assert tensor[1][i].item() == (i + 1) * 10, (
f"UCX: rank {i} col 1 = {tensor[1][i].item()}, expected {(i + 1) * 10}"
)
comm.finalize()
return latencies
def percentile(data, p):
data = sorted(data)
idx = int(len(data) * p / 100)
return data[min(idx, len(data) - 1)]
def print_stats(name, latencies, rank):
if latencies is None:
print(f"[rank {rank}] {name}: SKIPPED")
return
us = [t * 1e6 for t in latencies]
print(
f"[rank {rank}] {name}: "
f"p50={percentile(us, 50):.1f}us "
f"p95={percentile(us, 95):.1f}us "
f"p99={percentile(us, 99):.1f}us "
f"mean={sum(us) / len(us):.1f}us "
f"min={min(us):.1f}us "
f"max={max(us):.1f}us"
)
def main():
dist.init_process_group(backend="gloo")
rank = dist.get_rank()
world_size = dist.get_world_size()
gloo_group = dist.group.WORLD
print(f"[rank {rank}] world_size={world_size}")
n_iters = int(os.environ.get("TEST_ITERS", "200"))
gloo_latencies = test_gloo_baseline(rank, world_size, gloo_group, n_iters)
print_stats("Gloo TCP", gloo_latencies, rank)
ucx_latencies = test_ucx_allreduce(rank, world_size, gloo_group, n_iters)
print_stats("UCX RDMA", ucx_latencies, rank)
if ucx_latencies and rank == 0:
gloo_p50 = percentile(gloo_latencies, 50) * 1e6
ucx_p50 = percentile(ucx_latencies, 50) * 1e6
gloo_p99 = percentile(gloo_latencies, 99) * 1e6
ucx_p99 = percentile(ucx_latencies, 99) * 1e6
print("\n=== Speedup ===")
print(f" P50: {gloo_p50:.1f}us -> {ucx_p50:.1f}us ({gloo_p50 / ucx_p50:.1f}x)")
print(f" P99: {gloo_p99:.1f}us -> {ucx_p99:.1f}us ({gloo_p99 / ucx_p99:.1f}x)")
dist.destroy_process_group()
if __name__ == "__main__":
main()
@@ -0,0 +1,200 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Correctness test for UCX allreduce — exercises cross-node paths.
Tests multiple patterns:
1. Each rank writes its column (the actual DP sync pattern)
2. Each rank writes a unique value — verifies every peer's data arrives
3. Sequential rounds — verifies round counter stays in sync
4. Stress test with many rapid calls
Single-node quick test:
torchrun --nproc-per-node=4 test_ucx_correctness.py
Cross-node (4 nodes x 4 ranks):
torchrun --nproc-per-node=4 --nnodes=4 \
--master-addr=$MASTER_ADDR --master-port=29500 \
--node-rank=$NODE_RANK test_ucx_correctness.py
"""
import os
import sys
import time
import torch
import torch.distributed as dist
def load_ucx_communicator(rank, world_size, gloo_group):
here = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, here)
from ucx_dp_communicator import UCXDPCommunicator
comm = UCXDPCommunicator(rank, world_size, max_msg_bytes=1024)
comm.bootstrap(gloo_group)
return comm
def test_column_pattern(comm, rank, world_size, n_iters=50):
"""The actual DP sync pattern: each rank fills its column."""
failures = 0
for it in range(n_iters):
tensor = torch.zeros(4, world_size, dtype=torch.int32)
tensor[0][rank] = rank + 1
tensor[1][rank] = (rank + 1) * 10
tensor[2][rank] = 1
tensor[3][rank] = it % 3
comm.allreduce_inplace(tensor)
for r in range(world_size):
expected = [r + 1, (r + 1) * 10, 1, it % 3]
actual = [tensor[row][r].item() for row in range(4)]
if actual != expected:
print(
f"[rank {rank}] FAIL iter={it} col={r}: "
f"expected {expected} got {actual}"
)
failures += 1
if failures > 5:
return failures
return failures
def test_unique_values(comm, rank, world_size, n_iters=50):
"""Each rank writes a unique value in every cell of its
column."""
failures = 0
for it in range(n_iters):
tensor = torch.zeros(4, world_size, dtype=torch.int32)
val = rank * 1000 + it
for row in range(4):
tensor[row][rank] = val + row
comm.allreduce_inplace(tensor)
for r in range(world_size):
base = r * 1000 + it
for row in range(4):
expected = base + row
actual = tensor[row][r].item()
if actual != expected:
print(
f"[rank {rank}] FAIL iter={it} "
f"row={row} col={r}: "
f"expected {expected} got {actual}"
)
failures += 1
if failures > 5:
return failures
return failures
def test_gloo_comparison(comm, rank, world_size, gloo_group, n_iters=50):
"""Run same allreduce via both UCX and Gloo, compare."""
failures = 0
for it in range(n_iters):
tensor_ucx = torch.zeros(4, world_size, dtype=torch.int32)
tensor_gloo = torch.zeros(4, world_size, dtype=torch.int32)
val = rank * 100 + it
for row in range(4):
tensor_ucx[row][rank] = val + row
tensor_gloo[row][rank] = val + row
comm.allreduce_inplace(tensor_ucx)
dist.all_reduce(tensor_gloo, group=gloo_group)
if not torch.equal(tensor_ucx, tensor_gloo):
diff_mask = tensor_ucx != tensor_gloo
print(
f"[rank {rank}] FAIL iter={it}: "
f"UCX != Gloo\n"
f" UCX: {tensor_ucx.tolist()}\n"
f" Gloo: {tensor_gloo.tolist()}\n"
f" Diff: {diff_mask.nonzero().tolist()}"
)
failures += 1
if failures > 5:
return failures
return failures
def test_rapid_fire(comm, rank, world_size, n_iters=200):
"""Many rapid consecutive allreduces — stress the round
counter."""
failures = 0
for it in range(n_iters):
tensor = torch.zeros(4, world_size, dtype=torch.int32)
tensor[0][rank] = it + 1
comm.allreduce_inplace(tensor)
total = tensor[0].sum().item()
expected = (it + 1) * world_size
if total != expected:
print(
f"[rank {rank}] FAIL iter={it}: "
f"sum={total} expected={expected} "
f"row0={tensor[0].tolist()}"
)
failures += 1
if failures > 5:
return failures
return failures
def main():
dist.init_process_group(backend="gloo")
rank = dist.get_rank()
world_size = dist.get_world_size()
gloo_group = dist.group.WORLD
print(f"[rank {rank}] world_size={world_size}")
comm = load_ucx_communicator(rank, world_size, gloo_group)
print(f"[rank {rank}] UCX communicator ready")
tests = [
(
"column_pattern",
lambda: test_column_pattern(comm, rank, world_size),
),
(
"unique_values",
lambda: test_unique_values(comm, rank, world_size),
),
(
"gloo_comparison",
lambda: test_gloo_comparison(comm, rank, world_size, gloo_group),
),
(
"rapid_fire",
lambda: test_rapid_fire(comm, rank, world_size),
),
]
all_pass = True
for name, fn in tests:
dist.barrier(group=gloo_group)
t0 = time.monotonic()
failures = fn()
elapsed = time.monotonic() - t0
status = "PASS" if failures == 0 else f"FAIL ({failures} failures)"
print(f"[rank {rank}] {name}: {status} ({elapsed:.2f}s)")
if failures > 0:
all_pass = False
comm.finalize()
dist.barrier(group=gloo_group)
if rank == 0:
result = "ALL TESTS PASSED" if all_pass else "SOME TESTS FAILED"
print(f"\n{result}")
dist.destroy_process_group()
if __name__ == "__main__":
main()
@@ -0,0 +1,263 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
UCX-based one-shot AllReduce for DP metadata synchronization.
Replaces Gloo's TCP AllReduce with UCX tag-matching over IB RDMA.
Falls back to Gloo if UCX is unavailable or init fails.
Usage:
from vllm.distributed.device_communicators.ucx_dp_communicator import (
try_init_ucx_dp,
get_ucx_dp_communicator,
)
# Once, after DP group is created:
try_init_ucx_dp(dp_rank, dp_size, gloo_group)
# Per-iteration:
comm = get_ucx_dp_communicator()
if comm is not None:
comm.allreduce_inplace(tensor) # UCX/RDMA path
else:
dist.all_reduce(tensor, group=gloo_group) # fallback
"""
import ctypes
import ctypes.util
import logging
import os
import subprocess
import threading
import torch
import torch.distributed as dist
logger = logging.getLogger(__name__)
_communicator: "UCXDPCommunicator | None" = None
_init_lock = threading.Lock()
def get_ucx_dp_communicator() -> "UCXDPCommunicator | None":
return _communicator
def try_init_ucx_dp(
rank: int,
world_size: int,
gloo_group,
max_msg_bytes: int = 1024,
) -> bool:
global _communicator
with _init_lock:
if _communicator is not None:
return True
try:
comm = UCXDPCommunicator(rank, world_size, max_msg_bytes)
comm.bootstrap(gloo_group)
_communicator = comm
logger.info(
"UCX DP communicator ready (rank %d/%d, RDMA allreduce)",
rank,
world_size,
)
return True
except Exception:
logger.warning(
"UCX DP communicator unavailable, using Gloo fallback",
exc_info=True,
)
return False
class UCXDPCommunicator:
def __init__(
self,
rank: int,
world_size: int,
max_msg_bytes: int = 1024,
):
self.rank = rank
self.world_size = world_size
self._lib = _load_library()
self._state = ctypes.c_void_p()
addr_ptr = ctypes.c_void_p()
addr_len = ctypes.c_size_t()
rc = self._lib.ucx_dp_init(
rank,
world_size,
max_msg_bytes,
ctypes.byref(self._state),
ctypes.byref(addr_ptr),
ctypes.byref(addr_len),
)
if rc != 0:
raise RuntimeError("ucx_dp_init failed")
assert addr_ptr.value is not None
self._address = ctypes.string_at(addr_ptr.value, addr_len.value)
self._lib.ucx_dp_release_address(self._state, addr_ptr)
def bootstrap(self, gloo_group) -> None:
"""Exchange UCX worker addresses via existing Gloo group."""
all_addrs: list = [None] * self.world_size
dist.all_gather_object(all_addrs, self._address, group=gloo_group)
for i, addr in enumerate(all_addrs):
if i == self.rank:
continue
rc = self._lib.ucx_dp_connect(
self._state,
i,
addr,
len(addr),
)
if rc != 0:
raise RuntimeError(f"ucx_dp_connect to rank {i} failed")
def allreduce_inplace(self, tensor: torch.Tensor) -> None:
"""In-place sum allreduce of a contiguous CPU int32
tensor."""
assert tensor.is_contiguous() and tensor.device.type == "cpu"
nbytes = tensor.nelement() * tensor.element_size()
rc = self._lib.ucx_dp_allreduce_inplace(
self._state,
ctypes.c_void_p(tensor.data_ptr()),
nbytes,
)
if rc != 0:
raise RuntimeError("ucx_dp_allreduce_inplace failed")
def finalize(self) -> None:
if self._state:
self._lib.ucx_dp_finalize(self._state)
self._state = ctypes.c_void_p()
def __del__(self):
self.finalize()
# ---- library loading ----
_lib_cache: ctypes.CDLL | None = None
def _load_library() -> ctypes.CDLL:
global _lib_cache
if _lib_cache is not None:
return _lib_cache
so_path = _find_or_compile()
if so_path is None:
raise RuntimeError(
"Cannot find or compile _ucx_dp_sync.so. "
"Place ucx_dp_sync.c next to this file, or set "
"VLLM_UCX_DP_LIB to the .so path."
)
lib = ctypes.CDLL(so_path)
_setup_signatures(lib)
_lib_cache = lib
return lib
def _setup_signatures(lib: ctypes.CDLL) -> None:
c_int = ctypes.c_int
c_size_t = ctypes.c_size_t
c_void_p = ctypes.c_void_p
POINTER = ctypes.POINTER
lib.ucx_dp_init.restype = c_int
lib.ucx_dp_init.argtypes = [
c_int,
c_int,
c_size_t,
POINTER(c_void_p),
POINTER(c_void_p),
POINTER(c_size_t),
]
lib.ucx_dp_release_address.restype = None
lib.ucx_dp_release_address.argtypes = [c_void_p, c_void_p]
lib.ucx_dp_connect.restype = c_int
lib.ucx_dp_connect.argtypes = [
c_void_p,
c_int,
c_void_p,
c_size_t,
]
lib.ucx_dp_allreduce_inplace.restype = c_int
lib.ucx_dp_allreduce_inplace.argtypes = [
c_void_p,
c_void_p,
c_size_t,
]
lib.ucx_dp_finalize.restype = None
lib.ucx_dp_finalize.argtypes = [c_void_p]
def _find_or_compile() -> str | None:
env_path = os.environ.get("VLLM_UCX_DP_LIB")
if env_path and os.path.isfile(env_path):
return env_path
here = os.path.dirname(os.path.abspath(__file__))
so_path = os.path.join(here, "_ucx_dp_sync.so")
if os.path.isfile(so_path):
return so_path
c_path = os.path.join(here, "ucx_dp_sync.c")
if not os.path.isfile(c_path):
return None
if not _has_ucp():
logger.warning("libucp.so not found — cannot compile UCX DP sync")
return None
logger.info("Compiling _ucx_dp_sync.so …")
cflags: list[str] = []
ldflags: list[str] = []
try:
pc = subprocess.run(
["pkg-config", "--cflags", "--libs", "ucx"],
capture_output=True,
text=True,
timeout=5,
)
if pc.returncode == 0:
for tok in pc.stdout.strip().split():
if tok.startswith(("-I", "-D")):
cflags.append(tok)
else:
ldflags.append(tok)
except Exception:
pass
cmd = [
"gcc",
"-shared",
"-fPIC",
"-O2",
*cflags,
"-o",
so_path,
c_path,
"-lucp",
"-lucs",
*ldflags,
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
if result.returncode != 0:
logger.warning("Compile failed: %s", result.stderr)
return None
return so_path
def _has_ucp() -> bool:
return ctypes.util.find_library("ucp") is not None
@@ -0,0 +1,251 @@
/*
* ucx_dp_sync.c — One-shot AllReduce for vLLM DP metadata over UCX/RDMA
*
* SPDX-License-Identifier: Apache-2.0
* SPDX-FileCopyrightText: Copyright contributors to the vLLM project
*
* Replaces Gloo TCP AllReduce (~100ms P99) with UCX tag-matching over
* InfiniBand RDMA (~0.1ms P99) for the per-iteration DP metadata sync.
*
* Build:
* gcc -shared -fPIC -O2 -o _ucx_dp_sync.so ucx_dp_sync.c -lucp -lucs
*
* The allreduce is one-shot: every rank sends its full tensor to every
* other rank, then locally sums. For 256 bytes x 15 peers this is 3.8KB
* total — well within RDMA eager threshold, so all sends complete in a
* single round with no rendezvous handshake.
*/
#include <ucp/api/ucp.h>
#include <ucs/type/status.h>
#include <stdlib.h>
#include <string.h>
#include <stdint.h>
#include <stdio.h>
#define MAX_RANKS 256
typedef struct {
ucp_context_h ctx;
ucp_worker_h worker;
ucp_ep_h *eps; /* [world_size], NULL for self */
int rank;
int world_size;
uint64_t round; /* monotonic counter for tag uniqueness */
uint8_t **recv_bufs; /* [world_size] pre-allocated */
uint8_t *send_staging;
size_t max_bytes;
} ucx_dp_state_t;
/* ---- request completion ---- */
static void req_init(void *request) {
*(int *)request = 0;
}
static void send_cb(void *request, ucs_status_t status, void *user_data) {
*(int *)request = 1;
}
static void recv_cb(void *request, ucs_status_t status,
const ucp_tag_recv_info_t *info, void *user_data) {
*(int *)request = 1;
}
/* ---- public API ---- */
int ucx_dp_init(int rank, int world_size, size_t max_bytes,
void **state_out, void **addr_out, size_t *addr_len_out) {
ucs_status_t st;
if (world_size > MAX_RANKS) return -1;
ucx_dp_state_t *s = (ucx_dp_state_t *)calloc(1, sizeof(*s));
if (!s) return -1;
s->rank = rank;
s->world_size = world_size;
s->max_bytes = max_bytes;
/* context */
ucp_config_t *config;
st = ucp_config_read(NULL, NULL, &config);
if (st != UCS_OK) goto fail_s;
ucp_params_t params;
memset(&params, 0, sizeof(params));
params.field_mask = UCP_PARAM_FIELD_FEATURES
| UCP_PARAM_FIELD_REQUEST_SIZE
| UCP_PARAM_FIELD_REQUEST_INIT;
params.features = UCP_FEATURE_TAG;
params.request_size = sizeof(int);
params.request_init = req_init;
st = ucp_init(&params, config, &s->ctx);
ucp_config_release(config);
if (st != UCS_OK) goto fail_s;
/* worker */
ucp_worker_params_t wp;
memset(&wp, 0, sizeof(wp));
wp.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
wp.thread_mode = UCS_THREAD_MODE_SINGLE;
st = ucp_worker_create(s->ctx, &wp, &s->worker);
if (st != UCS_OK) goto fail_ctx;
/* worker address */
ucp_address_t *addr;
size_t addr_len;
st = ucp_worker_get_address(s->worker, &addr, &addr_len);
if (st != UCS_OK) goto fail_worker;
/* buffers */
s->eps = (ucp_ep_h *)calloc(world_size, sizeof(ucp_ep_h));
s->recv_bufs = (uint8_t **)calloc(world_size, sizeof(uint8_t *));
s->send_staging = (uint8_t *)malloc(max_bytes);
for (int i = 0; i < world_size; i++) {
if (i != rank)
s->recv_bufs[i] = (uint8_t *)malloc(max_bytes);
}
*state_out = s;
*addr_out = addr; /* caller copies, then calls release */
*addr_len_out = addr_len;
return 0;
fail_worker: ucp_worker_destroy(s->worker);
fail_ctx: ucp_cleanup(s->ctx);
fail_s: free(s);
return -1;
}
void ucx_dp_release_address(void *state, void *addr) {
ucx_dp_state_t *s = (ucx_dp_state_t *)state;
ucp_worker_release_address(s->worker, (ucp_address_t *)addr);
}
int ucx_dp_connect(void *state, int peer, const void *addr, size_t len) {
ucx_dp_state_t *s = (ucx_dp_state_t *)state;
ucp_ep_params_t ep;
memset(&ep, 0, sizeof(ep));
ep.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
ep.address = (const ucp_address_t *)addr;
ucs_status_t st = ucp_ep_create(s->worker, &ep, &s->eps[peer]);
return (st == UCS_OK) ? 0 : -1;
}
int ucx_dp_allreduce_inplace(void *state, void *buf, size_t nbytes) {
ucx_dp_state_t *s = (ucx_dp_state_t *)state;
if (nbytes > s->max_bytes) return -1;
uint64_t round = s->round++;
int ws = s->world_size;
memcpy(s->send_staging, buf, nbytes);
ucs_status_ptr_t recv_reqs[MAX_RANKS];
ucs_status_ptr_t send_reqs[MAX_RANKS];
/* post all receives */
for (int i = 0; i < ws; i++) {
if (i == s->rank) { recv_reqs[i] = NULL; continue; }
ucp_request_param_t p;
memset(&p, 0, sizeof(p));
p.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK
| UCP_OP_ATTR_FIELD_FLAGS;
p.cb.recv = recv_cb;
p.flags = UCP_OP_ATTR_FLAG_NO_IMM_CMPL;
ucp_tag_t tag = (round << 16) | (uint64_t)i;
recv_reqs[i] = ucp_tag_recv_nbx(s->worker, s->recv_bufs[i],
nbytes, tag, ~(ucp_tag_t)0, &p);
}
/* post all sends */
ucp_tag_t my_tag = (round << 16) | (uint64_t)s->rank;
for (int i = 0; i < ws; i++) {
if (i == s->rank) { send_reqs[i] = NULL; continue; }
ucp_request_param_t p;
memset(&p, 0, sizeof(p));
p.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK;
p.cb.send = send_cb;
send_reqs[i] = ucp_tag_send_nbx(s->eps[i], s->send_staging,
nbytes, my_tag, &p);
}
/* progress until every request completes */
for (;;) {
ucp_worker_progress(s->worker);
int done = 1;
for (int i = 0; i < ws; i++) {
if (i == s->rank) continue;
if (recv_reqs[i] && !UCS_PTR_IS_ERR(recv_reqs[i])) {
if (ucp_request_check_status(recv_reqs[i]) == UCS_INPROGRESS) {
done = 0; break;
}
}
if (send_reqs[i] && !UCS_PTR_IS_ERR(send_reqs[i])) {
if (ucp_request_check_status(send_reqs[i]) == UCS_INPROGRESS) {
done = 0; break;
}
}
}
if (done) break;
}
/* free requests */
for (int i = 0; i < ws; i++) {
if (i == s->rank) continue;
if (recv_reqs[i] && !UCS_PTR_IS_ERR(recv_reqs[i]))
ucp_request_free(recv_reqs[i]);
if (send_reqs[i] && !UCS_PTR_IS_ERR(send_reqs[i]))
ucp_request_free(send_reqs[i]);
}
/*
* Memory fence: ensure all recv buffer writes from UCX are
* visible before we read them for the reduction. On ARM
* (aarch64) the store from the transport thread and our read
* may not be ordered without this.
*/
__sync_synchronize();
/* local reduce: buf = local + sum(received) */
int32_t *out = (int32_t *)buf;
int32_t *local = (int32_t *)s->send_staging;
int count = (int)(nbytes / sizeof(int32_t));
memcpy(out, local, nbytes);
for (int i = 0; i < ws; i++) {
if (i == s->rank) continue;
int32_t *peer = (int32_t *)s->recv_bufs[i];
for (int j = 0; j < count; j++)
out[j] += peer[j];
}
return 0;
}
void ucx_dp_finalize(void *state) {
ucx_dp_state_t *s = (ucx_dp_state_t *)state;
if (!s) return;
/* flush pending endpoint ops */
for (int i = 0; i < 64; i++)
ucp_worker_progress(s->worker);
if (s->worker) ucp_worker_destroy(s->worker);
if (s->ctx) ucp_cleanup(s->ctx);
for (int i = 0; i < s->world_size; i++) {
if (s->recv_bufs && s->recv_bufs[i]) free(s->recv_bufs[i]);
}
free(s->eps);
free(s->recv_bufs);
free(s->send_staging);
free(s);
}
+73 -3
View File
@@ -1,6 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
import time
import torch
import torch.distributed as dist
@@ -14,6 +17,53 @@ from vllm.v1.worker.ubatch_utils import (
logger = init_logger(__name__)
_dp_sync_stats_lock = threading.Lock()
_dp_sync_stats: list[float] = []
_ucx_init_attempted = False
def _maybe_init_ucx(parallel_config: ParallelConfig) -> None:
"""Lazily initialize the UCX DP communicator on first use.
Gated on VLLM_DP_SYNC_BACKEND=ucx. Falls back to Gloo on
failure.
"""
global _ucx_init_attempted
if _ucx_init_attempted:
return
_ucx_init_attempted = True
import os
if os.environ.get("VLLM_DP_SYNC_BACKEND", "").lower() != "ucx":
return
try:
from vllm.distributed.device_communicators.ucx_dp_communicator import (
try_init_ucx_dp,
)
gloo_group = get_dp_group().cpu_group
try_init_ucx_dp(
rank=parallel_config.data_parallel_rank,
world_size=parallel_config.data_parallel_size,
gloo_group=gloo_group,
max_msg_bytes=1024,
)
except Exception:
logger.warning("UCX DP init failed, using Gloo", exc_info=True)
def get_dp_sync_stats() -> list[float] | None:
"""Return and clear the list of DP sync latencies."""
with _dp_sync_stats_lock:
if not _dp_sync_stats:
return None
result = list(_dp_sync_stats)
_dp_sync_stats.clear()
return result
def _get_device_and_group(parallel_config: ParallelConfig):
# Use the actual device assigned to the DP group, not just the device type
@@ -40,17 +90,37 @@ def _run_ar(
cudagraph_mode: int,
parallel_config: ParallelConfig,
) -> torch.Tensor:
_maybe_init_ucx(parallel_config)
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
device, group = _get_device_and_group(parallel_config)
# Populate this rank's contribution on CPU to reduce GPU syncs.
tensor_cpu = torch.zeros(4, dp_size, dtype=torch.int32)
tensor_cpu[0][dp_rank] = orig_num_tokens_per_ubatch
tensor_cpu[1][dp_rank] = padded_num_tokens_per_ubatch
tensor_cpu[2][dp_rank] = 1 if should_ubatch else 0
tensor_cpu[3][dp_rank] = cudagraph_mode
tensor = tensor_cpu.to(device, non_blocking=True)
dist.all_reduce(tensor, group=group)
t0 = time.monotonic()
from vllm.distributed.device_communicators.ucx_dp_communicator import (
get_ucx_dp_communicator,
)
ucx = get_ucx_dp_communicator()
if ucx is not None:
ucx.allreduce_inplace(tensor_cpu)
tensor = tensor_cpu
else:
device, group = _get_device_and_group(parallel_config)
tensor = tensor_cpu.to(device, non_blocking=True)
dist.all_reduce(tensor, group=group)
elapsed = time.monotonic() - t0
with _dp_sync_stats_lock:
_dp_sync_stats.append(elapsed)
return tensor
+59 -2
View File
@@ -2,16 +2,55 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import os
import time
import torch
import torch.distributed as dist
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger
from vllm.v1.worker.dp_utils import _dp_sync_stats, _dp_sync_stats_lock
from vllm.v1.worker.gpu.cudagraph_utils import (
BatchExecutionDescriptor,
CudaGraphManager,
)
logger = init_logger(__name__)
_ucx_init_attempted = False
def _maybe_init_ucx(dp_rank: int, dp_size: int) -> None:
"""Lazily initialize the UCX DP communicator on first use.
Gated on VLLM_DP_SYNC_BACKEND=ucx. Falls back to Gloo on
failure.
"""
global _ucx_init_attempted
if _ucx_init_attempted:
return
_ucx_init_attempted = True
if os.environ.get("VLLM_DP_SYNC_BACKEND", "").lower() != "ucx":
return
try:
from vllm.distributed.device_communicators.ucx_dp_communicator import (
try_init_ucx_dp,
)
gloo_group = get_dp_group().cpu_group
try_init_ucx_dp(
rank=dp_rank,
world_size=dp_size,
gloo_group=gloo_group,
max_msg_bytes=1024,
)
except Exception:
logger.warning("UCX DP init failed, using Gloo", exc_info=True)
def sync_cudagraph_and_dp_padding(
cudagraph_manager: CudaGraphManager | None,
@@ -28,12 +67,30 @@ def sync_cudagraph_and_dp_padding(
Returns (synced_batch_desc, num_tokens_across_dp).
"""
assert dp_size > 1, "DP size must be greater than 1"
group = get_dp_group().cpu_group
_maybe_init_ucx(dp_rank, dp_size)
tensor = torch.zeros(3, dp_size, dtype=torch.int32, device="cpu")
tensor[0][dp_rank] = num_tokens
tensor[1][dp_rank] = desired_batch_desc.cg_mode.value
tensor[2][dp_rank] = uniform_token_count or 0 # (0 means None)
dist.all_reduce(tensor, group=group)
t0 = time.monotonic()
from vllm.distributed.device_communicators.ucx_dp_communicator import (
get_ucx_dp_communicator,
)
ucx = get_ucx_dp_communicator()
if ucx is not None:
ucx.allreduce_inplace(tensor)
else:
group = get_dp_group().cpu_group
dist.all_reduce(tensor, group=group)
elapsed = time.monotonic() - t0
with _dp_sync_stats_lock:
_dp_sync_stats.append(elapsed)
num_tokens_across_dp = tensor[0]
cg_mode_across_dp = tensor[1]