mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
use split_group for pytorch process group creation (#41980)
Signed-off-by: Tushar Jain <tushar00jain@users.noreply.github.com> Co-authored-by: Tushar Jain <tushar00jain@users.noreply.github.com>
This commit is contained in:
@@ -15,6 +15,7 @@ import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
@@ -379,7 +380,13 @@ def _distributed_packed_a2a_worker(env: dict[str, str]) -> None:
|
||||
update_environment_variables(env)
|
||||
local_rank = int(env["LOCAL_RANK"])
|
||||
torch.accelerator.set_device_index(local_rank)
|
||||
dist.init_process_group(backend="nccl")
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
dist.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
device_id=torch.device(f"cuda:{local_rank}"),
|
||||
)
|
||||
else:
|
||||
dist.init_process_group(backend="nccl")
|
||||
use_workspace = env.get("USE_WORKSPACE") == "1"
|
||||
if use_workspace:
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
@@ -9,6 +9,7 @@ import pytest
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import vllm.envs as envs
|
||||
from tests.utils import ensure_current_vllm_config
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
@@ -82,11 +83,18 @@ def test_pynccl():
|
||||
@worker_fn_wrapper
|
||||
def multiple_allreduce_worker_fn():
|
||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||
groups = [
|
||||
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
|
||||
torch.distributed.new_group(ranks=[2, 3], backend="gloo"),
|
||||
]
|
||||
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
# Eager-init path: parent PG has bound_device_id + a CPU backend,
|
||||
# so split_group is supported.
|
||||
group = torch.distributed.split_group(
|
||||
split_ranks=[[0, 1], [2, 3]], backend="cpu:gloo,cuda:nccl"
|
||||
)
|
||||
else:
|
||||
groups = [
|
||||
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
|
||||
torch.distributed.new_group(ranks=[2, 3], backend="gloo"),
|
||||
]
|
||||
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
|
||||
pynccl_comm = PyNcclCommunicator(group=group, device=device)
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||
# two groups can communicate independently
|
||||
@@ -339,11 +347,16 @@ def test_pynccl_send_recv():
|
||||
@worker_fn_wrapper
|
||||
def multiple_send_recv_worker_fn():
|
||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||
groups = [
|
||||
torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
|
||||
torch.distributed.new_group(ranks=[1, 3], backend="gloo"),
|
||||
]
|
||||
group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
group = torch.distributed.split_group(
|
||||
split_ranks=[[0, 2], [1, 3]], backend="cpu:gloo,cuda:nccl"
|
||||
)
|
||||
else:
|
||||
groups = [
|
||||
torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
|
||||
torch.distributed.new_group(ranks=[1, 3], backend="gloo"),
|
||||
]
|
||||
group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
|
||||
pynccl_comm = PyNcclCommunicator(group=group, device=device)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||
|
||||
@@ -9,6 +9,7 @@ import ray
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
|
||||
from vllm.distributed.device_communicators.quick_all_reduce import (
|
||||
@@ -397,13 +398,27 @@ def qr_variable_input(rank, world_size):
|
||||
ranks = []
|
||||
for i in range(world_size):
|
||||
ranks.append(i)
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method="tcp://127.0.0.1:29500",
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
cpu_group = torch.distributed.new_group(ranks, backend="nccl")
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
dist.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
init_method="tcp://127.0.0.1:29500",
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
device_id=device,
|
||||
)
|
||||
else:
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method="tcp://127.0.0.1:29500",
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
cpu_group = torch.distributed.split_group(
|
||||
split_ranks=[ranks], backend="cpu:gloo,cuda:nccl"
|
||||
)
|
||||
else:
|
||||
cpu_group = torch.distributed.new_group(ranks, backend="nccl")
|
||||
|
||||
handle = ops.qr_get_handle(_ptr)
|
||||
world_size = dist.get_world_size(group=cpu_group)
|
||||
|
||||
@@ -0,0 +1,233 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for split_group in GroupCoordinator.
|
||||
|
||||
These tests verify that:
|
||||
1. split_group is used for both device and CPU group creation.
|
||||
2. Multiple subgroups work correctly with split_group.
|
||||
3. Both GPU and CPU all-reduce work on split groups.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import multiprocess as mp
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.parallel_state import (
|
||||
GroupCoordinator,
|
||||
init_distributed_environment,
|
||||
)
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
|
||||
# The whole module exercises the split_group code path, which is opt-in
|
||||
# behind VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1.
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP,
|
||||
reason=("VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1 not set; split_group path is opt-in."),
|
||||
)
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
|
||||
def distributed_run(fn, world_size):
|
||||
number_of_processes = world_size
|
||||
processes: list[mp.Process] = []
|
||||
for i in range(number_of_processes):
|
||||
env: dict[str, str] = {}
|
||||
env["RANK"] = str(i)
|
||||
env["LOCAL_RANK"] = str(i)
|
||||
env["WORLD_SIZE"] = str(number_of_processes)
|
||||
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||
env["MASTER_ADDR"] = "localhost"
|
||||
env["MASTER_PORT"] = "12346"
|
||||
# propagate the opt-in flag to the spawned child workers
|
||||
env["VLLM_DISTRIBUTED_USE_SPLIT_GROUP"] = "1"
|
||||
p = mp.Process(target=fn, args=(env,))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
for p in processes:
|
||||
assert p.exitcode == 0
|
||||
|
||||
|
||||
def worker_fn_wrapper(fn):
|
||||
def wrapped_fn(env):
|
||||
update_environment_variables(env)
|
||||
local_rank = os.environ["LOCAL_RANK"]
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.accelerator.set_device_index(device)
|
||||
init_distributed_environment()
|
||||
fn()
|
||||
|
||||
return wrapped_fn
|
||||
|
||||
|
||||
def _verify_device_group(coordinator: GroupCoordinator):
|
||||
"""Verify device group works via all-reduce."""
|
||||
local_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
tensor = torch.ones(16, 16, dtype=torch.float32, device=device)
|
||||
torch.distributed.all_reduce(tensor, group=coordinator.device_group)
|
||||
torch.accelerator.synchronize()
|
||||
expected = coordinator.world_size
|
||||
assert torch.all(tensor == expected).cpu().item(), (
|
||||
f"Device group all-reduce failed: expected {expected}, "
|
||||
f"got {tensor.flatten()[0].item()}"
|
||||
)
|
||||
|
||||
|
||||
def _verify_cpu_group(coordinator: GroupCoordinator):
|
||||
"""Verify CPU group works via all-reduce."""
|
||||
tensor = torch.ones(16, dtype=torch.float32)
|
||||
torch.distributed.all_reduce(tensor, group=coordinator.cpu_group)
|
||||
expected = coordinator.world_size
|
||||
assert torch.all(tensor == expected).cpu().item(), (
|
||||
f"CPU group all-reduce failed: expected {expected}, "
|
||||
f"got {tensor.flatten()[0].item()}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 1: Basic split_group path with 2 GPUs
|
||||
# ---------------------------------------------------------------------------
|
||||
@worker_fn_wrapper
|
||||
def split_group_basic_worker():
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
group_ranks = [list(range(world_size))]
|
||||
|
||||
coordinator = GroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=rank,
|
||||
torch_distributed_backend="nccl",
|
||||
use_device_communicator=False,
|
||||
group_name="test_split_basic",
|
||||
)
|
||||
|
||||
_verify_device_group(coordinator)
|
||||
_verify_cpu_group(coordinator)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
torch.accelerator.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.",
|
||||
)
|
||||
def test_split_group_basic():
|
||||
"""Test basic GroupCoordinator creation with split_group."""
|
||||
distributed_run(split_group_basic_worker, 2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 2: Multiple subgroups with split_group (4 GPUs)
|
||||
# ---------------------------------------------------------------------------
|
||||
@worker_fn_wrapper
|
||||
def split_group_multiple_subgroups_worker():
|
||||
rank = torch.distributed.get_rank()
|
||||
group_ranks = [[0, 1], [2, 3]]
|
||||
|
||||
coordinator = GroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=rank,
|
||||
torch_distributed_backend="nccl",
|
||||
use_device_communicator=False,
|
||||
group_name="test_split_multi",
|
||||
)
|
||||
|
||||
assert coordinator.world_size == 2
|
||||
|
||||
_verify_device_group(coordinator)
|
||||
_verify_cpu_group(coordinator)
|
||||
|
||||
if rank in [0, 1]:
|
||||
assert coordinator.ranks == [0, 1]
|
||||
else:
|
||||
assert coordinator.ranks == [2, 3]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
torch.accelerator.device_count() < 4,
|
||||
reason="Need at least 4 GPUs to run the test.",
|
||||
)
|
||||
def test_split_group_multiple_subgroups():
|
||||
"""Test GroupCoordinator with multiple independent subgroups."""
|
||||
distributed_run(split_group_multiple_subgroups_worker, 4)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 3: split_group contract — every parent rank must enter with the same
|
||||
# ``split_ranks``. NCCL happens to produce
|
||||
# correct subgroups for disjoint partitions because the wrapper hashes
|
||||
# ``my_group`` to derive the comm-split color, but the contract violation is
|
||||
# real and would break under non-partition / non-NCCL backends. This test
|
||||
# captures the actual ``split_ranks`` argument passed on every rank and
|
||||
# asserts they match.
|
||||
# ---------------------------------------------------------------------------
|
||||
@worker_fn_wrapper
|
||||
def split_group_contract_worker():
|
||||
rank = torch.distributed.get_rank()
|
||||
group_ranks = [[0, 1], [2, 3]]
|
||||
|
||||
captured: list[list[list[int]]] = []
|
||||
original_split_group = torch.distributed.split_group
|
||||
|
||||
def capturing_split_group(*args, split_ranks=None, **kwargs):
|
||||
captured.append([list(g) for g in split_ranks])
|
||||
return original_split_group(*args, split_ranks=split_ranks, **kwargs)
|
||||
|
||||
torch.distributed.split_group = capturing_split_group
|
||||
try:
|
||||
GroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=rank,
|
||||
torch_distributed_backend="nccl",
|
||||
use_device_communicator=False,
|
||||
group_name="test_split_contract",
|
||||
)
|
||||
finally:
|
||||
torch.distributed.split_group = original_split_group
|
||||
|
||||
# GroupCoordinator builds two subgroups (device + cpu) per coordinator,
|
||||
# so every rank must have made exactly two split_group calls.
|
||||
if len(captured) != 2:
|
||||
raise AssertionError(
|
||||
f"rank {rank} expected 2 split_group calls (device + cpu), "
|
||||
f"got {len(captured)}: {captured}"
|
||||
)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
for call_idx in range(2):
|
||||
gathered: list[Any] = [None] * world_size
|
||||
torch.distributed.all_gather_object(gathered, captured[call_idx])
|
||||
# Normalize for stable comparison (sort each subgroup and the outer list).
|
||||
norm = [
|
||||
sorted([sorted(sg) for sg in per_rank_args]) for per_rank_args in gathered
|
||||
]
|
||||
reference = norm[0]
|
||||
for r, args in enumerate(norm):
|
||||
if args != reference:
|
||||
raise AssertionError(
|
||||
f"split_group contract violation on call #{call_idx}: "
|
||||
f"rank {r} passed split_ranks={gathered[r]}, but rank 0 "
|
||||
f"passed split_ranks={gathered[0]}. PyTorch requires every "
|
||||
"parent rank to enter split_group with the same split_ranks."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
torch.accelerator.device_count() < 4,
|
||||
reason="Need at least 4 GPUs to run the test.",
|
||||
)
|
||||
def test_split_group_contract_same_split_ranks_on_all_ranks():
|
||||
"""All parent ranks must call torch.distributed.split_group with the same
|
||||
``split_ranks`` argument. This catches the bug where each rank passed
|
||||
only its own subgroup (``split_ranks=[ranks]``), which NCCL forgives for
|
||||
disjoint partitions but is a documented contract violation.
|
||||
"""
|
||||
distributed_run(split_group_contract_worker, 4)
|
||||
@@ -5,13 +5,26 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
|
||||
# Let PyTorch choose the WORLD backend for the current device type.
|
||||
dist.init_process_group()
|
||||
# By default, let PyTorch choose the WORLD backend for the current device
|
||||
# type (legacy lazy-init path). When VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1,
|
||||
# use the explicit eager-init pattern required by `split_group` (mixed
|
||||
# cpu:gloo,cuda:nccl backend + device_id binding).
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
torch.accelerator.set_device_index(local_rank)
|
||||
dist.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
device_id=torch.device(f"cuda:{local_rank}"),
|
||||
)
|
||||
else:
|
||||
dist.init_process_group()
|
||||
|
||||
# Create prompts
|
||||
prompts = [
|
||||
|
||||
@@ -5,13 +5,26 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed.parallel_state import get_tp_group, get_world_group
|
||||
|
||||
# Let PyTorch choose the WORLD backend for the current device type.
|
||||
dist.init_process_group()
|
||||
# By default, let PyTorch choose the WORLD backend for the current device
|
||||
# type (legacy lazy-init path). When VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1,
|
||||
# use the explicit eager-init pattern required by `split_group` (mixed
|
||||
# cpu:gloo,cuda:nccl backend + device_id binding).
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
torch.accelerator.set_device_index(local_rank)
|
||||
dist.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
device_id=torch.device(f"cuda:{local_rank}"),
|
||||
)
|
||||
else:
|
||||
dist.init_process_group()
|
||||
|
||||
# Create prompts
|
||||
prompts = [
|
||||
|
||||
@@ -10,6 +10,7 @@ import torch
|
||||
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import (
|
||||
cleanup_dist_env_and_memory,
|
||||
@@ -60,7 +61,15 @@ def _set_vllm_config(
|
||||
tensor_model_parallel_size=vllm_config.parallel_config.tensor_parallel_size,
|
||||
pipeline_model_parallel_size=vllm_config.parallel_config.pipeline_parallel_size,
|
||||
)
|
||||
cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo")
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
cpu_group = torch.distributed.split_group(
|
||||
split_ranks=[list(range(world_size))],
|
||||
group_desc="moe_test_cpu",
|
||||
)
|
||||
else:
|
||||
cpu_group = torch.distributed.new_group(
|
||||
list(range(world_size)), backend="gloo"
|
||||
)
|
||||
return cpu_group
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import torch.distributed
|
||||
from torch.distributed import ProcessGroup
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
@@ -375,7 +376,13 @@ def _test_deepep_deepgemm_moe(
|
||||
w1_scale = w1_scale.to(device=device)
|
||||
w2_scale = w2_scale.to(device=device)
|
||||
|
||||
pg = torch.distributed.new_group(list(range(pgi.world_size)))
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
pg = torch.distributed.split_group(
|
||||
split_ranks=[list(range(pgi.world_size))],
|
||||
group_desc="deepep_deepgemm_test",
|
||||
)
|
||||
else:
|
||||
pg = torch.distributed.new_group(list(range(pgi.world_size)))
|
||||
test_tensors = TestTensors.make(config, pgi.rank)
|
||||
block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)]
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import pytest
|
||||
import torch.distributed
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from tests.kernels.moe.utils import make_dummy_moe_config
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
@@ -375,7 +376,13 @@ def _deep_ep_moe(
|
||||
w1_scale = w1_scale.to(device=device_idx)
|
||||
w2_scale = w2_scale.to(device=device_idx)
|
||||
|
||||
pg = torch.distributed.new_group(list(range(pgi.world_size)))
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
pg = torch.distributed.split_group(
|
||||
split_ranks=[list(range(pgi.world_size))],
|
||||
group_desc="deepep_test",
|
||||
)
|
||||
else:
|
||||
pg = torch.distributed.new_group(list(range(pgi.world_size)))
|
||||
test_tensors = TestTensors.make(config, low_latency_mode)
|
||||
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
|
||||
@@ -227,6 +227,67 @@ def patched_fused_scaled_matmul_reduce_scatter_fake(
|
||||
return res
|
||||
|
||||
|
||||
def _platform_device_type() -> str:
|
||||
"""Return the device-type string (e.g. ``"cuda"``, ``"xpu"``, ``"cpu"``)
|
||||
for the current platform, in the form expected by
|
||||
``torch.distributed.init_process_group(backend=...)``.
|
||||
"""
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
return "cuda"
|
||||
elif current_platform.is_xpu():
|
||||
return "xpu"
|
||||
elif current_platform.is_out_of_tree():
|
||||
return current_platform.device_name
|
||||
else:
|
||||
return "cpu"
|
||||
|
||||
|
||||
def _device_backend_str(torch_distributed_backend: str | Backend) -> str:
|
||||
"""Normalize ``torch_distributed_backend`` to the ``"<device>:<backend>"``
|
||||
format required by ``split_group``'s ``backend`` argument.
|
||||
|
||||
Accepts either a bare backend name (e.g. ``"nccl"``) or an already-prefixed
|
||||
string (e.g. ``"cuda:nccl"``).
|
||||
"""
|
||||
backend_str = str(torch_distributed_backend)
|
||||
if ":" in backend_str:
|
||||
return backend_str
|
||||
return f"{_platform_device_type()}:{backend_str}"
|
||||
|
||||
|
||||
def _create_subgroups_split_group(
|
||||
group_ranks: list[list[int]],
|
||||
group_name: str,
|
||||
torch_distributed_backend: str | Backend,
|
||||
) -> tuple[ProcessGroup, ProcessGroup]:
|
||||
"""Create the device + CPU subgroups for ``GroupCoordinator`` via
|
||||
``torch.distributed.split_group``.
|
||||
|
||||
``split_group`` is collective on the parent group, so every parent rank
|
||||
must enter with the same ``split_ranks`` definition. Each rank receives
|
||||
the subgroup it belongs to.
|
||||
"""
|
||||
device_backend_str = _device_backend_str(torch_distributed_backend)
|
||||
self_device_group = torch.distributed.split_group(
|
||||
split_ranks=group_ranks,
|
||||
group_desc=f"{group_name}:device",
|
||||
backend=device_backend_str,
|
||||
)
|
||||
# CPU subgroup: split_group requires the requested backend filter to
|
||||
# include the parent's default device type (= the device the parent PG
|
||||
# was bound to via ``device_id``), so a cpu-only filter is rejected.
|
||||
# Include the device backend in the filter; only the gloo backend is
|
||||
# actually used for CPU collectives on this group.
|
||||
self_cpu_group = torch.distributed.split_group(
|
||||
split_ranks=group_ranks,
|
||||
group_desc=f"{group_name}:cpu",
|
||||
backend=f"cpu:gloo,{device_backend_str}",
|
||||
)
|
||||
return self_device_group, self_cpu_group
|
||||
|
||||
|
||||
def patched_fused_scaled_matmul_reduce_scatter(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
@@ -335,26 +396,39 @@ class GroupCoordinator:
|
||||
self_device_group = None
|
||||
self_cpu_group = None
|
||||
|
||||
from vllm.distributed.utils import get_cpu_distributed_timeout_or_none
|
||||
|
||||
timeout = get_cpu_distributed_timeout_or_none()
|
||||
|
||||
for ranks in group_ranks:
|
||||
device_group = torch.distributed.new_group(
|
||||
ranks, backend=torch_distributed_backend
|
||||
# VLLM_DISTRIBUTED_USE_SPLIT_GROUP gates the new ``split_group``
|
||||
# codepath. Default (False) preserves the legacy ``new_group`` path.
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
self_device_group, self_cpu_group = _create_subgroups_split_group(
|
||||
group_ranks, group_name, torch_distributed_backend
|
||||
)
|
||||
# a group with `gloo` backend, to allow direct coordination between
|
||||
# processes through the CPU.
|
||||
with suppress_stdout():
|
||||
cpu_group = torch.distributed.new_group(
|
||||
ranks, backend="gloo", timeout=timeout
|
||||
for ranks in group_ranks:
|
||||
if self.rank in ranks:
|
||||
self.ranks = ranks
|
||||
self.world_size = len(ranks)
|
||||
self.rank_in_group = ranks.index(self.rank)
|
||||
break
|
||||
else:
|
||||
from vllm.distributed.utils import get_cpu_distributed_timeout_or_none
|
||||
|
||||
timeout = get_cpu_distributed_timeout_or_none()
|
||||
|
||||
for ranks in group_ranks:
|
||||
device_group = torch.distributed.new_group(
|
||||
ranks, backend=torch_distributed_backend
|
||||
)
|
||||
if self.rank in ranks:
|
||||
self.ranks = ranks
|
||||
self.world_size = len(ranks)
|
||||
self.rank_in_group = ranks.index(self.rank)
|
||||
self_device_group = device_group
|
||||
self_cpu_group = cpu_group
|
||||
# a group with `gloo` backend, to allow direct coordination between
|
||||
# processes through the CPU.
|
||||
with suppress_stdout():
|
||||
cpu_group = torch.distributed.new_group(
|
||||
ranks, backend="gloo", timeout=timeout
|
||||
)
|
||||
if self.rank in ranks:
|
||||
self.ranks = ranks
|
||||
self.world_size = len(ranks)
|
||||
self.rank_in_group = ranks.index(self.rank)
|
||||
self_device_group = device_group
|
||||
self_cpu_group = cpu_group
|
||||
|
||||
assert self_cpu_group is not None
|
||||
assert self_device_group is not None
|
||||
@@ -1348,6 +1422,62 @@ def set_custom_all_reduce(enable: bool):
|
||||
_ENABLE_CUSTOM_ALL_REDUCE = enable
|
||||
|
||||
|
||||
def _init_process_group_for_split_group(
|
||||
*,
|
||||
backend: str,
|
||||
distributed_init_method: str,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
timeout: timedelta | None,
|
||||
) -> None:
|
||||
"""Initialize the default PG with both CPU (gloo) and device (e.g. nccl)
|
||||
backends and an eager ``device_id`` binding so that subgroups can be
|
||||
created via ``split_group`` (which requires the parent communicator to
|
||||
be eagerly initialized). Falls back to ``gloo`` on CPU-only systems.
|
||||
"""
|
||||
if torch.accelerator.is_available() and backend != "gloo":
|
||||
init_backend = "cpu:gloo,cuda:nccl"
|
||||
device_id: torch.device | None = torch.device(f"cuda:{local_rank}")
|
||||
else:
|
||||
init_backend = "gloo"
|
||||
device_id = None
|
||||
torch.distributed.init_process_group(
|
||||
backend=init_backend,
|
||||
init_method=distributed_init_method,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
timeout=timeout,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
|
||||
def _validate_default_pg_for_split_group() -> None:
|
||||
"""When an external launcher (e.g. ``torchrun``) initialized the default
|
||||
PG, ``GroupCoordinator`` cannot patch in additional backends or change
|
||||
the eager-init behavior — ``split_group`` only selects subsets of an
|
||||
existing parent. Validate that the parent has both ``device_id`` and a
|
||||
CPU (gloo) backend, and emit a descriptive error pointing at the exact
|
||||
init call to update otherwise.
|
||||
"""
|
||||
default_pg = torch.distributed.distributed_c10d._get_default_group()
|
||||
assert default_pg.bound_device_id is not None, (
|
||||
"External launcher initialized the default process group "
|
||||
"without device_id. vLLM requires the default PG to be device-"
|
||||
"bound for split_group. Pass device_id=torch.device(f'cuda:"
|
||||
"{local_rank}') to torch.distributed.init_process_group()."
|
||||
)
|
||||
try:
|
||||
default_pg._get_backend(torch.device("cpu"))
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(
|
||||
"External launcher initialized the default process group "
|
||||
"without a CPU (gloo) backend. vLLM requires both CPU and "
|
||||
"device backends. Pass backend='cpu:gloo,cuda:nccl' to "
|
||||
"torch.distributed.init_process_group()."
|
||||
) from e
|
||||
|
||||
|
||||
def _init_elastic_ep_world(
|
||||
config, local_rank: int, backend: str, rank: int, world_size: int
|
||||
) -> None:
|
||||
@@ -1456,14 +1586,33 @@ def init_distributed_environment(
|
||||
"Fallback Gloo backend is not available."
|
||||
)
|
||||
backend = "gloo"
|
||||
# this backend is used for WORLD
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
init_method=distributed_init_method,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
timeout=timeout,
|
||||
)
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
# split_group needs local_rank early to compute device_id for
|
||||
# the eager init. local_rank is not available in torch
|
||||
# ProcessGroup, see https://github.com/pytorch/pytorch/issues/122816
|
||||
if local_rank == -1:
|
||||
local_rank = (
|
||||
int(envs.LOCAL_RANK)
|
||||
if distributed_init_method == "env://"
|
||||
else rank
|
||||
)
|
||||
_init_process_group_for_split_group(
|
||||
backend=backend,
|
||||
distributed_init_method=distributed_init_method,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
local_rank=local_rank,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
# this backend is used for WORLD
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
init_method=distributed_init_method,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
timeout=timeout,
|
||||
)
|
||||
if enable_elastic_ep:
|
||||
tp_pp_cpu_group = torch.distributed.new_group(
|
||||
backend="gloo", timeout=timeout
|
||||
@@ -1476,6 +1625,9 @@ def init_distributed_environment(
|
||||
"Elastic EP is not yet supported with multi-node TP/PP"
|
||||
)
|
||||
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP and torch.accelerator.is_available():
|
||||
_validate_default_pg_for_split_group()
|
||||
|
||||
# set the local rank
|
||||
# local_rank is not available in torch ProcessGroup,
|
||||
# see https://github.com/pytorch/pytorch/issues/122816
|
||||
|
||||
@@ -63,6 +63,7 @@ if TYPE_CHECKING:
|
||||
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
|
||||
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
|
||||
VLLM_USE_RAY_V2_EXECUTOR_BACKEND: bool = False
|
||||
VLLM_DISTRIBUTED_USE_SPLIT_GROUP: bool = False
|
||||
VLLM_XLA_USE_SPMD: bool = False
|
||||
VLLM_WORKER_MULTIPROC_METHOD: Literal["fork", "spawn"] = "fork"
|
||||
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
|
||||
@@ -877,6 +878,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_RAY_V2_EXECUTOR_BACKEND": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_RAY_V2_EXECUTOR_BACKEND", "1"))
|
||||
),
|
||||
# When True, GroupCoordinator constructs its CPU/device subgroups via
|
||||
# ``torch.distributed.split_group(backend=...)``
|
||||
# and ``init_distributed_environment`` initializes the default PG with
|
||||
# mixed ``cpu:gloo,cuda:nccl`` backend + eager ``device_id`` binding.
|
||||
"VLLM_DISTRIBUTED_USE_SPLIT_GROUP": lambda: bool(
|
||||
int(os.getenv("VLLM_DISTRIBUTED_USE_SPLIT_GROUP", "0"))
|
||||
),
|
||||
# Use dedicated multiprocess context for workers.
|
||||
# Both spawn and fork work
|
||||
"VLLM_WORKER_MULTIPROC_METHOD": env_with_choices(
|
||||
|
||||
Reference in New Issue
Block a user