use split_group for pytorch process group creation (#41980)

Signed-off-by: Tushar Jain <tushar00jain@users.noreply.github.com>
Co-authored-by: Tushar Jain <tushar00jain@users.noreply.github.com>
This commit is contained in:
Tushar Jain
2026-06-04 14:36:07 -04:00
committed by GitHub
parent a947f7a420
commit 38fd2405f3
11 changed files with 528 additions and 51 deletions
+8 -1
View File
@@ -15,6 +15,7 @@ import pytest
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.config.parallel import ParallelConfig
from vllm.utils.network_utils import get_open_port
from vllm.utils.system_utils import update_environment_variables
@@ -379,7 +380,13 @@ def _distributed_packed_a2a_worker(env: dict[str, str]) -> None:
update_environment_variables(env)
local_rank = int(env["LOCAL_RANK"])
torch.accelerator.set_device_index(local_rank)
dist.init_process_group(backend="nccl")
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
dist.init_process_group(
backend="cpu:gloo,cuda:nccl",
device_id=torch.device(f"cuda:{local_rank}"),
)
else:
dist.init_process_group(backend="nccl")
use_workspace = env.get("USE_WORKSPACE") == "1"
if use_workspace:
from vllm.v1.worker.workspace import init_workspace_manager
+23 -10
View File
@@ -9,6 +9,7 @@ import pytest
import torch
import torch.distributed
import vllm.envs as envs
from tests.utils import ensure_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
@@ -82,11 +83,18 @@ def test_pynccl():
@worker_fn_wrapper
def multiple_allreduce_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
groups = [
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
torch.distributed.new_group(ranks=[2, 3], backend="gloo"),
]
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
# Eager-init path: parent PG has bound_device_id + a CPU backend,
# so split_group is supported.
group = torch.distributed.split_group(
split_ranks=[[0, 1], [2, 3]], backend="cpu:gloo,cuda:nccl"
)
else:
groups = [
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
torch.distributed.new_group(ranks=[2, 3], backend="gloo"),
]
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
pynccl_comm = PyNcclCommunicator(group=group, device=device)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
# two groups can communicate independently
@@ -339,11 +347,16 @@ def test_pynccl_send_recv():
@worker_fn_wrapper
def multiple_send_recv_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
groups = [
torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
torch.distributed.new_group(ranks=[1, 3], backend="gloo"),
]
group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
group = torch.distributed.split_group(
split_ranks=[[0, 2], [1, 3]], backend="cpu:gloo,cuda:nccl"
)
else:
groups = [
torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
torch.distributed.new_group(ranks=[1, 3], backend="gloo"),
]
group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
pynccl_comm = PyNcclCommunicator(group=group, device=device)
if torch.distributed.get_rank() == 0:
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
+22 -7
View File
@@ -9,6 +9,7 @@ import ray
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
from vllm.distributed.device_communicators.quick_all_reduce import (
@@ -397,13 +398,27 @@ def qr_variable_input(rank, world_size):
ranks = []
for i in range(world_size):
ranks.append(i)
dist.init_process_group(
backend="nccl",
init_method="tcp://127.0.0.1:29500",
rank=rank,
world_size=world_size,
)
cpu_group = torch.distributed.new_group(ranks, backend="nccl")
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
dist.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method="tcp://127.0.0.1:29500",
rank=rank,
world_size=world_size,
device_id=device,
)
else:
dist.init_process_group(
backend="nccl",
init_method="tcp://127.0.0.1:29500",
rank=rank,
world_size=world_size,
)
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
cpu_group = torch.distributed.split_group(
split_ranks=[ranks], backend="cpu:gloo,cuda:nccl"
)
else:
cpu_group = torch.distributed.new_group(ranks, backend="nccl")
handle = ops.qr_get_handle(_ptr)
world_size = dist.get_world_size(group=cpu_group)
+233
View File
@@ -0,0 +1,233 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for split_group in GroupCoordinator.
These tests verify that:
1. split_group is used for both device and CPU group creation.
2. Multiple subgroups work correctly with split_group.
3. Both GPU and CPU all-reduce work on split groups.
"""
import os
from typing import Any
import multiprocess as mp
import pytest
import torch
import torch.distributed
import vllm.envs as envs
from vllm.distributed.parallel_state import (
GroupCoordinator,
init_distributed_environment,
)
from vllm.utils.system_utils import update_environment_variables
# The whole module exercises the split_group code path, which is opt-in
# behind VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1.
pytestmark = pytest.mark.skipif(
not envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP,
reason=("VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1 not set; split_group path is opt-in."),
)
mp.set_start_method("spawn", force=True)
def distributed_run(fn, world_size):
number_of_processes = world_size
processes: list[mp.Process] = []
for i in range(number_of_processes):
env: dict[str, str] = {}
env["RANK"] = str(i)
env["LOCAL_RANK"] = str(i)
env["WORLD_SIZE"] = str(number_of_processes)
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = "12346"
# propagate the opt-in flag to the spawned child workers
env["VLLM_DISTRIBUTED_USE_SPLIT_GROUP"] = "1"
p = mp.Process(target=fn, args=(env,))
processes.append(p)
p.start()
for p in processes:
p.join()
for p in processes:
assert p.exitcode == 0
def worker_fn_wrapper(fn):
def wrapped_fn(env):
update_environment_variables(env)
local_rank = os.environ["LOCAL_RANK"]
device = torch.device(f"cuda:{local_rank}")
torch.accelerator.set_device_index(device)
init_distributed_environment()
fn()
return wrapped_fn
def _verify_device_group(coordinator: GroupCoordinator):
"""Verify device group works via all-reduce."""
local_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{local_rank}")
tensor = torch.ones(16, 16, dtype=torch.float32, device=device)
torch.distributed.all_reduce(tensor, group=coordinator.device_group)
torch.accelerator.synchronize()
expected = coordinator.world_size
assert torch.all(tensor == expected).cpu().item(), (
f"Device group all-reduce failed: expected {expected}, "
f"got {tensor.flatten()[0].item()}"
)
def _verify_cpu_group(coordinator: GroupCoordinator):
"""Verify CPU group works via all-reduce."""
tensor = torch.ones(16, dtype=torch.float32)
torch.distributed.all_reduce(tensor, group=coordinator.cpu_group)
expected = coordinator.world_size
assert torch.all(tensor == expected).cpu().item(), (
f"CPU group all-reduce failed: expected {expected}, "
f"got {tensor.flatten()[0].item()}"
)
# ---------------------------------------------------------------------------
# Test 1: Basic split_group path with 2 GPUs
# ---------------------------------------------------------------------------
@worker_fn_wrapper
def split_group_basic_worker():
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
group_ranks = [list(range(world_size))]
coordinator = GroupCoordinator(
group_ranks=group_ranks,
local_rank=rank,
torch_distributed_backend="nccl",
use_device_communicator=False,
group_name="test_split_basic",
)
_verify_device_group(coordinator)
_verify_cpu_group(coordinator)
@pytest.mark.skipif(
torch.accelerator.device_count() < 2,
reason="Need at least 2 GPUs to run the test.",
)
def test_split_group_basic():
"""Test basic GroupCoordinator creation with split_group."""
distributed_run(split_group_basic_worker, 2)
# ---------------------------------------------------------------------------
# Test 2: Multiple subgroups with split_group (4 GPUs)
# ---------------------------------------------------------------------------
@worker_fn_wrapper
def split_group_multiple_subgroups_worker():
rank = torch.distributed.get_rank()
group_ranks = [[0, 1], [2, 3]]
coordinator = GroupCoordinator(
group_ranks=group_ranks,
local_rank=rank,
torch_distributed_backend="nccl",
use_device_communicator=False,
group_name="test_split_multi",
)
assert coordinator.world_size == 2
_verify_device_group(coordinator)
_verify_cpu_group(coordinator)
if rank in [0, 1]:
assert coordinator.ranks == [0, 1]
else:
assert coordinator.ranks == [2, 3]
@pytest.mark.skipif(
torch.accelerator.device_count() < 4,
reason="Need at least 4 GPUs to run the test.",
)
def test_split_group_multiple_subgroups():
"""Test GroupCoordinator with multiple independent subgroups."""
distributed_run(split_group_multiple_subgroups_worker, 4)
# ---------------------------------------------------------------------------
# Test 3: split_group contract — every parent rank must enter with the same
# ``split_ranks``. NCCL happens to produce
# correct subgroups for disjoint partitions because the wrapper hashes
# ``my_group`` to derive the comm-split color, but the contract violation is
# real and would break under non-partition / non-NCCL backends. This test
# captures the actual ``split_ranks`` argument passed on every rank and
# asserts they match.
# ---------------------------------------------------------------------------
@worker_fn_wrapper
def split_group_contract_worker():
rank = torch.distributed.get_rank()
group_ranks = [[0, 1], [2, 3]]
captured: list[list[list[int]]] = []
original_split_group = torch.distributed.split_group
def capturing_split_group(*args, split_ranks=None, **kwargs):
captured.append([list(g) for g in split_ranks])
return original_split_group(*args, split_ranks=split_ranks, **kwargs)
torch.distributed.split_group = capturing_split_group
try:
GroupCoordinator(
group_ranks=group_ranks,
local_rank=rank,
torch_distributed_backend="nccl",
use_device_communicator=False,
group_name="test_split_contract",
)
finally:
torch.distributed.split_group = original_split_group
# GroupCoordinator builds two subgroups (device + cpu) per coordinator,
# so every rank must have made exactly two split_group calls.
if len(captured) != 2:
raise AssertionError(
f"rank {rank} expected 2 split_group calls (device + cpu), "
f"got {len(captured)}: {captured}"
)
world_size = torch.distributed.get_world_size()
for call_idx in range(2):
gathered: list[Any] = [None] * world_size
torch.distributed.all_gather_object(gathered, captured[call_idx])
# Normalize for stable comparison (sort each subgroup and the outer list).
norm = [
sorted([sorted(sg) for sg in per_rank_args]) for per_rank_args in gathered
]
reference = norm[0]
for r, args in enumerate(norm):
if args != reference:
raise AssertionError(
f"split_group contract violation on call #{call_idx}: "
f"rank {r} passed split_ranks={gathered[r]}, but rank 0 "
f"passed split_ranks={gathered[0]}. PyTorch requires every "
"parent rank to enter split_group with the same split_ranks."
)
@pytest.mark.skipif(
torch.accelerator.device_count() < 4,
reason="Need at least 4 GPUs to run the test.",
)
def test_split_group_contract_same_split_ranks_on_all_ranks():
"""All parent ranks must call torch.distributed.split_group with the same
``split_ranks`` argument. This catches the bug where each rank passed
only its own subgroup (``split_ranks=[ranks]``), which NCCL forgives for
disjoint partitions but is a documented contract violation.
"""
distributed_run(split_group_contract_worker, 4)
+15 -2
View File
@@ -5,13 +5,26 @@
import os
import random
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import get_world_group
# Let PyTorch choose the WORLD backend for the current device type.
dist.init_process_group()
# By default, let PyTorch choose the WORLD backend for the current device
# type (legacy lazy-init path). When VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1,
# use the explicit eager-init pattern required by `split_group` (mixed
# cpu:gloo,cuda:nccl backend + device_id binding).
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
local_rank = int(os.environ["LOCAL_RANK"])
torch.accelerator.set_device_index(local_rank)
dist.init_process_group(
backend="cpu:gloo,cuda:nccl",
device_id=torch.device(f"cuda:{local_rank}"),
)
else:
dist.init_process_group()
# Create prompts
prompts = [
+15 -2
View File
@@ -5,13 +5,26 @@
import os
import random
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import get_tp_group, get_world_group
# Let PyTorch choose the WORLD backend for the current device type.
dist.init_process_group()
# By default, let PyTorch choose the WORLD backend for the current device
# type (legacy lazy-init path). When VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1,
# use the explicit eager-init pattern required by `split_group` (mixed
# cpu:gloo,cuda:nccl backend + device_id binding).
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
local_rank = int(os.environ["LOCAL_RANK"])
torch.accelerator.set_device_index(local_rank)
dist.init_process_group(
backend="cpu:gloo,cuda:nccl",
device_id=torch.device(f"cuda:{local_rank}"),
)
else:
dist.init_process_group()
# Create prompts
prompts = [
@@ -10,6 +10,7 @@ import torch
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import ParamSpec
import vllm.envs as envs
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed import (
cleanup_dist_env_and_memory,
@@ -60,7 +61,15 @@ def _set_vllm_config(
tensor_model_parallel_size=vllm_config.parallel_config.tensor_parallel_size,
pipeline_model_parallel_size=vllm_config.parallel_config.pipeline_parallel_size,
)
cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo")
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
cpu_group = torch.distributed.split_group(
split_ranks=[list(range(world_size))],
group_desc="moe_test_cpu",
)
else:
cpu_group = torch.distributed.new_group(
list(range(world_size)), backend="gloo"
)
return cpu_group
@@ -14,6 +14,7 @@ import torch.distributed
from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec
import vllm.envs as envs
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
@@ -375,7 +376,13 @@ def _test_deepep_deepgemm_moe(
w1_scale = w1_scale.to(device=device)
w2_scale = w2_scale.to(device=device)
pg = torch.distributed.new_group(list(range(pgi.world_size)))
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
pg = torch.distributed.split_group(
split_ranks=[list(range(pgi.world_size))],
group_desc="deepep_deepgemm_test",
)
else:
pg = torch.distributed.new_group(list(range(pgi.world_size)))
test_tensors = TestTensors.make(config, pgi.rank)
block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)]
+8 -1
View File
@@ -10,6 +10,7 @@ import pytest
import torch.distributed
from torch.distributed import ProcessGroup
import vllm.envs as envs
from tests.kernels.moe.utils import make_dummy_moe_config
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
@@ -375,7 +376,13 @@ def _deep_ep_moe(
w1_scale = w1_scale.to(device=device_idx)
w2_scale = w2_scale.to(device=device_idx)
pg = torch.distributed.new_group(list(range(pgi.world_size)))
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
pg = torch.distributed.split_group(
split_ranks=[list(range(pgi.world_size))],
group_desc="deepep_test",
)
else:
pg = torch.distributed.new_group(list(range(pgi.world_size)))
test_tensors = TestTensors.make(config, low_latency_mode)
with set_current_vllm_config(VllmConfig()):