use split_group for pytorch process group creation (#41980)

Signed-off-by: Tushar Jain <tushar00jain@users.noreply.github.com>
Co-authored-by: Tushar Jain <tushar00jain@users.noreply.github.com>
This commit is contained in:
Tushar Jain
2026-06-04 14:36:07 -04:00
committed by GitHub
parent a947f7a420
commit 38fd2405f3
11 changed files with 528 additions and 51 deletions
+8 -1
View File
@@ -15,6 +15,7 @@ import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import vllm.envs as envs
from vllm.config.parallel import ParallelConfig from vllm.config.parallel import ParallelConfig
from vllm.utils.network_utils import get_open_port from vllm.utils.network_utils import get_open_port
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
@@ -379,7 +380,13 @@ def _distributed_packed_a2a_worker(env: dict[str, str]) -> None:
update_environment_variables(env) update_environment_variables(env)
local_rank = int(env["LOCAL_RANK"]) local_rank = int(env["LOCAL_RANK"])
torch.accelerator.set_device_index(local_rank) torch.accelerator.set_device_index(local_rank)
dist.init_process_group(backend="nccl") if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
dist.init_process_group(
backend="cpu:gloo,cuda:nccl",
device_id=torch.device(f"cuda:{local_rank}"),
)
else:
dist.init_process_group(backend="nccl")
use_workspace = env.get("USE_WORKSPACE") == "1" use_workspace = env.get("USE_WORKSPACE") == "1"
if use_workspace: if use_workspace:
from vllm.v1.worker.workspace import init_workspace_manager from vllm.v1.worker.workspace import init_workspace_manager
+23 -10
View File
@@ -9,6 +9,7 @@ import pytest
import torch import torch
import torch.distributed import torch.distributed
import vllm.envs as envs
from tests.utils import ensure_current_vllm_config from tests.utils import ensure_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
@@ -82,11 +83,18 @@ def test_pynccl():
@worker_fn_wrapper @worker_fn_wrapper
def multiple_allreduce_worker_fn(): def multiple_allreduce_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}") device = torch.device(f"cuda:{torch.distributed.get_rank()}")
groups = [ if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
torch.distributed.new_group(ranks=[0, 1], backend="gloo"), # Eager-init path: parent PG has bound_device_id + a CPU backend,
torch.distributed.new_group(ranks=[2, 3], backend="gloo"), # so split_group is supported.
] group = torch.distributed.split_group(
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] split_ranks=[[0, 1], [2, 3]], backend="cpu:gloo,cuda:nccl"
)
else:
groups = [
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
torch.distributed.new_group(ranks=[2, 3], backend="gloo"),
]
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
pynccl_comm = PyNcclCommunicator(group=group, device=device) pynccl_comm = PyNcclCommunicator(group=group, device=device)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
# two groups can communicate independently # two groups can communicate independently
@@ -339,11 +347,16 @@ def test_pynccl_send_recv():
@worker_fn_wrapper @worker_fn_wrapper
def multiple_send_recv_worker_fn(): def multiple_send_recv_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}") device = torch.device(f"cuda:{torch.distributed.get_rank()}")
groups = [ if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
torch.distributed.new_group(ranks=[0, 2], backend="gloo"), group = torch.distributed.split_group(
torch.distributed.new_group(ranks=[1, 3], backend="gloo"), split_ranks=[[0, 2], [1, 3]], backend="cpu:gloo,cuda:nccl"
] )
group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1] else:
groups = [
torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
torch.distributed.new_group(ranks=[1, 3], backend="gloo"),
]
group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
pynccl_comm = PyNcclCommunicator(group=group, device=device) pynccl_comm = PyNcclCommunicator(group=group, device=device)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
+22 -7
View File
@@ -9,6 +9,7 @@ import ray
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
from vllm.distributed.device_communicators.quick_all_reduce import ( from vllm.distributed.device_communicators.quick_all_reduce import (
@@ -397,13 +398,27 @@ def qr_variable_input(rank, world_size):
ranks = [] ranks = []
for i in range(world_size): for i in range(world_size):
ranks.append(i) ranks.append(i)
dist.init_process_group( if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
backend="nccl", dist.init_process_group(
init_method="tcp://127.0.0.1:29500", backend="cpu:gloo,cuda:nccl",
rank=rank, init_method="tcp://127.0.0.1:29500",
world_size=world_size, rank=rank,
) world_size=world_size,
cpu_group = torch.distributed.new_group(ranks, backend="nccl") device_id=device,
)
else:
dist.init_process_group(
backend="nccl",
init_method="tcp://127.0.0.1:29500",
rank=rank,
world_size=world_size,
)
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
cpu_group = torch.distributed.split_group(
split_ranks=[ranks], backend="cpu:gloo,cuda:nccl"
)
else:
cpu_group = torch.distributed.new_group(ranks, backend="nccl")
handle = ops.qr_get_handle(_ptr) handle = ops.qr_get_handle(_ptr)
world_size = dist.get_world_size(group=cpu_group) world_size = dist.get_world_size(group=cpu_group)
+233
View File
@@ -0,0 +1,233 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for split_group in GroupCoordinator.
These tests verify that:
1. split_group is used for both device and CPU group creation.
2. Multiple subgroups work correctly with split_group.
3. Both GPU and CPU all-reduce work on split groups.
"""
import os
from typing import Any
import multiprocess as mp
import pytest
import torch
import torch.distributed
import vllm.envs as envs
from vllm.distributed.parallel_state import (
GroupCoordinator,
init_distributed_environment,
)
from vllm.utils.system_utils import update_environment_variables
# The whole module exercises the split_group code path, which is opt-in
# behind VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1.
pytestmark = pytest.mark.skipif(
not envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP,
reason=("VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1 not set; split_group path is opt-in."),
)
mp.set_start_method("spawn", force=True)
def distributed_run(fn, world_size):
number_of_processes = world_size
processes: list[mp.Process] = []
for i in range(number_of_processes):
env: dict[str, str] = {}
env["RANK"] = str(i)
env["LOCAL_RANK"] = str(i)
env["WORLD_SIZE"] = str(number_of_processes)
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = "12346"
# propagate the opt-in flag to the spawned child workers
env["VLLM_DISTRIBUTED_USE_SPLIT_GROUP"] = "1"
p = mp.Process(target=fn, args=(env,))
processes.append(p)
p.start()
for p in processes:
p.join()
for p in processes:
assert p.exitcode == 0
def worker_fn_wrapper(fn):
def wrapped_fn(env):
update_environment_variables(env)
local_rank = os.environ["LOCAL_RANK"]
device = torch.device(f"cuda:{local_rank}")
torch.accelerator.set_device_index(device)
init_distributed_environment()
fn()
return wrapped_fn
def _verify_device_group(coordinator: GroupCoordinator):
"""Verify device group works via all-reduce."""
local_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{local_rank}")
tensor = torch.ones(16, 16, dtype=torch.float32, device=device)
torch.distributed.all_reduce(tensor, group=coordinator.device_group)
torch.accelerator.synchronize()
expected = coordinator.world_size
assert torch.all(tensor == expected).cpu().item(), (
f"Device group all-reduce failed: expected {expected}, "
f"got {tensor.flatten()[0].item()}"
)
def _verify_cpu_group(coordinator: GroupCoordinator):
"""Verify CPU group works via all-reduce."""
tensor = torch.ones(16, dtype=torch.float32)
torch.distributed.all_reduce(tensor, group=coordinator.cpu_group)
expected = coordinator.world_size
assert torch.all(tensor == expected).cpu().item(), (
f"CPU group all-reduce failed: expected {expected}, "
f"got {tensor.flatten()[0].item()}"
)
# ---------------------------------------------------------------------------
# Test 1: Basic split_group path with 2 GPUs
# ---------------------------------------------------------------------------
@worker_fn_wrapper
def split_group_basic_worker():
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
group_ranks = [list(range(world_size))]
coordinator = GroupCoordinator(
group_ranks=group_ranks,
local_rank=rank,
torch_distributed_backend="nccl",
use_device_communicator=False,
group_name="test_split_basic",
)
_verify_device_group(coordinator)
_verify_cpu_group(coordinator)
@pytest.mark.skipif(
torch.accelerator.device_count() < 2,
reason="Need at least 2 GPUs to run the test.",
)
def test_split_group_basic():
"""Test basic GroupCoordinator creation with split_group."""
distributed_run(split_group_basic_worker, 2)
# ---------------------------------------------------------------------------
# Test 2: Multiple subgroups with split_group (4 GPUs)
# ---------------------------------------------------------------------------
@worker_fn_wrapper
def split_group_multiple_subgroups_worker():
rank = torch.distributed.get_rank()
group_ranks = [[0, 1], [2, 3]]
coordinator = GroupCoordinator(
group_ranks=group_ranks,
local_rank=rank,
torch_distributed_backend="nccl",
use_device_communicator=False,
group_name="test_split_multi",
)
assert coordinator.world_size == 2
_verify_device_group(coordinator)
_verify_cpu_group(coordinator)
if rank in [0, 1]:
assert coordinator.ranks == [0, 1]
else:
assert coordinator.ranks == [2, 3]
@pytest.mark.skipif(
torch.accelerator.device_count() < 4,
reason="Need at least 4 GPUs to run the test.",
)
def test_split_group_multiple_subgroups():
"""Test GroupCoordinator with multiple independent subgroups."""
distributed_run(split_group_multiple_subgroups_worker, 4)
# ---------------------------------------------------------------------------
# Test 3: split_group contract — every parent rank must enter with the same
# ``split_ranks``. NCCL happens to produce
# correct subgroups for disjoint partitions because the wrapper hashes
# ``my_group`` to derive the comm-split color, but the contract violation is
# real and would break under non-partition / non-NCCL backends. This test
# captures the actual ``split_ranks`` argument passed on every rank and
# asserts they match.
# ---------------------------------------------------------------------------
@worker_fn_wrapper
def split_group_contract_worker():
rank = torch.distributed.get_rank()
group_ranks = [[0, 1], [2, 3]]
captured: list[list[list[int]]] = []
original_split_group = torch.distributed.split_group
def capturing_split_group(*args, split_ranks=None, **kwargs):
captured.append([list(g) for g in split_ranks])
return original_split_group(*args, split_ranks=split_ranks, **kwargs)
torch.distributed.split_group = capturing_split_group
try:
GroupCoordinator(
group_ranks=group_ranks,
local_rank=rank,
torch_distributed_backend="nccl",
use_device_communicator=False,
group_name="test_split_contract",
)
finally:
torch.distributed.split_group = original_split_group
# GroupCoordinator builds two subgroups (device + cpu) per coordinator,
# so every rank must have made exactly two split_group calls.
if len(captured) != 2:
raise AssertionError(
f"rank {rank} expected 2 split_group calls (device + cpu), "
f"got {len(captured)}: {captured}"
)
world_size = torch.distributed.get_world_size()
for call_idx in range(2):
gathered: list[Any] = [None] * world_size
torch.distributed.all_gather_object(gathered, captured[call_idx])
# Normalize for stable comparison (sort each subgroup and the outer list).
norm = [
sorted([sorted(sg) for sg in per_rank_args]) for per_rank_args in gathered
]
reference = norm[0]
for r, args in enumerate(norm):
if args != reference:
raise AssertionError(
f"split_group contract violation on call #{call_idx}: "
f"rank {r} passed split_ranks={gathered[r]}, but rank 0 "
f"passed split_ranks={gathered[0]}. PyTorch requires every "
"parent rank to enter split_group with the same split_ranks."
)
@pytest.mark.skipif(
torch.accelerator.device_count() < 4,
reason="Need at least 4 GPUs to run the test.",
)
def test_split_group_contract_same_split_ranks_on_all_ranks():
"""All parent ranks must call torch.distributed.split_group with the same
``split_ranks`` argument. This catches the bug where each rank passed
only its own subgroup (``split_ranks=[ranks]``), which NCCL forgives for
disjoint partitions but is a documented contract violation.
"""
distributed_run(split_group_contract_worker, 4)
+15 -2
View File
@@ -5,13 +5,26 @@
import os import os
import random import random
import torch
import torch.distributed as dist import torch.distributed as dist
import vllm.envs as envs
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import get_world_group from vllm.distributed.parallel_state import get_world_group
# Let PyTorch choose the WORLD backend for the current device type. # By default, let PyTorch choose the WORLD backend for the current device
dist.init_process_group() # type (legacy lazy-init path). When VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1,
# use the explicit eager-init pattern required by `split_group` (mixed
# cpu:gloo,cuda:nccl backend + device_id binding).
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
local_rank = int(os.environ["LOCAL_RANK"])
torch.accelerator.set_device_index(local_rank)
dist.init_process_group(
backend="cpu:gloo,cuda:nccl",
device_id=torch.device(f"cuda:{local_rank}"),
)
else:
dist.init_process_group()
# Create prompts # Create prompts
prompts = [ prompts = [
+15 -2
View File
@@ -5,13 +5,26 @@
import os import os
import random import random
import torch
import torch.distributed as dist import torch.distributed as dist
import vllm.envs as envs
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import get_tp_group, get_world_group from vllm.distributed.parallel_state import get_tp_group, get_world_group
# Let PyTorch choose the WORLD backend for the current device type. # By default, let PyTorch choose the WORLD backend for the current device
dist.init_process_group() # type (legacy lazy-init path). When VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1,
# use the explicit eager-init pattern required by `split_group` (mixed
# cpu:gloo,cuda:nccl backend + device_id binding).
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
local_rank = int(os.environ["LOCAL_RANK"])
torch.accelerator.set_device_index(local_rank)
dist.init_process_group(
backend="cpu:gloo,cuda:nccl",
device_id=torch.device(f"cuda:{local_rank}"),
)
else:
dist.init_process_group()
# Create prompts # Create prompts
prompts = [ prompts = [
@@ -10,6 +10,7 @@ import torch
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
import vllm.envs as envs
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed import ( from vllm.distributed import (
cleanup_dist_env_and_memory, cleanup_dist_env_and_memory,
@@ -60,7 +61,15 @@ def _set_vllm_config(
tensor_model_parallel_size=vllm_config.parallel_config.tensor_parallel_size, tensor_model_parallel_size=vllm_config.parallel_config.tensor_parallel_size,
pipeline_model_parallel_size=vllm_config.parallel_config.pipeline_parallel_size, pipeline_model_parallel_size=vllm_config.parallel_config.pipeline_parallel_size,
) )
cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo") if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
cpu_group = torch.distributed.split_group(
split_ranks=[list(range(world_size))],
group_desc="moe_test_cpu",
)
else:
cpu_group = torch.distributed.new_group(
list(range(world_size)), backend="gloo"
)
return cpu_group return cpu_group
@@ -14,6 +14,7 @@ import torch.distributed
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
import vllm.envs as envs
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.activation import MoEActivation
@@ -375,7 +376,13 @@ def _test_deepep_deepgemm_moe(
w1_scale = w1_scale.to(device=device) w1_scale = w1_scale.to(device=device)
w2_scale = w2_scale.to(device=device) w2_scale = w2_scale.to(device=device)
pg = torch.distributed.new_group(list(range(pgi.world_size))) if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
pg = torch.distributed.split_group(
split_ranks=[list(range(pgi.world_size))],
group_desc="deepep_deepgemm_test",
)
else:
pg = torch.distributed.new_group(list(range(pgi.world_size)))
test_tensors = TestTensors.make(config, pgi.rank) test_tensors = TestTensors.make(config, pgi.rank)
block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)] block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)]
+8 -1
View File
@@ -10,6 +10,7 @@ import pytest
import torch.distributed import torch.distributed
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
import vllm.envs as envs
from tests.kernels.moe.utils import make_dummy_moe_config from tests.kernels.moe.utils import make_dummy_moe_config
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
@@ -375,7 +376,13 @@ def _deep_ep_moe(
w1_scale = w1_scale.to(device=device_idx) w1_scale = w1_scale.to(device=device_idx)
w2_scale = w2_scale.to(device=device_idx) w2_scale = w2_scale.to(device=device_idx)
pg = torch.distributed.new_group(list(range(pgi.world_size))) if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
pg = torch.distributed.split_group(
split_ranks=[list(range(pgi.world_size))],
group_desc="deepep_test",
)
else:
pg = torch.distributed.new_group(list(range(pgi.world_size)))
test_tensors = TestTensors.make(config, low_latency_mode) test_tensors = TestTensors.make(config, low_latency_mode)
with set_current_vllm_config(VllmConfig()): with set_current_vllm_config(VllmConfig()):
+178 -26
View File
@@ -227,6 +227,67 @@ def patched_fused_scaled_matmul_reduce_scatter_fake(
return res return res
def _platform_device_type() -> str:
"""Return the device-type string (e.g. ``"cuda"``, ``"xpu"``, ``"cpu"``)
for the current platform, in the form expected by
``torch.distributed.init_process_group(backend=...)``.
"""
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
return "cuda"
elif current_platform.is_xpu():
return "xpu"
elif current_platform.is_out_of_tree():
return current_platform.device_name
else:
return "cpu"
def _device_backend_str(torch_distributed_backend: str | Backend) -> str:
"""Normalize ``torch_distributed_backend`` to the ``"<device>:<backend>"``
format required by ``split_group``'s ``backend`` argument.
Accepts either a bare backend name (e.g. ``"nccl"``) or an already-prefixed
string (e.g. ``"cuda:nccl"``).
"""
backend_str = str(torch_distributed_backend)
if ":" in backend_str:
return backend_str
return f"{_platform_device_type()}:{backend_str}"
def _create_subgroups_split_group(
group_ranks: list[list[int]],
group_name: str,
torch_distributed_backend: str | Backend,
) -> tuple[ProcessGroup, ProcessGroup]:
"""Create the device + CPU subgroups for ``GroupCoordinator`` via
``torch.distributed.split_group``.
``split_group`` is collective on the parent group, so every parent rank
must enter with the same ``split_ranks`` definition. Each rank receives
the subgroup it belongs to.
"""
device_backend_str = _device_backend_str(torch_distributed_backend)
self_device_group = torch.distributed.split_group(
split_ranks=group_ranks,
group_desc=f"{group_name}:device",
backend=device_backend_str,
)
# CPU subgroup: split_group requires the requested backend filter to
# include the parent's default device type (= the device the parent PG
# was bound to via ``device_id``), so a cpu-only filter is rejected.
# Include the device backend in the filter; only the gloo backend is
# actually used for CPU collectives on this group.
self_cpu_group = torch.distributed.split_group(
split_ranks=group_ranks,
group_desc=f"{group_name}:cpu",
backend=f"cpu:gloo,{device_backend_str}",
)
return self_device_group, self_cpu_group
def patched_fused_scaled_matmul_reduce_scatter( def patched_fused_scaled_matmul_reduce_scatter(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
@@ -335,26 +396,39 @@ class GroupCoordinator:
self_device_group = None self_device_group = None
self_cpu_group = None self_cpu_group = None
from vllm.distributed.utils import get_cpu_distributed_timeout_or_none # VLLM_DISTRIBUTED_USE_SPLIT_GROUP gates the new ``split_group``
# codepath. Default (False) preserves the legacy ``new_group`` path.
timeout = get_cpu_distributed_timeout_or_none() if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
self_device_group, self_cpu_group = _create_subgroups_split_group(
for ranks in group_ranks: group_ranks, group_name, torch_distributed_backend
device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
) )
# a group with `gloo` backend, to allow direct coordination between for ranks in group_ranks:
# processes through the CPU. if self.rank in ranks:
with suppress_stdout(): self.ranks = ranks
cpu_group = torch.distributed.new_group( self.world_size = len(ranks)
ranks, backend="gloo", timeout=timeout self.rank_in_group = ranks.index(self.rank)
break
else:
from vllm.distributed.utils import get_cpu_distributed_timeout_or_none
timeout = get_cpu_distributed_timeout_or_none()
for ranks in group_ranks:
device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
) )
if self.rank in ranks: # a group with `gloo` backend, to allow direct coordination between
self.ranks = ranks # processes through the CPU.
self.world_size = len(ranks) with suppress_stdout():
self.rank_in_group = ranks.index(self.rank) cpu_group = torch.distributed.new_group(
self_device_group = device_group ranks, backend="gloo", timeout=timeout
self_cpu_group = cpu_group )
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
self_device_group = device_group
self_cpu_group = cpu_group
assert self_cpu_group is not None assert self_cpu_group is not None
assert self_device_group is not None assert self_device_group is not None
@@ -1348,6 +1422,62 @@ def set_custom_all_reduce(enable: bool):
_ENABLE_CUSTOM_ALL_REDUCE = enable _ENABLE_CUSTOM_ALL_REDUCE = enable
def _init_process_group_for_split_group(
*,
backend: str,
distributed_init_method: str,
world_size: int,
rank: int,
local_rank: int,
timeout: timedelta | None,
) -> None:
"""Initialize the default PG with both CPU (gloo) and device (e.g. nccl)
backends and an eager ``device_id`` binding so that subgroups can be
created via ``split_group`` (which requires the parent communicator to
be eagerly initialized). Falls back to ``gloo`` on CPU-only systems.
"""
if torch.accelerator.is_available() and backend != "gloo":
init_backend = "cpu:gloo,cuda:nccl"
device_id: torch.device | None = torch.device(f"cuda:{local_rank}")
else:
init_backend = "gloo"
device_id = None
torch.distributed.init_process_group(
backend=init_backend,
init_method=distributed_init_method,
world_size=world_size,
rank=rank,
timeout=timeout,
device_id=device_id,
)
def _validate_default_pg_for_split_group() -> None:
"""When an external launcher (e.g. ``torchrun``) initialized the default
PG, ``GroupCoordinator`` cannot patch in additional backends or change
the eager-init behavior ``split_group`` only selects subsets of an
existing parent. Validate that the parent has both ``device_id`` and a
CPU (gloo) backend, and emit a descriptive error pointing at the exact
init call to update otherwise.
"""
default_pg = torch.distributed.distributed_c10d._get_default_group()
assert default_pg.bound_device_id is not None, (
"External launcher initialized the default process group "
"without device_id. vLLM requires the default PG to be device-"
"bound for split_group. Pass device_id=torch.device(f'cuda:"
"{local_rank}') to torch.distributed.init_process_group()."
)
try:
default_pg._get_backend(torch.device("cpu"))
except RuntimeError as e:
raise RuntimeError(
"External launcher initialized the default process group "
"without a CPU (gloo) backend. vLLM requires both CPU and "
"device backends. Pass backend='cpu:gloo,cuda:nccl' to "
"torch.distributed.init_process_group()."
) from e
def _init_elastic_ep_world( def _init_elastic_ep_world(
config, local_rank: int, backend: str, rank: int, world_size: int config, local_rank: int, backend: str, rank: int, world_size: int
) -> None: ) -> None:
@@ -1456,14 +1586,33 @@ def init_distributed_environment(
"Fallback Gloo backend is not available." "Fallback Gloo backend is not available."
) )
backend = "gloo" backend = "gloo"
# this backend is used for WORLD if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
torch.distributed.init_process_group( # split_group needs local_rank early to compute device_id for
backend=backend, # the eager init. local_rank is not available in torch
init_method=distributed_init_method, # ProcessGroup, see https://github.com/pytorch/pytorch/issues/122816
world_size=world_size, if local_rank == -1:
rank=rank, local_rank = (
timeout=timeout, int(envs.LOCAL_RANK)
) if distributed_init_method == "env://"
else rank
)
_init_process_group_for_split_group(
backend=backend,
distributed_init_method=distributed_init_method,
world_size=world_size,
rank=rank,
local_rank=local_rank,
timeout=timeout,
)
else:
# this backend is used for WORLD
torch.distributed.init_process_group(
backend=backend,
init_method=distributed_init_method,
world_size=world_size,
rank=rank,
timeout=timeout,
)
if enable_elastic_ep: if enable_elastic_ep:
tp_pp_cpu_group = torch.distributed.new_group( tp_pp_cpu_group = torch.distributed.new_group(
backend="gloo", timeout=timeout backend="gloo", timeout=timeout
@@ -1476,6 +1625,9 @@ def init_distributed_environment(
"Elastic EP is not yet supported with multi-node TP/PP" "Elastic EP is not yet supported with multi-node TP/PP"
) )
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP and torch.accelerator.is_available():
_validate_default_pg_for_split_group()
# set the local rank # set the local rank
# local_rank is not available in torch ProcessGroup, # local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816 # see https://github.com/pytorch/pytorch/issues/122816
+8
View File
@@ -63,6 +63,7 @@ if TYPE_CHECKING:
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
VLLM_USE_RAY_V2_EXECUTOR_BACKEND: bool = False VLLM_USE_RAY_V2_EXECUTOR_BACKEND: bool = False
VLLM_DISTRIBUTED_USE_SPLIT_GROUP: bool = False
VLLM_XLA_USE_SPMD: bool = False VLLM_XLA_USE_SPMD: bool = False
VLLM_WORKER_MULTIPROC_METHOD: Literal["fork", "spawn"] = "fork" VLLM_WORKER_MULTIPROC_METHOD: Literal["fork", "spawn"] = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
@@ -877,6 +878,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_RAY_V2_EXECUTOR_BACKEND": lambda: bool( "VLLM_USE_RAY_V2_EXECUTOR_BACKEND": lambda: bool(
int(os.getenv("VLLM_USE_RAY_V2_EXECUTOR_BACKEND", "1")) int(os.getenv("VLLM_USE_RAY_V2_EXECUTOR_BACKEND", "1"))
), ),
# When True, GroupCoordinator constructs its CPU/device subgroups via
# ``torch.distributed.split_group(backend=...)``
# and ``init_distributed_environment`` initializes the default PG with
# mixed ``cpu:gloo,cuda:nccl`` backend + eager ``device_id`` binding.
"VLLM_DISTRIBUTED_USE_SPLIT_GROUP": lambda: bool(
int(os.getenv("VLLM_DISTRIBUTED_USE_SPLIT_GROUP", "0"))
),
# Use dedicated multiprocess context for workers. # Use dedicated multiprocess context for workers.
# Both spawn and fork work # Both spawn and fork work
"VLLM_WORKER_MULTIPROC_METHOD": env_with_choices( "VLLM_WORKER_MULTIPROC_METHOD": env_with_choices(