[1/N] Elastic EP Milestone 2 (#34861)

Signed-off-by: Yongji Wu <wuyongji317@gmail.com>
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Ron Tourgeman <rtourgeman@nvidia.com>
Co-authored-by: Yongji Wu <wuyongji317@gmail.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
This commit is contained in:
Itay Alroy
2026-02-28 06:46:42 +02:00
committed by GitHub
parent 90805ff464
commit dea268336f
53 changed files with 3613 additions and 1016 deletions
+16 -1
View File
@@ -20,4 +20,19 @@ steps:
- tests/distributed/test_eplb_execute.py
commands:
- pytest -v -s distributed/test_eplb_execute.py
- pytest -v -s distributed/test_eplb_spec_decode.py
- pytest -v -s distributed/test_eplb_spec_decode.py
- label: Elastic EP Scaling Test
timeout_in_minutes: 20
device: b200
optional: true
working_dir: "/vllm-workspace/tests"
num_devices: 4
source_file_dependencies:
- vllm/distributed/
- vllm/engine/
- vllm/executor/
- vllm/compilation/
- tests/distributed/
commands:
- pytest -v -s distributed/test_elastic_ep.py
@@ -316,7 +316,6 @@ def async_tp_pass_on_test_model(
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig()
@@ -334,11 +333,10 @@ def async_tp_pass_on_test_model(
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
)
async_tp_pass = AsyncTPPass(vllm_config)
# Set the global vllm_config for TestBackend which calls
# get_current_vllm_config()
with set_current_vllm_config(vllm_config):
initialize_model_parallel(tensor_model_parallel_size=world_size)
async_tp_pass = AsyncTPPass(vllm_config)
backend = TestBackend(async_tp_pass)
assert (
@@ -278,7 +278,6 @@ def all_reduce_fusion_pass_on_test_model(
)
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
custom_ops = []
if enable_rms_norm_custom_op:
@@ -304,6 +303,7 @@ def all_reduce_fusion_pass_on_test_model(
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
)
with set_current_vllm_config(vllm_config):
initialize_model_parallel(tensor_model_parallel_size=world_size)
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
@@ -242,7 +242,6 @@ def sequence_parallelism_pass_on_test_model(
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# configure vllm config for SequenceParallelismPass
custom_ops_list = custom_ops.split(",") if custom_ops else []
@@ -272,6 +271,7 @@ def sequence_parallelism_pass_on_test_model(
)
with set_current_vllm_config(vllm_config):
initialize_model_parallel(tensor_model_parallel_size=world_size)
noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
+13 -9
View File
@@ -176,16 +176,20 @@ def init_test_http_connection():
@pytest.fixture
def dist_init():
from tests.utils import ensure_current_vllm_config
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend="nccl",
)
initialize_model_parallel(1, 1)
yield
with ensure_current_vllm_config():
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend="nccl",
)
initialize_model_parallel(1, 1)
yield
cleanup_dist_env_and_memory()
+6 -1
View File
@@ -7,6 +7,7 @@ import random
import torch
import torch.multiprocessing as mp
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import (
init_distributed_environment,
)
@@ -42,7 +43,11 @@ def set_env_vars_and_device(env: dict[str, str]) -> None:
local_rank = os.environ["LOCAL_RANK"]
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
init_distributed_environment()
# Create a minimal vllm config for init_distributed_environment
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
init_distributed_environment()
# Ensure each worker process has the same random seed
random.seed(42)
+202
View File
@@ -0,0 +1,202 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import subprocess
import time
import pytest
import requests
from ..evals.gsm8k.gsm8k_eval import evaluate_gsm8k
from ..utils import RemoteOpenAIServer, multi_gpu_test
@pytest.fixture(autouse=True)
def cleanup_ray_between_tests():
"""Force-stop any lingering Ray processes between tests."""
subprocess.run(["ray", "stop", "--force"], timeout=30, capture_output=True)
time.sleep(5)
yield
MODEL_NAME = "deepseek-ai/DeepSeek-V2-Lite-Chat"
NUM_GSM8K_QUESTIONS = 256
EXPECTED_ACCURACY = 0.58
ACCURACY_TOL = 0.08
MAX_NUM_SEQS = 32
def _send_scale_command(server: RemoteOpenAIServer, new_dp_size: int) -> bool:
url = server.url_for("scale_elastic_ep")
payload = {"new_data_parallel_size": new_dp_size}
headers = {"Content-Type": "application/json"}
try:
response = requests.post(url, json=payload, headers=headers, timeout=300)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
def _run_gsm8k_eval(server: RemoteOpenAIServer, stage: str) -> float:
assert server.port is not None
result = evaluate_gsm8k(
num_questions=NUM_GSM8K_QUESTIONS,
host=f"http://{server.host}",
port=server.port,
)
accuracy = result["accuracy"]
print(
f"[{stage}] GSM8K accuracy: {accuracy:.3f} "
f"({result['num_questions']} questions)"
)
assert accuracy >= EXPECTED_ACCURACY, (
f"[{stage}] GSM8K accuracy {accuracy:.3f} is below "
f"expected threshold {EXPECTED_ACCURACY}"
)
return accuracy
@multi_gpu_test(num_gpus=4)
def test_elastic_ep_scaling():
vllm_serve_args = [
"--trust-remote-code",
"--tensor-parallel-size",
"1",
"--gpu-memory-utilization",
"0.8",
"--max-model-len",
"4096",
"--max-num-seqs",
str(MAX_NUM_SEQS),
"--enable-expert-parallel",
"--all2all-backend",
"allgather_reducescatter",
"--enable-elastic-ep",
"--enable-eplb",
"--eplb-config.num_redundant_experts",
"0",
"--data-parallel-backend",
"ray",
"--data-parallel-size",
"2",
"--api-server-count",
"1",
]
leader_address = os.environ.get("LEADER_ADDRESS")
if leader_address:
vllm_serve_args.extend(["--data-parallel-address", leader_address])
with RemoteOpenAIServer(
MODEL_NAME, vllm_serve_args, env_dict={}, max_wait_seconds=1200
) as server:
initial_accuracy = _run_gsm8k_eval(server, "Initial (2 GPUs)")
assert _send_scale_command(server, 4)
time.sleep(10)
scale_up_accuracy = _run_gsm8k_eval(server, "After scale up (4 GPUs)")
assert scale_up_accuracy >= initial_accuracy - ACCURACY_TOL, (
f"Scale up accuracy {scale_up_accuracy:.3f} dropped more than "
f"{ACCURACY_TOL} below initial accuracy {initial_accuracy:.3f}"
)
assert _send_scale_command(server, 2)
time.sleep(5)
scale_down_accuracy = _run_gsm8k_eval(server, "After scale down (2 GPUs)")
assert scale_down_accuracy >= initial_accuracy - ACCURACY_TOL, (
f"Scale down accuracy {scale_down_accuracy:.3f} dropped more than "
f"{ACCURACY_TOL} below initial accuracy {initial_accuracy:.3f}"
)
print("\nAccuracy Summary:")
print(f" Initial: {initial_accuracy:.3f}")
print(
f" Scale up: {scale_up_accuracy:.3f} "
f"(diff: {scale_up_accuracy - initial_accuracy:+.3f})"
)
print(
f" Scale down: {scale_down_accuracy:.3f} "
f"(diff: {scale_down_accuracy - initial_accuracy:+.3f})"
)
print(f" Tolerance: {ACCURACY_TOL:.3f}")
@multi_gpu_test(num_gpus=4)
def test_elastic_ep_scaling_uneven():
"""Test scale up with uneven worker distribution.
This tests the case where num_new_workers % old_dp_size != 0,
specifically 2 -> 3 where remainder = 1 % 2 = 1.
This exercises the remainder handling in sender-receiver pairing.
"""
vllm_serve_args = [
"--trust-remote-code",
"--tensor-parallel-size",
"1",
"--gpu-memory-utilization",
"0.8",
"--max-model-len",
"4096",
"--max-num-seqs",
str(MAX_NUM_SEQS),
"--enable-expert-parallel",
"--all2all-backend",
"allgather_reducescatter",
"--enable-elastic-ep",
"--enable-eplb",
"--eplb-config.num_redundant_experts",
"0",
"--data-parallel-backend",
"ray",
"--data-parallel-size",
"2",
"--api-server-count",
"1",
]
leader_address = os.environ.get("LEADER_ADDRESS")
if leader_address:
vllm_serve_args.extend(["--data-parallel-address", leader_address])
with RemoteOpenAIServer(
MODEL_NAME, vllm_serve_args, env_dict={}, max_wait_seconds=1200
) as server:
initial_accuracy = _run_gsm8k_eval(server, "Initial (2 GPUs)")
# Scale 2 -> 3: This has remainder = 1 % 2 = 1
# Tests uneven sender-receiver pairing
assert _send_scale_command(server, 3)
time.sleep(10)
scale_up_accuracy = _run_gsm8k_eval(server, "After scale up (3 GPUs)")
assert scale_up_accuracy >= initial_accuracy - ACCURACY_TOL, (
f"Scale up accuracy {scale_up_accuracy:.3f} dropped more than "
f"{ACCURACY_TOL} below initial accuracy {initial_accuracy:.3f}"
)
# Scale back down to 2
assert _send_scale_command(server, 2)
time.sleep(5)
scale_down_accuracy = _run_gsm8k_eval(server, "After scale down (2 GPUs)")
assert scale_down_accuracy >= initial_accuracy - ACCURACY_TOL, (
f"Scale down accuracy {scale_down_accuracy:.3f} dropped more than "
f"{ACCURACY_TOL} below initial accuracy {initial_accuracy:.3f}"
)
print("\nAccuracy Summary (Uneven Scaling):")
print(f" Initial: {initial_accuracy:.3f}")
print(
f" Scale up: {scale_up_accuracy:.3f} "
f"(diff: {scale_up_accuracy - initial_accuracy:+.3f})"
)
print(
f" Scale down: {scale_down_accuracy:.3f} "
f"(diff: {scale_down_accuracy - initial_accuracy:+.3f})"
)
print(f" Tolerance: {ACCURACY_TOL:.3f}")
+245 -224
View File
@@ -8,6 +8,7 @@ import pytest
import torch
import torch.distributed
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.eplb.rebalance_execute import (
move_from_buffer,
rearrange_expert_weights_inplace,
@@ -244,90 +245,95 @@ def _test_async_transfer_layer_without_mtp_worker(
num_logical_experts: int,
) -> None:
set_env_vars_and_device(env)
ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
tp_group = get_tp_group()
ep_group = tp_group.device_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
vllm_config = VllmConfig()
vllm_config.parallel_config.tensor_parallel_size = world_size
total_physical_experts = world_size * num_local_experts
hidden_sizes = [16, 32]
with set_current_vllm_config(vllm_config):
ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
redundancy_config = create_redundancy_config(
num_logical_experts,
total_physical_experts,
)
old_indices = create_expert_indices_with_redundancy(
num_layers,
num_logical_experts,
total_physical_experts,
redundancy_config,
)
tp_group = get_tp_group()
ep_group = tp_group.device_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
new_redundancy_config = create_redundancy_config(
num_logical_experts,
total_physical_experts,
)
new_indices = create_expert_indices_with_redundancy(
num_layers,
num_logical_experts,
total_physical_experts,
new_redundancy_config,
)
total_physical_experts = world_size * num_local_experts
hidden_sizes = [16, 32]
expert_weights = create_expert_weights(
num_layers,
num_local_experts,
hidden_sizes,
ep_rank,
device,
old_indices,
)
old_indices_cpu = old_indices.cpu()
new_indices_cpu = new_indices.cpu()
redundancy_config = create_redundancy_config(
num_logical_experts,
total_physical_experts,
)
old_indices = create_expert_indices_with_redundancy(
num_layers,
num_logical_experts,
total_physical_experts,
redundancy_config,
)
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
cuda_stream = torch.cuda.Stream(device=device)
new_redundancy_config = create_redundancy_config(
num_logical_experts,
total_physical_experts,
)
new_indices = create_expert_indices_with_redundancy(
num_layers,
num_logical_experts,
total_physical_experts,
new_redundancy_config,
)
for layer_idx in range(num_layers):
is_unchanged, is_received_locally, recv_metadata = asyncio.run(
transfer_layer(
old_layer_indices=old_indices_cpu[layer_idx],
new_layer_indices=new_indices_cpu[layer_idx],
expert_weights=expert_weights[layer_idx],
expert_weights_buffer=expert_buffer,
ep_group=ep_group,
cuda_stream=cuda_stream,
expert_weights = create_expert_weights(
num_layers,
num_local_experts,
hidden_sizes,
ep_rank,
device,
old_indices,
)
old_indices_cpu = old_indices.cpu()
new_indices_cpu = new_indices.cpu()
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
cuda_stream = torch.cuda.Stream(device=device)
for layer_idx in range(num_layers):
is_unchanged, is_received_locally, recv_metadata = asyncio.run(
transfer_layer(
old_layer_indices=old_indices_cpu[layer_idx],
new_layer_indices=new_indices_cpu[layer_idx],
expert_weights=expert_weights[layer_idx],
expert_weights_buffer=expert_buffer,
ep_group=ep_group,
cuda_stream=cuda_stream,
)
)
cuda_stream.synchronize()
move_from_buffer(
expert_weights=expert_weights[layer_idx],
expert_weights_buffers=expert_buffer,
is_unchanged=is_unchanged,
is_received_locally=is_received_locally,
recv_metadata=recv_metadata,
new_indices=new_indices_cpu[layer_idx].numpy(),
ep_rank=ep_rank,
)
)
cuda_stream.synchronize()
move_from_buffer(
expert_weights=expert_weights[layer_idx],
expert_weights_buffers=expert_buffer,
is_unchanged=is_unchanged,
is_received_locally=is_received_locally,
recv_metadata=recv_metadata,
new_indices=new_indices_cpu[layer_idx].numpy(),
ep_rank=ep_rank,
)
verify_expert_weights_after_shuffle(
expert_weights,
new_indices,
hidden_sizes,
ep_rank,
num_local_experts,
)
verify_redundant_experts_have_same_weights(
expert_weights,
new_indices,
hidden_sizes,
world_size,
num_local_experts,
)
verify_expert_weights_after_shuffle(
expert_weights,
new_indices,
hidden_sizes,
ep_rank,
num_local_experts,
)
verify_redundant_experts_have_same_weights(
expert_weights,
new_indices,
hidden_sizes,
world_size,
num_local_experts,
)
def _test_rearrange_expert_weights_with_redundancy(
@@ -336,71 +342,76 @@ def _test_rearrange_expert_weights_with_redundancy(
# Initialize model parallel (using tensor parallel as an entrypoint
# to expert parallel)
set_env_vars_and_device(env)
ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
ep_group = get_tp_group().cpu_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
vllm_config = VllmConfig()
vllm_config.parallel_config.tensor_parallel_size = world_size
# Test parameters
total_physical_experts = world_size * num_local_experts
hidden_sizes = [32, 64] # Two different weight matrices
with set_current_vllm_config(vllm_config):
ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
# Create old expert indices (with redundancy)
redundancy_config = create_redundancy_config(
num_logical_experts, total_physical_experts
)
ep_group = get_tp_group().cpu_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
old_indices = create_expert_indices_with_redundancy(
num_layers,
num_logical_experts,
total_physical_experts,
redundancy_config,
)
# Test parameters
total_physical_experts = world_size * num_local_experts
hidden_sizes = [32, 64] # Two different weight matrices
# Create new expert indices (with redundancy)
new_redundancy_config = create_redundancy_config(
num_logical_experts, total_physical_experts
)
new_indices = create_expert_indices_with_redundancy(
num_layers,
num_logical_experts,
total_physical_experts,
new_redundancy_config,
)
# Create old expert indices (with redundancy)
redundancy_config = create_redundancy_config(
num_logical_experts, total_physical_experts
)
# Create expert weights
expert_weights = create_expert_weights(
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
)
old_indices = create_expert_indices_with_redundancy(
num_layers,
num_logical_experts,
total_physical_experts,
redundancy_config,
)
# Execute weight rearrangement
rearrange_expert_weights_inplace(
old_indices,
new_indices,
expert_weights,
ep_group,
is_profile=False,
)
# Create new expert indices (with redundancy)
new_redundancy_config = create_redundancy_config(
num_logical_experts, total_physical_experts
)
new_indices = create_expert_indices_with_redundancy(
num_layers,
num_logical_experts,
total_physical_experts,
new_redundancy_config,
)
# Verify the rearrangement result
verify_expert_weights_after_shuffle(
expert_weights,
new_indices,
hidden_sizes,
ep_rank,
num_local_experts,
)
# Create expert weights
expert_weights = create_expert_weights(
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
)
verify_redundant_experts_have_same_weights(
expert_weights,
new_indices,
hidden_sizes,
world_size,
num_local_experts,
)
# Execute weight rearrangement
rearrange_expert_weights_inplace(
old_indices,
new_indices,
expert_weights,
ep_group,
is_profile=False,
)
# Verify the rearrangement result
verify_expert_weights_after_shuffle(
expert_weights,
new_indices,
hidden_sizes,
ep_rank,
num_local_experts,
)
verify_redundant_experts_have_same_weights(
expert_weights,
new_indices,
hidden_sizes,
world_size,
num_local_experts,
)
@pytest.mark.parametrize(
@@ -444,58 +455,63 @@ def test_rearrange_expert_weights_with_redundancy(
def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
set_env_vars_and_device(env)
ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
ep_group = get_tp_group().cpu_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
vllm_config = VllmConfig()
vllm_config.parallel_config.tensor_parallel_size = world_size
num_layers = 2
num_local_experts = 2
total_physical_experts = world_size * num_local_experts
num_logical_experts = total_physical_experts // 2 # Some redundancy
hidden_sizes = [32, 64]
with set_current_vllm_config(vllm_config):
ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
# Create redundancy configuration
redundancy_config = [2] * num_logical_experts
ep_group = get_tp_group().cpu_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
# Same indices - no change
indices = create_expert_indices_with_redundancy(
num_layers, num_logical_experts, total_physical_experts, redundancy_config
)
num_layers = 2
num_local_experts = 2
total_physical_experts = world_size * num_local_experts
num_logical_experts = total_physical_experts // 2 # Some redundancy
hidden_sizes = [32, 64]
expert_weights = create_expert_weights(
num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices
)
# Create redundancy configuration
redundancy_config = [2] * num_logical_experts
# Save original weights
original_weights = []
for layer_weights in expert_weights:
layer_copy = []
for weight in layer_weights:
layer_copy.append(weight.clone())
original_weights.append(layer_copy)
# Same indices - no change
indices = create_expert_indices_with_redundancy(
num_layers, num_logical_experts, total_physical_experts, redundancy_config
)
# Execute rearrangement (should be no change)
rearrange_expert_weights_inplace(
indices,
indices, # Same indices
expert_weights,
ep_group,
is_profile=False,
)
expert_weights = create_expert_weights(
num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices
)
# Verify that the weights have not changed
for layer in range(num_layers):
for weight_idx in range(len(hidden_sizes)):
torch.testing.assert_close(
expert_weights[layer][weight_idx],
original_weights[layer][weight_idx],
msg=f"""Layer {layer}, weight {weight_idx}
# Save original weights
original_weights = []
for layer_weights in expert_weights:
layer_copy = []
for weight in layer_weights:
layer_copy.append(weight.clone())
original_weights.append(layer_copy)
# Execute rearrangement (should be no change)
rearrange_expert_weights_inplace(
indices,
indices, # Same indices
expert_weights,
ep_group,
is_profile=False,
)
# Verify that the weights have not changed
for layer in range(num_layers):
for weight_idx in range(len(hidden_sizes)):
torch.testing.assert_close(
expert_weights[layer][weight_idx],
original_weights[layer][weight_idx],
msg=f"""Layer {layer}, weight {weight_idx}
should remain unchanged""",
)
)
@pytest.mark.parametrize(
@@ -538,64 +554,69 @@ def test_rearrange_expert_weights_no_change(world_size):
def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
set_env_vars_and_device(env)
ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
ep_group = get_tp_group().cpu_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
vllm_config = VllmConfig()
vllm_config.parallel_config.tensor_parallel_size = world_size
num_layers = 1
num_local_experts = 2
total_physical_experts = world_size * num_local_experts
num_logical_experts = total_physical_experts // 2
hidden_sizes = [32]
with set_current_vllm_config(vllm_config):
ensure_model_parallel_initialized(
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
)
# Create different index distributions
old_redundancy = create_redundancy_config(
num_logical_experts, total_physical_experts
)
new_redundancy = create_redundancy_config(
num_logical_experts, total_physical_experts
)
ep_group = get_tp_group().cpu_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
old_indices = create_expert_indices_with_redundancy(
num_layers, num_logical_experts, total_physical_experts, old_redundancy
)
new_indices = create_expert_indices_with_redundancy(
num_layers, num_logical_experts, total_physical_experts, new_redundancy
)
num_layers = 1
num_local_experts = 2
total_physical_experts = world_size * num_local_experts
num_logical_experts = total_physical_experts // 2
hidden_sizes = [32]
expert_weights = create_expert_weights(
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
)
# Create different index distributions
old_redundancy = create_redundancy_config(
num_logical_experts, total_physical_experts
)
new_redundancy = create_redundancy_config(
num_logical_experts, total_physical_experts
)
# Save original weights
original_weights = []
for layer_weights in expert_weights:
layer_copy = []
for weight in layer_weights:
layer_copy.append(weight.clone())
original_weights.append(layer_copy)
old_indices = create_expert_indices_with_redundancy(
num_layers, num_logical_experts, total_physical_experts, old_redundancy
)
new_indices = create_expert_indices_with_redundancy(
num_layers, num_logical_experts, total_physical_experts, new_redundancy
)
# Execute profile mode rearrangement
rearrange_expert_weights_inplace(
old_indices,
new_indices,
expert_weights,
ep_group,
is_profile=True, # Profile mode
)
expert_weights = create_expert_weights(
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
)
# In profile mode, the weights should remain unchanged
for layer in range(num_layers):
for weight_idx in range(len(hidden_sizes)):
torch.testing.assert_close(
expert_weights[layer][weight_idx],
original_weights[layer][weight_idx],
msg="In profile mode, the weights should remain unchanged",
)
# Save original weights
original_weights = []
for layer_weights in expert_weights:
layer_copy = []
for weight in layer_weights:
layer_copy.append(weight.clone())
original_weights.append(layer_copy)
# Execute profile mode rearrangement
rearrange_expert_weights_inplace(
old_indices,
new_indices,
expert_weights,
ep_group,
is_profile=True, # Profile mode
)
# In profile mode, the weights should remain unchanged
for layer in range(num_layers):
for weight_idx in range(len(hidden_sizes)):
torch.testing.assert_close(
expert_weights[layer][weight_idx],
original_weights[layer][weight_idx],
msg="In profile mode, the weights should remain unchanged",
)
@pytest.mark.parametrize("world_size", [2, 4])
@@ -10,6 +10,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp
import vllm.envs as envs
from tests.utils import ensure_current_vllm_config
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
from vllm.distributed.device_communicators.pynccl import register_nccl_symmetric_ops
@@ -51,7 +52,8 @@ def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
)
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
with ensure_current_vllm_config():
initialize_model_parallel(tensor_model_parallel_size=world_size)
cuda_communicator = typing.cast(
CudaCommunicator, get_tp_group().device_communicator
+3 -1
View File
@@ -9,6 +9,7 @@ import pytest
import torch
import torch.distributed
from tests.utils import ensure_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
@@ -112,7 +113,8 @@ def test_pynccl_multiple_allreduce():
@worker_fn_wrapper
def multiple_allreduce_with_vllm_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
ensure_model_parallel_initialized(2, 2)
with ensure_current_vllm_config():
ensure_model_parallel_initialized(2, 2)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
with graph_capture(device=device):
# two tp groups can communicate independently
+3 -2
View File
@@ -6,7 +6,7 @@ import unittest
import pytest
import torch
from tests.utils import multi_gpu_test
from tests.utils import ensure_current_vllm_config, multi_gpu_test
from vllm.distributed.parallel_state import (
init_distributed_environment,
initialize_model_parallel,
@@ -87,7 +87,8 @@ def mixer2_gated_norm_tensor_parallel(
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
with ensure_current_vllm_config():
initialize_model_parallel(tensor_model_parallel_size=world_size)
# create random weights an inputs
weight = torch.rand((hidden_size,), dtype=dtype, device=device)
+12 -9
View File
@@ -45,21 +45,24 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
@pytest.fixture
def dist_init():
from tests.utils import ensure_current_vllm_config
temp_file = tempfile.mkstemp()[1]
backend = "nccl"
if current_platform.is_cpu() or current_platform.is_tpu():
backend = "gloo"
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend=backend,
)
initialize_model_parallel(1, 1)
yield
with ensure_current_vllm_config():
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend=backend,
)
initialize_model_parallel(1, 1)
yield
cleanup_dist_env_and_memory(shutdown_ray=True)
+3 -2
View File
@@ -6,7 +6,7 @@ import random
import pytest
import torch
from tests.utils import multi_gpu_test
from tests.utils import ensure_current_vllm_config, multi_gpu_test
from vllm import _custom_ops as ops
from vllm.distributed import (
init_distributed_environment,
@@ -631,7 +631,8 @@ def use_fused_moe_lora_kernel_tensor_parallel(
local_rank=local_rank,
distributed_init_method=init_method,
)
initialize_model_parallel(world_size, 1)
with ensure_current_vllm_config():
initialize_model_parallel(world_size, 1)
tp_size = get_tensor_model_parallel_world_size()
input_dim = K if column_parallel else N
+4 -2
View File
@@ -13,6 +13,7 @@ from vllm.config import (
ParallelConfig,
SchedulerConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.config.load import LoadConfig
from vllm.config.lora import LoRAConfig
@@ -77,8 +78,9 @@ def test_worker_apply_lora(qwen3_lora_files):
distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
)
worker.init_device()
worker.load_model()
with set_current_vllm_config(vllm_config):
worker.init_device()
worker.load_model()
set_active_loras(worker, [])
assert worker.list_loras() == set()
+9 -5
View File
@@ -6,7 +6,7 @@ import pytest
import torch
import torch.multiprocessing as mp
from tests.utils import multi_gpu_test
from tests.utils import ensure_current_vllm_config, multi_gpu_test
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (
init_distributed_environment,
@@ -117,7 +117,8 @@ def run_dp_sharded_vision_model_vs_direct(
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
with ensure_current_vllm_config():
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create a test input tensor
image_input = torch.randn(batch_size, 3, 224, 224)
@@ -302,7 +303,8 @@ def run_dp_sharded_mrope_vision_model_vs_direct(
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
with ensure_current_vllm_config():
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create test data
grid_thw_list = []
@@ -377,7 +379,8 @@ def run_dp_sharded_mrope_vision_model_empty_input_worker(
)
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
with ensure_current_vllm_config():
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create empty inputs
pixel_values = torch.empty((0, 768))
@@ -425,7 +428,8 @@ def run_dp_sharded_mrope_vision_model_uneven_load_worker(
)
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
with ensure_current_vllm_config():
initialize_model_parallel(tensor_model_parallel_size=world_size)
# Create images with very different sizes
grid_thw_list = [
+32 -1
View File
@@ -895,6 +895,36 @@ def compare_all_settings(
)
@contextmanager
def ensure_current_vllm_config():
"""Ensures a vllm config is set for the duration of the context.
If a config is already set, this is a no-op. Otherwise, it creates a default
VllmConfig and sets it for the duration of the context.
Used for tests that call functions which require a vllm config but don't
need a specific config.
Example:
with ensure_current_vllm_config():
init_distributed_environment(...)
ensure_model_parallel_initialized(...)
"""
from vllm.config import (
VllmConfig,
get_current_vllm_config_or_none,
set_current_vllm_config,
)
if get_current_vllm_config_or_none() is not None:
# Config already set, just yield
yield
else:
# No config set, create a default one for the duration
with set_current_vllm_config(VllmConfig()):
yield
def init_test_distributed_environment(
tp_size: int,
pp_size: int,
@@ -921,6 +951,7 @@ def init_test_distributed_environment(
distributed_init_method=distributed_init_method,
local_rank=local_rank,
)
ensure_model_parallel_initialized(tp_size, pp_size)
else:
# No config set, create a default one for the test
with set_current_vllm_config(VllmConfig()):
@@ -930,7 +961,7 @@ def init_test_distributed_environment(
distributed_init_method=distributed_init_method,
local_rank=local_rank,
)
ensure_model_parallel_initialized(tp_size, pp_size)
ensure_model_parallel_initialized(tp_size, pp_size)
def multi_process_parallel(
+5 -2
View File
@@ -789,8 +789,11 @@ def test_hybrid_attention_mamba_tensor_shapes():
"MASTER_PORT": "12345",
}
)
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=1)
from tests.utils import ensure_current_vllm_config
with ensure_current_vllm_config():
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=1)
torch.set_default_dtype(torch.float16)
model_config = ModelConfig(
@@ -10,6 +10,7 @@ from unittest.mock import patch
import pytest
import torch
from vllm.config import set_current_vllm_config
from vllm.engine.arg_utils import EngineArgs
from vllm.utils.mem_utils import MemorySnapshot
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
@@ -95,7 +96,12 @@ def worker_process(
side_effect=make_operation_tracker("nccl_all_reduce", original_all_reduce),
)
with init_patch, memory_patch, all_reduce_patch:
with (
init_patch,
memory_patch,
all_reduce_patch,
set_current_vllm_config(vllm_config),
):
# Initialize device (this is where we test the order)
worker.init_device()
+49
View File
@@ -319,3 +319,52 @@ class TorchCompileWithNoGuardsWrapper:
yield
finally:
self.__class__.forward.__code__ = original
def reset_compile_wrapper(model: torch.nn.Module) -> None:
"""
Clean up compiled model and captured CUDA graphs for elastic EP.
"""
if not isinstance(model, TorchCompileWithNoGuardsWrapper) and hasattr(
model, "model"
):
model = model.model
if not isinstance(model, TorchCompileWithNoGuardsWrapper):
return
# model.do_not_compile is set by the @support_torch_compile decorator
if hasattr(model, "do_not_compile") and model.do_not_compile:
return
from vllm.compilation.counter import compilation_counter
# reset the compilation counter
compilation_counter.num_models_seen = 0
compilation_counter.num_graphs_seen = 0
compilation_counter.num_piecewise_graphs_seen = 0
compilation_counter.num_piecewise_capturable_graphs_seen = 0
compilation_counter.num_backend_compilations = 0
compilation_counter.num_gpu_runner_capture_triggers = 0
compilation_counter.num_cudagraph_captured = 0
compilation_counter.num_inductor_compiles = 0
compilation_counter.num_eager_compiles = 0
compilation_counter.num_cache_entries_updated = 0
compilation_counter.num_compiled_artifacts_saved = 0
compilation_counter.stock_torch_compile_count = 0
# Clear the AOT compiled function so the model is forced to
# recompile on the next call. Without this, decorators.py
# __call__ uses the stale aot_compiled_fn whose torchinductor
# kernels have old parameters (expert_map size for example)
# baked in as compile-time constants.
if hasattr(model, "aot_compiled_fn"):
model.aot_compiled_fn = None
if hasattr(model, "was_aot_compile_fn_loaded_from_disk"):
model.was_aot_compile_fn_loaded_from_disk = False
# Reset the cache_dir so VllmBackend recomputes the hash
# (data_parallel_size changed, so the config hash differs).
compilation_config = model.vllm_config.compilation_config
compilation_config.cache_dir = ""
compilation_config.local_cache_dir = ""
model.__class__.forward.__code__ = model.original_code_object()
TorchCompileWithNoGuardsWrapper.__init__(model)
+116 -6
View File
@@ -165,6 +165,9 @@ class ParallelConfig:
disable_custom_all_reduce: bool = False
"""Disable the custom all-reduce kernel and fall back to NCCL."""
enable_elastic_ep: bool = False
"""Enable elastic expert parallelism with stateless NCCL groups for DP/EP."""
enable_dbo: bool = False
"""Enable dual batch overlap for the model executor."""
ubatch_size: int = 0
@@ -244,6 +247,34 @@ class ParallelConfig:
Set to be private as it's not intended to be configured by users.
"""
_stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless DP groups when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
It is a list of list[int], with each inner list contains a set of 3 ports
to be used for setting up the stateless CPU/device/TCPStore groups
in StatelessGroupCoordinator. The number of inner lists is equal to
the number of DP groups,
i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size,
and len(self._stateless_dp_group_port_list[i]) == 3 for all i.
"""
_stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless EP groups when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size,
"""
_stateless_eplb_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless EPLB groups when enable_elastic_ep is True.
Same topology as EP but separate NCCL communicator to avoid deadlocks.
"""
_stateless_world_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless world group when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
len(self._stateless_world_group_port_list) == 1,
"""
decode_context_parallel_size: int = 1
"""Number of decode context parallel groups, because the world size does
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
@@ -402,7 +433,67 @@ class ParallelConfig:
return answer
def stateless_init_dp_group(self) -> ProcessGroup:
def allocate_elastic_ep_ports(self) -> None:
"""Allocate all ports for elastic EP (stateless groups + DP master).
Must be called AFTER ray.init() so that ports claimed by Ray's
idle worker pool are already in use and won't be returned by
get_open_ports_list().
"""
if not self.enable_elastic_ep:
return
if self._stateless_world_group_port_list:
return
num_world_groups = 1
dp_size = self.data_parallel_size
ep_size = self.data_parallel_size * self.world_size_across_dp
num_dp_groups = max(1, self.world_size_across_dp // dp_size)
num_ep_groups = max(1, self.world_size_across_dp // ep_size)
num_eplb_groups = num_ep_groups
total_stateless_ports = (
num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
) * 3
num_dp_master_ports = 5
all_ports = get_open_ports_list(total_stateless_ports + num_dp_master_ports)
self._data_parallel_master_port_list = all_ports[-num_dp_master_ports:]
self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
all_ports = all_ports[:-num_dp_master_ports]
self._stateless_world_group_port_list = [
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
]
start_idx = num_world_groups * 3
self._stateless_dp_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
]
start_idx += num_dp_groups * 3
self._stateless_ep_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
]
start_idx += num_ep_groups * 3
self._stateless_eplb_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
]
def get_next_stateless_world_group_port(self) -> list[int]:
return self._stateless_world_group_port_list.pop()
def get_next_stateless_dp_group_port(self) -> list[int]:
return self._stateless_dp_group_port_list.pop()
def get_next_stateless_ep_group_port(self) -> list[int]:
return self._stateless_ep_group_port_list.pop()
def get_next_stateless_eplb_group_port(self) -> list[int]:
return self._stateless_eplb_group_port_list.pop()
def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup:
# NOTE: In high-concurrency scenarios multiple processes
# can pick the same (currently free) port through a race
# condition when calling `get_open_port()`. When the first
@@ -426,7 +517,8 @@ class ParallelConfig:
self.get_next_dp_init_port(),
self.data_parallel_rank,
self.data_parallel_size,
backend=current_platform.dist_backend,
backend="gloo",
return_store=return_store,
)
except DistNetworkError as e:
# We only want to retry when the root cause is EADDRINUSE.
@@ -561,6 +653,21 @@ class ParallelConfig:
logger.info("Using external launcher for distributed inference.")
self.world_size *= self.data_parallel_size
if self.enable_elastic_ep:
if not self.enable_eplb:
raise ValueError("Elastic EP is only supported with enable_eplb=True.")
if self.pipeline_parallel_size > 1:
raise ValueError(
"Elastic EP is not supported with pipeline parallelism "
f"(pipeline_parallel_size={self.pipeline_parallel_size})."
)
if self.data_parallel_external_lb or self.data_parallel_hybrid_lb:
raise NotImplementedError(
"Elastic EP is not compatible with data_parallel_external_lb "
"or data_parallel_hybrid_lb. Elastic EP relies on a single API "
"server and core client to coordinate scale up/down."
)
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
# Data parallel was specified in the engine args.
if self.distributed_executor_backend == "external_launcher":
@@ -573,9 +680,12 @@ class ParallelConfig:
"Set data_parallel_rank to %d automatically.",
self.data_parallel_rank,
)
if not self._data_parallel_master_port_list:
self._data_parallel_master_port_list = get_open_ports_list(5)
self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
if not self.enable_elastic_ep:
if not self._data_parallel_master_port_list:
self._data_parallel_master_port_list = get_open_ports_list(5)
self.data_parallel_master_port = (
self._data_parallel_master_port_list.pop()
)
if not (0 <= self.data_parallel_rank < self.data_parallel_size):
raise ValueError(
@@ -602,7 +712,7 @@ class ParallelConfig:
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
logger.info("Disabling V1 multiprocessing for external launcher.")
if self.distributed_executor_backend is None and self.world_size > 1:
if self.distributed_executor_backend is None and self.world_size_across_dp > 1:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
@@ -31,8 +31,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
debugging.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def naive_multicast(
self,
@@ -138,8 +138,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
all-gather (dispatch) and reduce-scatter (combine).
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def dispatch_router_logits(
self,
@@ -239,12 +239,12 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
All2All communication based on DeepEP High-Throughput kernels.
"""
def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
assert has_deep_ep(), (
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
" to install DeepEP kernels."
) # noqa
super().__init__(cpu_group)
super().__init__(cpu_group, tcp_store_group)
self.handle_cache = Cache()
# This is the DeepEP default. Stick to it till we can establish
@@ -282,7 +282,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
raise NotImplementedError
def destroy(self):
pass
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()
self.handle_cache._cache.clear()
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
@@ -290,8 +293,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP High-Throughput kernels.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests.
@@ -314,6 +317,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank,
explicitly_destroy=True,
)
def get_handle(self, kwargs):
@@ -347,8 +351,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP Low-Latency kernels.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def _make_all2all_kwargs(
self,
@@ -387,6 +391,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
num_qps_per_rank=num_qps_per_rank,
allow_nvlink_for_low_latency_mode=True,
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
explicitly_destroy=True,
)
def get_handle(self, kwargs):
@@ -418,11 +423,11 @@ class FlashInferAllToAllManager(All2AllManagerBase):
rank: int
world_size: int
def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
assert has_flashinfer_all2all(), (
"flashinfer all2all module not found. Please install/check flashinfer"
) # noqa
super().__init__(cpu_group)
super().__init__(cpu_group, tcp_store_group)
logger.debug(
"Initialize for flashinfer All2All rank=%d, world size=%d",
self.rank,
@@ -29,8 +29,9 @@ class All2AllManagerBase:
rank: int
world_size: int
def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
self.cpu_group = cpu_group
self.tcp_store_group = tcp_store_group
# compute some common properties
from vllm.distributed.parallel_state import (
@@ -47,12 +48,17 @@ class All2AllManagerBase:
# when we create this object
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.rank = cpu_group.rank()
self.world_size = cpu_group.size()
# all2all communication often has separate implementations for
# intra-node and inter-node communication
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
if tcp_store_group is None:
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
else:
self.internode = not all(
in_the_same_node_as(tcp_store_group, source_rank=0)
)
def get_handle(self, kwargs):
# get a handle for the all2all communication,
@@ -121,17 +127,36 @@ class DeviceCommunicatorBase:
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
global_ranks: list[int] | None = None,
global_world_size: int | None = None,
):
self.device = device or torch.device("cpu")
self.cpu_group = cpu_group
self.device_group = device_group
self.unique_name = unique_name
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
# Check if this is a stateless process group
from torch.distributed.distributed_c10d import _world
is_stateless = _world.pg_map.get(cpu_group, None) is None
if is_stateless:
# For stateless groups, we can't use torch.distributed methods
self.rank = cpu_group.rank()
self.world_size = cpu_group.size()
assert global_ranks is not None
assert global_world_size is not None
self.ranks = global_ranks
self.global_rank = self.ranks[self.rank]
self.global_world_size = global_world_size
self.rank_in_group = self.rank
else:
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
use_ep = False
all2all_backend = None
@@ -145,7 +170,7 @@ class DeviceCommunicatorBase:
use_ep = config.parallel_config.data_parallel_size > 1
all2all_backend = config.parallel_config.all2all_backend
self.is_ep_communicator = "ep" in unique_name
self.is_ep_communicator = unique_name.split(":")[0] == "ep"
self.use_all2all = self.is_ep_communicator and use_ep
self.all2all_backend = all2all_backend
self.all2all_manager: All2AllManagerBase | None = None
@@ -275,6 +300,13 @@ class DeviceCommunicatorBase:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcast a tensor from source rank to all ranks."""
if self.world_size == 1:
return tensor
torch.distributed.broadcast(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self):
pass
@@ -343,3 +375,6 @@ class DeviceCommunicatorBase:
This is a no-op in the base class.
"""
return hidden_states
def batch_isend_irecv(self, p2p_ops: list):
raise NotImplementedError
@@ -16,6 +16,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import (
from vllm.logger import init_logger
from vllm.platforms import current_platform
from ..utils import StatelessProcessGroup
from .base_device_communicator import DeviceCommunicatorBase
logger = init_logger(__name__)
@@ -28,8 +29,18 @@ class CudaCommunicator(DeviceCommunicatorBase):
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
global_ranks: list[int] | None = None,
global_world_size: int | None = None,
tcp_store_group: StatelessProcessGroup | None = None,
):
super().__init__(cpu_group, device, device_group, unique_name)
super().__init__(
cpu_group,
device,
device_group,
unique_name,
global_ranks,
global_world_size,
)
if "tp" not in unique_name:
# custom allreduce or torch symm mem can be used only by tp
use_custom_allreduce = False
@@ -62,7 +73,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.pynccl_comm: PyNcclCommunicator | None = None
if self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
group=self.cpu_group if tcp_store_group is None else tcp_store_group,
device=self.device,
)
if is_symmetric_memory_enabled():
@@ -107,19 +118,27 @@ class CudaCommunicator(DeviceCommunicatorBase):
if self.all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
self.all2all_manager = NaiveAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "allgather_reducescatter":
from .all2all import AgRsAll2AllManager
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
self.all2all_manager = AgRsAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "deepep_high_throughput":
from .all2all import DeepEPHTAll2AllManager
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
self.all2all_manager = DeepEPHTAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "deepep_low_latency":
from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
self.all2all_manager = DeepEPLLAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "mori":
from .all2all import MoriAll2AllManager
@@ -127,7 +146,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
elif self.all2all_backend == "flashinfer_all2allv":
from .all2all import FlashInferAllToAllManager
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
self.all2all_manager = FlashInferAllToAllManager(
self.cpu_group, tcp_store_group
)
else:
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
@@ -284,6 +305,18 @@ class CudaCommunicator(DeviceCommunicatorBase):
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcast a tensor from source rank to all ranks."""
if self.world_size == 1:
return tensor
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.broadcast(tensor, src)
return tensor
else:
raise ValueError("No PyNCCL communicator found")
def destroy(self):
if self.pynccl_comm is not None:
self.pynccl_comm = None
@@ -403,3 +436,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
hidden_states,
is_sequence_parallel,
)
def batch_isend_irecv(self, p2p_ops: list):
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.batch_isend_irecv(p2p_ops)
else:
raise ValueError("No PyNCCL communicator found")
@@ -312,10 +312,19 @@ class PyNcclCommunicator:
)
if stream is None:
stream = current_stream()
if tensor.dtype in [
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
]:
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
else:
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
self.nccl.ncclSend(
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
nccl_dtype,
dst,
self.comm,
cudaStream_t(stream.cuda_stream),
@@ -330,10 +339,19 @@ class PyNcclCommunicator:
)
if stream is None:
stream = current_stream()
if tensor.dtype in [
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
]:
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
else:
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
self.nccl.ncclRecv(
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
nccl_dtype,
src,
self.comm,
cudaStream_t(stream.cuda_stream),
@@ -384,3 +402,17 @@ class PyNcclCommunicator:
def deregister_comm_window(self, window):
return self.nccl.ncclCommWindowDeregister(self.comm, window)
def batch_isend_irecv(self, p2p_ops: list, stream=None):
if self.disabled:
return
if stream is None:
stream = current_stream()
self.group_start()
for op in p2p_ops:
if op.op is torch.distributed.isend:
self.send(op.tensor, op.group_peer, stream)
elif op.op is torch.distributed.irecv:
self.recv(op.tensor, op.group_peer, stream)
self.group_end()
@@ -0,0 +1,529 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import gc
import weakref
from collections.abc import Iterable, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import P2POp
from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.wrapper import reset_compile_wrapper
from vllm.config import (
CompilationMode,
set_current_vllm_config,
)
from vllm.distributed import (
get_dp_group,
get_ep_group,
get_pcp_group,
get_tp_group,
)
from vllm.distributed.elastic_ep.standby_state import (
create_standby_groups,
get_standby_dp_group,
get_standby_ep_group,
pop_standby_groups,
)
from vllm.distributed.parallel_state import (
_replace_active_groups,
prepare_communication_buffer_for_model,
)
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.workspace import lock_workspace, unlock_workspace
logger = init_logger(__name__)
def batch_transfer_weights(
model: nn.Module,
is_sender: bool,
peer_rank: int,
dp_group: StatelessGroupCoordinator,
expert_weights: Sequence[Iterable[torch.Tensor]],
) -> None:
device_comm = dp_group.device_communicator
if device_comm is None:
raise ValueError("No device communicator found")
expert_weights_set = set()
for weight_group in expert_weights:
for weight in weight_group:
expert_weights_set.add(weight.data_ptr())
state_dict = model.state_dict()
all_params = []
for name, param in state_dict.items():
if name.endswith("expert_map"):
continue
if param.data_ptr() not in expert_weights_set:
all_params.append(param.data)
assert len(all_params) > 0
p2p_ops = []
for param in all_params:
op = object.__new__(P2POp)
if is_sender:
op.op = torch.distributed.isend
op.tensor = param
else:
op.op = torch.distributed.irecv
op.tensor = param
op.group_peer = peer_rank
p2p_ops.append(op)
device_comm.batch_isend_irecv(p2p_ops)
def broadcast_expert_mapping(
physical_to_logical: torch.Tensor | None,
num_local_physical_experts: int | None,
num_logical_experts: int | None,
dp_group: StatelessGroupCoordinator,
device: torch.device,
src_rank: int = 0,
) -> tuple[torch.Tensor, int, int]:
if dp_group.rank_in_group == src_rank:
assert physical_to_logical is not None
assert num_local_physical_experts is not None
assert num_logical_experts is not None
assert physical_to_logical.dtype == torch.int64
shape_tensor = torch.tensor(
list(physical_to_logical.shape), dtype=torch.int64, device="cpu"
)
metadata_tensor = torch.tensor(
[num_local_physical_experts, num_logical_experts],
dtype=torch.int64,
device="cpu",
)
else:
shape_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
metadata_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
shape_tensor = dp_group.tcp_store_group.broadcast(shape_tensor, src_rank)
metadata_tensor = dp_group.tcp_store_group.broadcast(metadata_tensor, src_rank)
if dp_group.rank_in_group != src_rank:
assert device is not None
physical_to_logical = torch.empty(
tuple(shape_tensor.tolist()),
dtype=torch.int64,
device=device,
)
assert physical_to_logical is not None
physical_to_logical = dp_group.broadcast(physical_to_logical, src_rank)
num_local_physical_experts = int(metadata_tensor[0].item())
num_logical_experts = int(metadata_tensor[1].item())
return physical_to_logical, num_local_physical_experts, num_logical_experts
class ElasticEPScalingExecutor:
def __init__(self, worker):
self.worker_ref = weakref.ref(worker)
self.reconfig_request = None
@property
def worker(self):
worker = self.worker_ref()
if worker is None:
raise RuntimeError("Worker has been garbage collected")
return worker
def execute(self, execute_method: str, *args, **kwargs):
method = getattr(self, execute_method, None)
if method is None:
raise ValueError(f"Unknown execute method: {execute_method}")
return method(*args, **kwargs)
def create_standby_groups(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
self.reconfig_request = reconfig_request
new_dp_size = reconfig_request.new_data_parallel_size
world_size = self.worker.vllm_config.parallel_config.world_size
new_world_size_across_dp = world_size * new_dp_size
updated_config = copy.copy(self.worker.vllm_config)
updated_config.parallel_config = copy.deepcopy(
self.worker.vllm_config.parallel_config
)
updated_config.parallel_config.data_parallel_size = new_dp_size
with set_current_vllm_config(updated_config):
create_standby_groups(
new_dp_size=new_dp_size,
new_world_size_across_dp=new_world_size_across_dp,
master_ip=reconfig_request.new_data_parallel_master_ip,
world_group_ports=reconfig_request.new_stateless_world_group_port_list,
dp_group_ports=reconfig_request.new_stateless_dp_group_port_list,
ep_group_ports=reconfig_request.new_stateless_ep_group_port_list,
eplb_group_ports=reconfig_request.new_stateless_eplb_group_port_list,
)
self.worker.model_runner.eep_eplb_suppressed = True
standby_ep_group = get_standby_ep_group()
assert standby_ep_group is not None
if standby_ep_group.rank == 0:
logger.info("[Elastic EP] EPLB disabled during elastic scaling transition")
def transfer_weights(self, old_dp_size: int, new_dp_size: int) -> None:
standby_dp_group = get_standby_dp_group()
assert standby_dp_group is not None
# Broadcast old_dp_size to all workers in standby group
if standby_dp_group.rank_in_group < old_dp_size:
old_dp_size_tensor = torch.tensor(
[old_dp_size], dtype=torch.int64, device="cpu"
)
else:
old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
old_dp_size_tensor = standby_dp_group.tcp_store_group.broadcast(
old_dp_size_tensor, 0
)
num_new_workers = new_dp_size - old_dp_size
dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank
# Sender-receiver pairing: the first new_workers % old_dp_size
# senders get (k+1) contiguous receivers, the rest get k
# receivers.
num_dst_per_sender = num_new_workers // old_dp_size
remainder = num_new_workers % old_dp_size
if dp_rank < remainder:
recv_begin = dp_rank * (num_dst_per_sender + 1)
recv_end = recv_begin + num_dst_per_sender + 1
else:
recv_begin = (
remainder * (num_dst_per_sender + 1)
+ (dp_rank - remainder) * num_dst_per_sender
)
recv_end = recv_begin + num_dst_per_sender
ranks_to_send = list(range(old_dp_size + recv_begin, old_dp_size + recv_end))
model = self.worker.model_runner.get_model()
for new_worker_rank in sorted(ranks_to_send):
batch_transfer_weights(
model=model,
is_sender=True,
peer_rank=new_worker_rank,
dp_group=standby_dp_group,
expert_weights=model.expert_weights,
)
torch.cuda.synchronize()
def broadcast_expert_mapping(self) -> None:
standby_dp_group = get_standby_dp_group()
assert standby_dp_group is not None
model_config = self.worker.model_runner.model_config
eplb_state = self.worker.model_runner.eplb_state
assert eplb_state is not None
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
physical_to_logical = eplb_model_state.physical_to_logical_map
num_physical_experts = physical_to_logical.shape[1]
num_local_physical_experts = num_physical_experts // get_ep_group().world_size
num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
broadcast_expert_mapping(
physical_to_logical=physical_to_logical,
num_local_physical_experts=num_local_physical_experts,
num_logical_experts=num_logical_experts,
dp_group=standby_dp_group,
src_rank=0,
device=self.worker.device,
)
def switch_and_remove(self) -> None:
_replace_active_groups(world=None, dp=None, ep=None, eplb=None, node_count=None)
def switch_and_prepare(self) -> None:
old_dp_size = get_dp_group().world_size
old_ep_size = get_ep_group().world_size
_replace_active_groups(**pop_standby_groups())
parallel_config = self.worker.vllm_config.parallel_config
reconfig_request = self.reconfig_request
assert reconfig_request is not None
new_dp_size = reconfig_request.new_data_parallel_size
new_ep_size = get_ep_group().world_size
parallel_config.data_parallel_size = new_dp_size
if (
reconfig_request.new_data_parallel_rank
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
if (
reconfig_request.new_data_parallel_rank_local
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank_local = (
reconfig_request.new_data_parallel_rank_local
)
parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
# Reconfigure MoE modules with new EP size
moe_modules = [
module
for module in self.worker.model_runner.model.modules()
if (
module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE"
)
]
num_local_experts = moe_modules[0].moe_config.num_local_experts
assert all(
module.moe_config.num_local_experts == num_local_experts
for module in moe_modules
), "All MoE modules must have the same number of experts"
for module in moe_modules:
module.moe_config.num_experts = num_local_experts * new_ep_size
module.global_num_experts = module.moe_config.num_experts
tp_size = get_tp_group().world_size
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
sp_size = tp_size if is_sequence_parallel else 1
module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=tp_size,
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
sp_size_=sp_size,
vllm_parallel_config=parallel_config,
)
module.moe_config.moe_parallel_config = module.moe_parallel_config
# Update EPLB state
eplb_state = self.worker.model_runner.eplb_state
assert eplb_state is not None
model_config = self.worker.model_runner.model_config
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
num_physical_experts = num_local_experts * new_ep_size
num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
parallel_config.eplb_config.num_redundant_experts = (
num_physical_experts - num_logical_experts
)
old_physical_to_logical = eplb_model_state.physical_to_logical_map
num_moe_layers = old_physical_to_logical.shape[0]
num_local_experts = eplb_model_state.expert_load_pass.shape[1] // old_ep_size
if new_dp_size > old_dp_size:
expanded_physical_to_logical = torch.full(
(num_moe_layers, num_local_experts * new_ep_size),
-1,
dtype=old_physical_to_logical.dtype,
device=old_physical_to_logical.device,
)
expanded_physical_to_logical[:, : num_local_experts * old_ep_size] = (
old_physical_to_logical
)
eplb_model_state.physical_to_logical_map = expanded_physical_to_logical
old_num_physical_experts = eplb_model_state.expert_load_pass.shape[1]
pad_size = num_physical_experts - old_num_physical_experts
if new_dp_size > old_dp_size:
assert pad_size > 0
expanded_expert_load_pass = F.pad(
eplb_model_state.expert_load_pass, (0, pad_size), value=0
)
expanded_expert_load_window = F.pad(
eplb_model_state.expert_load_window, (0, pad_size), value=0
)
eplb_model_state.expert_load_pass = expanded_expert_load_pass
eplb_model_state.expert_load_window = expanded_expert_load_window
eplb_state.num_valid_physical_experts = old_num_physical_experts
else:
assert pad_size < 0
eplb_model_state.expert_load_pass = eplb_model_state.expert_load_pass[
:, :num_physical_experts
]
eplb_model_state.expert_load_window = eplb_model_state.expert_load_window[
:, :, :num_physical_experts
]
eplb_state.num_valid_physical_experts = num_physical_experts
model = self.worker.model_runner.get_model()
model.expert_weights = []
with set_current_vllm_config(self.worker.vllm_config):
model.set_eplb_state(
eplb_model_state.expert_load_pass,
eplb_model_state.logical_to_physical_map,
eplb_model_state.logical_replica_count,
)
model.update_physical_experts_metadata(
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_local_experts,
)
# Force re-creation of the modular kernel (and all2all manager)
# for the new EP size by resetting quant_method to base
for module in moe_modules:
if hasattr(module.quant_method, "old_quant_method"):
module.quant_method = module.quant_method.old_quant_method
module.runner = module._init_runner()
prepare_communication_buffer_for_model(self.worker.model_runner.model)
if (
self.worker.vllm_config.compilation_config.mode
== CompilationMode.STOCK_TORCH_COMPILE
):
# NOTE(yongji): when using stock torch.compile,
# torch.compile is triggered during GPUModelRunner's load_model()
# TODO(yongji):check do we need to re-trigger torch.compile here?
# any changes to the tensor shapes in execution should already
# be handled internally by torch.compile.
backend = self.worker.vllm_config.compilation_config.init_backend(
self.worker.vllm_config
)
compilation_counter.stock_torch_compile_count += 1
self.worker.model_runner.model.compile(fullgraph=True, backend=backend)
# release all previously captured CUDA graphs
if isinstance(self.worker.model_runner.model, CUDAGraphWrapper):
wrapper = self.worker.model_runner.model
wrapper.concrete_cudagraph_entries = {}
elif isinstance(self.worker.model_runner.model, UBatchWrapper):
raise RuntimeError("DBO is not yet supported in elastic EP")
multi_block_table = self.worker.model_runner.input_batch.block_table
saved_block_tables: list[tuple[torch.Tensor, torch.Tensor]] = []
for bt in multi_block_table.block_tables:
saved_block_tables.append(
(bt.block_table.gpu.clone(), bt.block_table.cpu.clone())
)
multi_block_table.clear()
# reset the compile wrapper
torch.compiler.reset()
with set_current_vllm_config(self.worker.vllm_config):
reset_compile_wrapper(self.worker.model_runner.get_model())
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
unlock_workspace()
self.worker.compile_or_warm_up_model()
lock_workspace()
for bt, (saved_gpu, saved_cpu) in zip(
multi_block_table.block_tables, saved_block_tables
):
bt.block_table.gpu.copy_(saved_gpu)
bt.block_table.cpu.copy_(saved_cpu)
def perform_eplb_reshuffle(self, new_dp_size: int | None = None) -> None:
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Starting expert resharding...")
eplb_state = self.worker.model_runner.eplb_state
assert eplb_state is not None
model_config = self.worker.model_runner.model_config
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
is_async_enabled = eplb_state.is_async
eplb_state.is_async = False
if new_dp_size is None:
eplb_state.rearrange()
else:
# scale down
parallel_config = self.worker.vllm_config.parallel_config
tp_size = parallel_config.tensor_parallel_size
old_ep_size = parallel_config.data_parallel_size * tp_size
new_ep_size = new_dp_size * tp_size
rank_mapping = {
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
for old_ep_rank in range(old_ep_size)
}
eplb_state.rearrange(rank_mapping=rank_mapping)
# NOTE(yongji): check whether we need to synchronize here
torch.cuda.synchronize()
# reset expert_rearrangement_step to ensure all ranks are synchronized
eplb_state.expert_rearrangement_step = 0
eplb_state.num_valid_physical_experts = (
eplb_model_state.physical_to_logical_map.shape[1]
)
eplb_state.is_async = is_async_enabled
self.worker.model_runner.eep_eplb_suppressed = False
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed")
def receive_weights(self) -> None:
dp_group = get_dp_group()
assert isinstance(dp_group, StatelessGroupCoordinator)
new_dp_size = dp_group.world_size
dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank
# Receive old_dp_size broadcasted during transfer_weights
old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
old_dp_size_tensor = dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0)
old_dp_size = int(old_dp_size_tensor[0].item())
# Calculate which existing worker will send to this new worker
num_new_workers = new_dp_size - old_dp_size
new_worker_idx = dp_rank - old_dp_size
num_dst_per_sender = num_new_workers // old_dp_size
remainder = num_new_workers % old_dp_size
if new_worker_idx < remainder * (num_dst_per_sender + 1):
sender_rank = new_worker_idx // (num_dst_per_sender + 1)
else:
sender_rank = (
remainder
+ (new_worker_idx - remainder * (num_dst_per_sender + 1))
// num_dst_per_sender
)
model = self.worker.model_runner.get_model()
batch_transfer_weights(
model=model,
is_sender=False,
peer_rank=sender_rank,
dp_group=dp_group,
expert_weights=model.expert_weights,
)
torch.cuda.synchronize()
def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]:
dp_group = get_dp_group()
assert isinstance(dp_group, StatelessGroupCoordinator)
physical_to_logical, num_local_physical_experts, num_logical_experts = (
broadcast_expert_mapping(
physical_to_logical=None,
num_local_physical_experts=None,
num_logical_experts=None,
dp_group=dp_group,
src_rank=0,
device=self.worker.device,
)
)
num_moe_layers = physical_to_logical.shape[0]
new_dp_size = get_dp_group().world_size
tp_size = self.worker.vllm_config.parallel_config.tensor_parallel_size
new_ep_size = new_dp_size * tp_size
expanded_physical_to_logical = torch.full(
(num_moe_layers, num_local_physical_experts * new_ep_size),
-1,
dtype=physical_to_logical.dtype,
device=physical_to_logical.device,
)
old_num_physical_experts = physical_to_logical.shape[1]
expanded_physical_to_logical[:, :old_num_physical_experts] = physical_to_logical
return (
expanded_physical_to_logical,
num_logical_experts,
old_num_physical_experts,
)
def prepare_new_worker(self) -> None:
with set_current_vllm_config(self.worker.vllm_config):
prepare_communication_buffer_for_model(self.worker.model_runner.get_model())
@@ -0,0 +1,563 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import time
import weakref
from datetime import timedelta
from typing import TYPE_CHECKING, Literal
import torch.distributed
from vllm.config import ParallelConfig
from vllm.distributed import (
sched_yield,
stateless_destroy_torch_distributed_process_group,
)
from vllm.logger import init_logger
from vllm.v1.engine import (
EEPNotificationType,
ReconfigureDistributedRequest,
ReconfigureRankType,
)
from vllm.v1.engine.core import DPEngineCoreProc
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.v1.executor.abstract import Executor
logger = init_logger(__name__)
WorkerType = Literal["existing", "new", "removing"]
class ScaleUpExistingEngineState(enum.IntEnum):
WAIT_NEW_CORE_ENGINES_INIT = 0
CREATE_STANDBY_GROUPS = 1
TRANSFER_EXPERT_MAPPING = 2
WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT = 3
TRANSFER_WEIGHTS = 4
SYNC_KV_CACHE_MEMORY_SIZE = 5
SWITCH_AND_PREPARE = 6
EPLB_RESHUFFLE = 7
COMPLETE = 8
class ScaleUpNewEngineState(enum.IntEnum):
PREPARE = 0
EPLB_RESHUFFLE = 1
COMPLETE = 2
class ScaleDownRemainingEngineState(enum.IntEnum):
PREPARE = 0
EPLB_RESHUFFLE = 1
SWITCH_AND_PREPARE = 2
COMPLETE = 3
class ScaleDownRemovingEngineState(enum.IntEnum):
PREPARE = 0
EPLB_RESHUFFLE = 1
COMPLETE = 2
class _BarrierTimeoutError(RuntimeError):
"""
Exception raised for timeout
in the first stage of our two-staged
TCPStore based barrier to synchronize the
execution of all engines in the DP group.
"""
class ElasticEPScalingState:
def __init__(
self,
model_executor: "Executor",
engine_core: "DPEngineCoreProc",
vllm_config: "VllmConfig",
new_parallel_config: ParallelConfig,
worker_type: WorkerType,
scale_type: Literal["scale_up", "scale_down"],
reconfig_request: ReconfigureDistributedRequest | None = None,
):
self.model_executor_ref = weakref.ref(model_executor)
self.engine_core_ref = weakref.ref(engine_core)
self.vllm_config = vllm_config
self.old_dp_group = self.engine_core.dp_group if worker_type != "new" else None
self.old_dp_store = self.engine_core.dp_store if worker_type != "new" else None
self.new_parallel_config: ParallelConfig = new_parallel_config
self.new_dp_group: torch.distributed.ProcessGroup | None = (
self.engine_core.dp_group if worker_type == "new" else None
)
self.new_dp_store = self.engine_core.dp_store if worker_type == "new" else None
self.worker_type = worker_type
self.scale_type = scale_type
self.reconfig_request = reconfig_request
if scale_type == "scale_up":
self.state = (
ScaleUpNewEngineState.PREPARE
if worker_type == "new"
else ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
)
else:
self.state = (
ScaleDownRemovingEngineState.PREPARE
if worker_type == "removing"
else ScaleDownRemainingEngineState.PREPARE
)
@property
def model_executor(self) -> "Executor":
model_executor = self.model_executor_ref()
if model_executor is None:
raise RuntimeError("Model executor has been garbage collected")
return model_executor
@property
def engine_core(self) -> "DPEngineCoreProc":
engine_core = self.engine_core_ref()
if engine_core is None:
raise RuntimeError("Engine core has been garbage collected")
return engine_core
def progress(self) -> bool:
if self.scale_type == "scale_up":
return (
self._progress_new_engine()
if self.worker_type == "new"
else self._progress_existing_engine()
)
return (
self._progress_removing_engine()
if self.worker_type == "removing"
else self._progress_remaining_engine()
)
def _execute_tcp_store_barrier(
self, dp_store, group_rank, group_size, barrier_id, timeout=None
):
arrival_key = f"arrival_{barrier_id}_{group_rank}"
dp_store.set(arrival_key, b"1")
start_time = time.time()
processes_arrived: set[int] = set()
while len(processes_arrived) < group_size:
if (
timeout is not None
and time.time() - start_time > timeout.total_seconds()
):
raise _BarrierTimeoutError(
f"Barrier timed out after {timeout.total_seconds()} seconds"
)
for i in range(group_size):
if i in processes_arrived:
continue
key = f"arrival_{barrier_id}_{i}"
present = dp_store.check([key])
if present:
processes_arrived.add(i)
if len(processes_arrived) < group_size:
sched_yield()
def _staged_barrier(self, use_new_group: bool, barrier_name: str) -> bool:
"""
Execute a two-staged barrier to synchronize all engines in the DP group.
Some DP EngineCores may receive the reconfiguration notifications
later than others, and already proceed to engine step (model forward)
in the busy loop.
In this case, EngineCores that already proceed to reconfiguration
should skip reconfiguration and execute model forward for one more
step, so in the next step, all EngineCores will be synchronized.
We use a two-staged barrier to achieve this. The first time each
EngineCore executes the barrier, if a timeout is reached before the
barrier completes, that means some EngineCores have already entered
engine step. The EngineCores that timed out will then proceed to
engine step, and will synchronize with the other EngineCores in the
next step with a barrier without timeout.
"""
dp_store = self.new_dp_store if use_new_group else self.old_dp_store
dp_group = self.new_dp_group if use_new_group else self.old_dp_group
assert dp_group is not None
group_rank = dp_group.rank()
group_size = dp_group.size()
barrier_id = f"eep_barrier_{barrier_name}"
sync_key = f"{barrier_id}_sync"
# TODO(yongji): figure out appropriate timeout for the barrier
timeout = None if dp_store.check([sync_key]) else timedelta(seconds=5)
try:
self._execute_tcp_store_barrier(
dp_store, group_rank, group_size, barrier_id, timeout=timeout
)
torch.distributed.barrier(dp_group)
if group_rank == 0:
dp_store.delete_key(sync_key)
for i in range(group_size):
dp_store.delete_key(f"arrival_{barrier_id}_{i}")
return True
except _BarrierTimeoutError as e:
if timeout is None:
raise RuntimeError("Unexpected timeout encountered") from e
dp_store.compare_set(sync_key, "", b"1")
return False
def _progress_existing_engine(self) -> bool:
state = self.state
if state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT:
return False
elif state == ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS:
# NOTE(yongji): wait for all existing workers to receive the request
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=False, barrier_name="create_standby_groups"
):
return False
if self.old_dp_group.rank() == 0:
self.old_dp_store.delete_key("eep_barrier_engine_count")
self._create_standby_groups()
self.state = ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING
return True
elif state == ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING:
self._transfer_expert_mapping()
self.state = ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
return True
elif state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT:
return False
elif state == ScaleUpExistingEngineState.TRANSFER_WEIGHTS:
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=False, barrier_name="transfer_weights"
):
return False
if self.old_dp_group.rank() == 0:
self.old_dp_store.delete_key("eep_barrier_engine_count")
self._transfer_weights()
self.state = ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE
return True
elif state == ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE:
self._sync_kv_cache_memory_size()
self.state = ScaleUpExistingEngineState.SWITCH_AND_PREPARE
return True
elif state == ScaleUpExistingEngineState.SWITCH_AND_PREPARE:
self._switch_and_prepare()
self.state = ScaleUpExistingEngineState.EPLB_RESHUFFLE
self.new_dp_store.add("eep_barrier_engine_count", 1)
return True
elif state == ScaleUpExistingEngineState.EPLB_RESHUFFLE:
assert self.new_dp_group is not None
if (
int(self.new_dp_store.get("eep_barrier_engine_count"))
< self.new_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=True, barrier_name="eplb_reshuffle"
):
return False
if self.new_dp_group.rank() == 0:
self.new_dp_store.delete_key("eep_barrier_engine_count")
self._eplb_reshuffle()
self.state = ScaleUpExistingEngineState.COMPLETE
self._update_parallel_config()
return True
else:
assert self.state == ScaleUpExistingEngineState.COMPLETE
return True
def _progress_new_engine(self) -> bool:
state = self.state
assert self.new_dp_group is not None
if state == ScaleUpNewEngineState.PREPARE:
tensor = torch.tensor([0, 0, 0], dtype=torch.int32, device="cpu")
torch.distributed.all_reduce(
tensor,
op=torch.distributed.ReduceOp.MAX,
group=self.new_dp_group,
)
data = tensor.tolist()
self.engine_core.engines_running = bool(data[0])
self.engine_core.current_wave = int(data[1])
self.engine_core.step_counter = int(data[2])
self.state = ScaleUpNewEngineState.EPLB_RESHUFFLE
self.new_dp_store.add("eep_barrier_engine_count", 1)
return True
elif state == ScaleUpNewEngineState.EPLB_RESHUFFLE:
if (
int(self.new_dp_store.get("eep_barrier_engine_count"))
< self.new_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=True, barrier_name="eplb_reshuffle"
):
return False
assert self.new_dp_group.rank() > 0
self._eplb_reshuffle()
self.state = ScaleUpNewEngineState.COMPLETE
return True
else:
assert self.state == ScaleUpNewEngineState.COMPLETE
return True
def _progress_remaining_engine(self) -> bool:
state = self.state
if state == ScaleDownRemainingEngineState.PREPARE:
self.state = ScaleDownRemainingEngineState.EPLB_RESHUFFLE
self.old_dp_store.add("eep_barrier_engine_count", 1)
return True
elif state == ScaleDownRemainingEngineState.EPLB_RESHUFFLE:
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=False, barrier_name="eplb_reshuffle"
):
return False
if self.old_dp_group.rank() == 0:
self.old_dp_store.delete_key("eep_barrier_engine_count")
self._eplb_reshuffle_before_scale_down()
self.state = ScaleDownRemainingEngineState.SWITCH_AND_PREPARE
# NOTE(yongji): currently, after EPLB reshuffle
# that redistributes experts to remaining workers, workers
# to be removed will immediately initiate shutdown;
# existing workers can no longer execute forward steps using
# the old setup. In the future, we may keep
# the removing workers alive a bit longer,
# e.g., to drain in-batch requests.
self._create_standby_groups()
self._switch_and_prepare()
self._update_parallel_config()
self.state = ScaleDownRemainingEngineState.COMPLETE
return True
else:
assert self.state == ScaleDownRemainingEngineState.COMPLETE
return True
def _progress_removing_engine(self) -> bool:
state = self.state
if state == ScaleDownRemovingEngineState.PREPARE:
self.state = ScaleDownRemovingEngineState.EPLB_RESHUFFLE
self.old_dp_store.add("eep_barrier_engine_count", 1)
return True
if state == ScaleDownRemovingEngineState.EPLB_RESHUFFLE:
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=False, barrier_name="eplb_reshuffle"
):
return False
assert self.old_dp_group.rank() > 0
self._eplb_reshuffle_before_scale_down()
self._switch_and_remove()
self.state = ScaleDownRemovingEngineState.COMPLETE
self.engine_core._eep_send_engine_core_notification(
EEPNotificationType.SHUTDOWN_COMPLETE
)
self.engine_core.shutdown()
return True
else:
assert self.state == ScaleDownRemovingEngineState.COMPLETE
return True
def handle_notification(self, notification_type: EEPNotificationType):
assert self.worker_type != "new"
if (
notification_type == EEPNotificationType.NEW_CORE_ENGINES_INIT_READY
and self.state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
):
self.old_dp_store.add("eep_barrier_engine_count", 1)
self.state = ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS
elif (
notification_type == EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
and self.state
== ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
):
self.old_dp_store.add("eep_barrier_engine_count", 1)
self.state = ScaleUpExistingEngineState.TRANSFER_WEIGHTS
def is_complete(self) -> bool:
if self.scale_type == "scale_up":
return (
self.state == ScaleUpNewEngineState.COMPLETE
if self.worker_type == "new"
else self.state == ScaleUpExistingEngineState.COMPLETE
)
return (
self.state == ScaleDownRemovingEngineState.COMPLETE
if self.worker_type == "removing"
else self.state == ScaleDownRemainingEngineState.COMPLETE
)
def _create_standby_groups(self):
self.new_dp_group, self.new_dp_store = (
self.new_parallel_config.stateless_init_dp_group(return_store=True)
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("create_standby_groups", self.reconfig_request)
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] Created standby communication groups")
def _transfer_weights(self):
assert self.reconfig_request is not None
old_dp_size = self.old_dp_group.size()
new_dp_size = self.reconfig_request.new_data_parallel_size
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("transfer_weights", old_dp_size, new_dp_size)
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] Transferred weights to new workers")
def _transfer_expert_mapping(self):
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("broadcast_expert_mapping",)
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] Broadcasted expert mapping to new workers")
def _sync_kv_cache_memory_size(self):
assert self.engine_core.available_gpu_memory_for_kv_cache > 0
assert self.new_dp_group is not None
ParallelConfig.sync_kv_cache_memory_size(
self.new_dp_group,
self.engine_core.available_gpu_memory_for_kv_cache,
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] Synced KV cache memory size to new workers")
def _switch_and_prepare(self):
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("switch_and_prepare",)
)
old_dp_group = self.old_dp_group
stateless_destroy_torch_distributed_process_group(old_dp_group)
assert self.new_dp_group is not None
new_dp_group = self.new_dp_group
self.engine_core.dp_group = new_dp_group
self.engine_core.dp_rank = new_dp_group.rank()
self.engine_core.dp_store = self.new_dp_store
engines_running = int(self.engine_core.engines_running)
current_wave = self.engine_core.current_wave
step_counter = self.engine_core.step_counter
tensor = torch.tensor(
[engines_running, current_wave, step_counter],
dtype=torch.int32,
device="cpu",
)
torch.distributed.all_reduce(
tensor, op=torch.distributed.ReduceOp.MAX, group=new_dp_group
)
data = tensor.tolist()
self.engine_core.engines_running = bool(data[0])
self.engine_core.current_wave = int(data[1])
self.engine_core.step_counter = int(data[2])
if new_dp_group.rank() == 0:
self.engine_core._eep_send_engine_core_notification(
EEPNotificationType.RECONFIGURE_FINISHED
)
logger.info("[Elastic EP] Switched to new setup")
def _eplb_reshuffle(self):
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("perform_eplb_reshuffle",)
)
assert self.new_dp_group is not None
if self.new_dp_group.rank() == 0:
logger.info("[Elastic EP] EPLB reshuffle completed")
def _eplb_reshuffle_before_scale_down(self):
assert self.reconfig_request is not None
self.model_executor.collective_rpc(
"elastic_ep_execute",
args=(
"perform_eplb_reshuffle",
self.reconfig_request.new_data_parallel_size,
),
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] EPLB reshuffle completed")
def _switch_and_remove(self):
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("switch_and_remove",)
)
def _update_parallel_config(self):
assert self.reconfig_request is not None
reconfig_request = self.reconfig_request
parallel_config = self.vllm_config.parallel_config
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
if (
reconfig_request.new_data_parallel_rank
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
if (
reconfig_request.new_data_parallel_rank_local
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank_local = (
reconfig_request.new_data_parallel_rank_local
)
parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
parallel_config._data_parallel_master_port_list = (
reconfig_request.new_data_parallel_master_port_list
)
parallel_config._stateless_world_group_port_list = (
reconfig_request.new_stateless_world_group_port_list
)
parallel_config._stateless_dp_group_port_list = (
reconfig_request.new_stateless_dp_group_port_list
)
parallel_config._stateless_ep_group_port_list = (
reconfig_request.new_stateless_ep_group_port_list
)
parallel_config._stateless_eplb_group_port_list = (
reconfig_request.new_stateless_eplb_group_port_list
)
@@ -0,0 +1,117 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed.parallel_state import (
_init_stateless_group,
_node_count,
get_pp_group,
get_tp_group,
get_world_group,
)
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
_STANDBY_WORLD: StatelessGroupCoordinator | None = None
_STANDBY_WORLD_NODE_COUNT: int | None = None
_STANDBY_DP: StatelessGroupCoordinator | None = None
_STANDBY_EP: StatelessGroupCoordinator | None = None
_STANDBY_EPLB: StatelessGroupCoordinator | None = None
def get_standby_dp_group() -> StatelessGroupCoordinator | None:
return _STANDBY_DP
def get_standby_ep_group() -> StatelessGroupCoordinator | None:
return _STANDBY_EP
def get_standby_eplb_group() -> StatelessGroupCoordinator | None:
return _STANDBY_EPLB
def get_standby_world_group() -> StatelessGroupCoordinator | None:
return _STANDBY_WORLD
def create_standby_groups(
new_dp_size: int,
new_world_size_across_dp: int,
master_ip: str,
world_group_ports: list[list[int]],
dp_group_ports: list[list[int]],
ep_group_ports: list[list[int]],
eplb_group_ports: list[list[int]] | None = None,
backend: str | None = None,
) -> None:
global \
_STANDBY_WORLD, \
_STANDBY_WORLD_NODE_COUNT, \
_STANDBY_DP, \
_STANDBY_EP, \
_STANDBY_EPLB
assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size
world_group = get_world_group()
assert isinstance(world_group, StatelessGroupCoordinator)
backend = backend or world_group.backend
standby_world_ranks = [list(range(new_world_size_across_dp))]
_STANDBY_WORLD = _init_stateless_group(
standby_world_ranks,
"world",
world_group_ports,
master_ip,
backend,
use_device_communicator=False,
)
_STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group)
tp_size = get_tp_group().world_size
pp_size = get_pp_group().world_size
all_ranks = torch.arange(new_world_size_across_dp).reshape(
-1, new_dp_size, pp_size, tp_size
)
standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0)
standby_dp_ranks = [x.tolist() for x in standby_dp_ranks]
_STANDBY_DP = _init_stateless_group(
standby_dp_ranks, "dp", dp_group_ports, master_ip, backend
)
standby_ep_ranks = (
all_ranks.transpose(1, 2).reshape(-1, new_dp_size * tp_size).unbind(0)
)
standby_ep_ranks = [x.tolist() for x in standby_ep_ranks]
_STANDBY_EP = _init_stateless_group(
standby_ep_ranks, "ep", ep_group_ports, master_ip, backend
)
if eplb_group_ports is not None:
_STANDBY_EPLB = _init_stateless_group(
standby_ep_ranks, "eplb", eplb_group_ports, master_ip, backend
)
def pop_standby_groups() -> dict:
"""Return all standby groups and clear the standby state."""
global \
_STANDBY_WORLD, \
_STANDBY_WORLD_NODE_COUNT, \
_STANDBY_DP, \
_STANDBY_EP, \
_STANDBY_EPLB
result = dict(
world=_STANDBY_WORLD,
dp=_STANDBY_DP,
ep=_STANDBY_EP,
eplb=_STANDBY_EPLB,
node_count=_STANDBY_WORLD_NODE_COUNT,
)
_STANDBY_WORLD = None
_STANDBY_WORLD_NODE_COUNT = None
_STANDBY_DP = None
_STANDBY_EP = None
_STANDBY_EPLB = None
return result
-4
View File
@@ -24,7 +24,6 @@ logger = init_logger(__name__)
def start_async_worker(
state: "EplbState",
rank_mapping: dict[int, int] | None = None,
is_profile: bool = False,
) -> threading.Thread:
eplb_group = get_eplb_group().device_group
@@ -45,7 +44,6 @@ def start_async_worker(
eplb_group=eplb_group,
cuda_stream=cuda_stream,
is_profile=is_profile,
rank_mapping=rank_mapping,
)
)
except Exception as exc: # pragma: no cover - diagnostic path
@@ -107,7 +105,6 @@ async def transfer_run_periodically(
eplb_group: ProcessGroup,
cuda_stream: torch.cuda.Stream,
is_profile: bool = False,
rank_mapping: dict[int, int] | None = None,
) -> None:
while True:
await asyncio.to_thread(state.rearrange_event.wait)
@@ -176,7 +173,6 @@ async def transfer_run_periodically(
ep_group=eplb_group,
is_profile=is_profile,
cuda_stream=cuda_stream,
rank_mapping=rank_mapping,
)
event = torch.cuda.Event(blocking=False)
cuda_stream.record_event(event)
+98 -214
View File
@@ -40,6 +40,7 @@ from vllm.distributed.parallel_state import (
get_node_count,
in_the_same_node_as,
)
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MixtureOfExperts
@@ -302,6 +303,14 @@ class EplbState:
"""
CUDA device index for the async EPLB worker thread.
"""
self.num_valid_physical_experts: int = 0
"""
Number of valid physical experts.
This is the number of physical experts that are
actually mapped to logical experts. In elastic EP,
newly started EP ranks may not have physical experts
mapped yet.
"""
if self.device.type == "cuda":
self.cuda_device_index = self.device.index
if self.cuda_device_index is None and torch.cuda.is_available():
@@ -367,9 +376,6 @@ class EplbState:
self,
model: MixtureOfExperts,
model_config: ModelConfig,
global_expert_load: torch.Tensor | None = None,
old_global_expert_indices: torch.Tensor | None = None,
rank_mapping: dict[int, int] | None = None,
):
"""
Build the initial EPLB state.
@@ -462,75 +468,15 @@ class EplbState:
)
self.expert_rearrangement_step_interval = eplb_step_interval
# Set the policy based on the selected eplb algorithm type.
policy_type = self.parallel_config.eplb_config.policy
self.policy = EPLB_POLICIES[policy_type]
logger.debug("Selected EPLB policy: %s", policy_type)
if global_expert_load is not None:
ep_group = get_ep_group().device_group
assert global_expert_load.shape == (
model.num_moe_layers,
model.num_logical_experts,
)
assert global_expert_load.dtype == torch.int64
num_replicas = model.num_physical_experts
num_groups = model.num_expert_groups
num_nodes = get_node_count()
num_gpus = ep_group.size()
if num_gpus % num_nodes != 0:
num_nodes = 1
logger.warning_once(
f"num_gpus % num_nodes != 0, "
"not using hierarchical rearrangement algorithm.\n"
f"{num_gpus=}, {num_nodes=}"
)
# Get new expert mappings
(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = self.policy.rebalance_experts(
global_expert_load,
num_replicas,
num_groups,
num_nodes,
num_gpus,
)
max_physical_slots = new_logical_to_physical_map.shape[-1]
assert max_physical_slots <= logical_to_physical_map.shape[-1]
new_logical_to_physical_map = torch.nn.functional.pad(
new_logical_to_physical_map,
(0, logical_to_physical_map.shape[-1] - max_physical_slots),
value=-1,
)
physical_to_logical_map = new_physical_to_logical_map.to(self.device)
logical_to_physical_map.copy_(new_logical_to_physical_map)
logical_replica_count.copy_(new_logical_replica_count)
else:
new_physical_to_logical_map = None
new_logical_to_physical_map = None
new_logical_replica_count = None
model.set_eplb_state(
expert_load_pass,
logical_to_physical_map,
logical_replica_count,
)
if global_expert_load is not None:
rearrange_expert_weights_inplace(
old_global_expert_indices,
new_physical_to_logical_map,
model.expert_weights,
ep_group,
False,
rank_mapping,
)
self.expert_rearrangement_step = 0
expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]]
@@ -561,11 +507,12 @@ class EplbState:
recv_dst_rows=np.array([]),
),
cuda_device_index=self.cuda_device_index,
new_physical_to_logical_map=new_physical_to_logical_map,
new_logical_to_physical_map=new_logical_to_physical_map,
new_logical_replica_count=new_logical_replica_count,
new_physical_to_logical_map=None,
new_logical_to_physical_map=None,
new_logical_replica_count=None,
)
self.model_states[model_config.compute_hash()] = model_state
self.num_valid_physical_experts = model.num_physical_experts
def step(
self,
@@ -696,8 +643,6 @@ class EplbState:
def rearrange(
self,
is_profile: bool = False,
execute_shuffle: bool = True,
global_expert_loads: list[torch.Tensor] | None = None,
rank_mapping: dict[int, int] | None = None,
) -> torch.Tensor | None:
"""
@@ -707,12 +652,6 @@ class EplbState:
is_profile (bool): If `True`, perform a dummy rearrangement.
This is used in `profile_run` to reserve enough memory,
no memory movement will be performed. Default is False.
execute_shuffle (bool): If `True`, execute the shuffle
in elastic expert parallel (EEP). Default is True.
global_expert_loads (list[torch.Tensor] | None): The global expert
loads when scaling is done in EEP.
List of expert loads for the main and drafter
(when spec decode is used) models.
rank_mapping (dict[int, int] | None): The rank mapping
when scaling is done in EEP.
"""
@@ -734,67 +673,34 @@ class EplbState:
"(profile)" if is_profile else "",
)
if global_expert_loads is None:
# Map the physical expert load to global logical experts
global_expert_load_windows = []
if not execute_shuffle:
num_models = torch.tensor(
[len(self.model_states)], dtype=torch.int32, device="cpu"
)
torch.distributed.broadcast(
num_models, group=get_ep_group().cpu_group, group_src=0
)
for eplb_model_state in self.model_states.values():
logical_expert_load_window = torch.zeros(
self.expert_load_window_size,
eplb_model_state.model.num_moe_layers,
eplb_model_state.model.num_logical_experts,
dtype=eplb_model_state.expert_load_window.dtype,
device=eplb_model_state.expert_load_window.device,
)
logical_expert_load_window.scatter_add_(
dim=-1,
index=eplb_model_state.physical_to_logical_map.unsqueeze(0)
.expand_as(eplb_model_state.expert_load_window)
.long(),
src=eplb_model_state.expert_load_window,
)
if not execute_shuffle:
metadata = torch.tensor(
[
eplb_model_state.model.num_moe_layers,
eplb_model_state.model.num_logical_experts,
eplb_model_state.physical_to_logical_map.shape[1],
],
dtype=torch.int32,
device="cpu",
)
torch.distributed.broadcast(
metadata, group=get_ep_group().cpu_group, group_src=0
)
global_expert_load_window = logical_expert_load_window.sum(dim=0)
global_expert_load_windows.append(global_expert_load_window)
# Perform all-reduce to get the expert load across all ranks for each model
global_expert_load_windows = self._allreduce_list(
global_expert_load_windows
# Map the physical expert load to global logical experts
global_expert_load_windows = []
for eplb_model_state in self.model_states.values():
expert_load_window = eplb_model_state.expert_load_window[
:, :, : self.num_valid_physical_experts
]
logical_expert_load_window = torch.zeros(
self.expert_load_window_size,
eplb_model_state.model.num_moe_layers,
eplb_model_state.model.num_logical_experts,
dtype=eplb_model_state.expert_load_window.dtype,
device=eplb_model_state.expert_load_window.device,
)
if not execute_shuffle:
for eplb_model_state, global_expert_load_window in zip(
self.model_states.values(), global_expert_load_windows
):
# (num_moe_layers, old_num_physical_experts)
old_global_expert_indices = eplb_model_state.physical_to_logical_map
torch.distributed.broadcast(
old_global_expert_indices, group=ep_group, group_src=0
)
if not execute_shuffle:
return global_expert_load_windows
else:
assert execute_shuffle
global_expert_load_windows = global_expert_loads
logical_expert_load_window.scatter_add_(
dim=-1,
index=eplb_model_state.physical_to_logical_map[
:, : self.num_valid_physical_experts
]
.unsqueeze(0)
.expand_as(expert_load_window)
.long(),
src=expert_load_window,
)
global_expert_load_window = logical_expert_load_window.sum(dim=0)
global_expert_load_windows.append(global_expert_load_window)
# Perform all-reduce to get the expert load across all ranks for each model
global_expert_load_windows = self._allreduce_list(global_expert_load_windows)
# TODO(bowen): Treat differently for prefill and decode nodes
eplb_model_state = next(iter(self.model_states.values()))
@@ -806,8 +712,10 @@ class EplbState:
# NOTE(yongji): scale down, we need to rebalance the experts on
# remaining GPUs, transfer the experts while we haven't shutdown
# the GPUs to be released.
cpu_group = get_ep_group().cpu_group
num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping)
coordinator = get_ep_group()
assert isinstance(coordinator, StatelessGroupCoordinator)
tcp_store_group = coordinator.tcp_store_group
num_nodes = _node_count_with_rank_mapping(tcp_store_group, rank_mapping)
num_gpus = sum(new_rank != -1 for new_rank in rank_mapping.values())
num_replicas = (
num_replicas // ep_group.size() * num_gpus
@@ -933,7 +841,6 @@ class EplbState:
if self.async_worker is None:
self.async_worker = start_async_worker(
self,
rank_mapping=rank_mapping,
is_profile=is_profile,
)
@@ -1089,83 +996,6 @@ class EplbState:
model_state.new_logical_to_physical_map = None
model_state.new_logical_replica_count = None
@staticmethod
def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""
Receive the expert load and old placement from the master rank.
"""
ep_group = get_ep_group()
num_models = torch.empty(1, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(num_models, group=ep_group.cpu_group, group_src=0)
num_models = num_models.item()
global_expert_loads = []
old_global_expert_indices_per_model = []
for _ in range(num_models):
metadata = torch.empty(3, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0)
num_moe_layers, num_logical_experts, num_old_physical_experts = (
metadata.tolist()
)
global_expert_load = torch.zeros(
(num_moe_layers, num_logical_experts),
dtype=torch.int64,
device=ep_group.device,
)
all_reduce(global_expert_load, group=ep_group.device_group)
old_global_expert_indices = torch.empty(
(num_moe_layers, num_old_physical_experts),
dtype=torch.int64,
device=ep_group.device,
)
torch.distributed.broadcast(
old_global_expert_indices,
group=ep_group.device_group,
group_src=0,
)
global_expert_loads.append(global_expert_load)
old_global_expert_indices_per_model.append(old_global_expert_indices)
return global_expert_loads, old_global_expert_indices_per_model
@classmethod
def get_eep_state(
cls, parallel_config: ParallelConfig
) -> tuple[
list[torch.Tensor] | None,
list[torch.Tensor] | None,
dict[int, int] | None,
]:
num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(
num_local_physical_experts,
group=get_ep_group().cpu_group,
group_src=0,
)
num_local_physical_experts = int(num_local_physical_experts.item())
new_ep_size = get_ep_group().world_size
global_expert_loads, old_global_expert_indices_per_model = (
EplbState.recv_state()
)
# EP configuration for all models has to be the same so as eplb config
num_logical_experts = global_expert_loads[0].shape[1]
parallel_config.eplb_config.num_redundant_experts = (
num_local_physical_experts * new_ep_size - num_logical_experts
)
assert (
old_global_expert_indices_per_model[0].shape[1] % num_local_physical_experts
== 0
)
old_ep_size = (
old_global_expert_indices_per_model[0].shape[1]
// num_local_physical_experts
)
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
return (
global_expert_loads,
old_global_expert_indices_per_model,
rank_mapping,
)
def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]:
"""
All-reduce a list of tensors.
@@ -1203,6 +1033,60 @@ class EplbState:
load_pass_list.append(eplb_model_state.expert_load_pass.clone())
return self._allreduce_list(load_pass_list)
@classmethod
def from_mapping(
cls,
model: MixtureOfExperts,
model_config: ModelConfig,
device: torch.device,
parallel_config: ParallelConfig,
expanded_physical_to_logical: torch.Tensor,
num_valid_physical_experts: int,
) -> "EplbState":
eplb_state = cls(
parallel_config=parallel_config,
device=device,
)
eplb_state.add_model(
model=model,
model_config=model_config,
)
eplb_state.num_valid_physical_experts = num_valid_physical_experts
num_moe_layers = expanded_physical_to_logical.shape[0]
num_physical_experts = expanded_physical_to_logical.shape[1]
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
eplb_model_state.physical_to_logical_map.copy_(expanded_physical_to_logical)
logical_to_physical_map = torch.full(
(
num_moe_layers,
model.num_logical_experts,
eplb_model_state.logical_to_physical_map.shape[2],
),
-1,
dtype=torch.int64,
)
logical_replica_count = torch.zeros(
(num_moe_layers, model.num_logical_experts),
dtype=torch.int64,
)
expanded_physical_to_logical_numpy = expanded_physical_to_logical.cpu().numpy()
for layer_idx in range(num_moe_layers):
for phys_idx in range(num_physical_experts):
logical_idx = expanded_physical_to_logical_numpy[layer_idx, phys_idx]
if logical_idx >= 0:
replica_idx = logical_replica_count[layer_idx, logical_idx]
logical_to_physical_map[layer_idx, logical_idx, replica_idx] = (
phys_idx
)
logical_replica_count[layer_idx, logical_idx] += 1
logical_to_physical_map = logical_to_physical_map.to(device)
logical_replica_count = logical_replica_count.to(device)
eplb_model_state.logical_to_physical_map.copy_(logical_to_physical_map)
eplb_model_state.logical_replica_count.copy_(logical_replica_count)
return eplb_state
@dataclass
class EplbLayerState:
+56 -24
View File
@@ -19,6 +19,8 @@ from torch.distributed import (
get_global_rank,
)
from vllm.distributed.parallel_state import get_ep_group
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.logger import init_logger
logger = init_logger(__name__)
@@ -249,10 +251,18 @@ def move_to_buffer(
b[dst].copy_(w[src_local], non_blocking=True)
p2p_ops: list[P2POp] = []
if isinstance(get_ep_group(), StatelessGroupCoordinator):
ep_group = get_ep_group()
is_stateless = True
else:
is_stateless = False
# Pre-compute global ranks mapping
# Pre-compute global ranks mapping (only needed for non-stateless groups)
ep_size = ep_group.size()
rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)}
if not is_stateless:
rank_to_global = {
rank: get_global_rank(ep_group, rank) for rank in range(ep_size)
}
# 2. Post sends
if send_count > 0:
@@ -284,15 +294,23 @@ def move_to_buffer(
if recver_pos < len(ranks_to_recv):
recv_ranks.append(ranks_to_recv[recver_pos])
for dst in recv_ranks:
dst_global = rank_to_global[dst]
p2p_ops += [
P2POp(
torch.distributed.isend,
w[src],
dst_global,
)
for w in expert_weights
]
if is_stateless:
for w in expert_weights:
op = object.__new__(P2POp)
op.op = torch.distributed.isend
op.tensor = w[src]
op.group_peer = dst
p2p_ops.append(op)
else:
dst_global = rank_to_global[dst]
p2p_ops += [
P2POp(
torch.distributed.isend,
w[src],
dst_global,
)
for w in expert_weights
]
# 3. Post recvs
if recv_count > 0:
@@ -321,26 +339,40 @@ def move_to_buffer(
src = ranks_to_send[recver_pos // num_dst_per_sender]
else:
src = ranks_to_send[recver_pos - remainder_start]
src_global = rank_to_global[src]
p2p_ops += [
P2POp(
torch.distributed.irecv,
b[dst],
src_global,
)
for b in expert_weights_buffers
]
if is_stateless:
for b in expert_weights_buffers:
op = object.__new__(P2POp)
op.op = torch.distributed.irecv
op.tensor = b[dst]
op.group_peer = src
p2p_ops.append(op)
else:
src_global = rank_to_global[src]
p2p_ops += [
P2POp(
torch.distributed.irecv,
b[dst],
src_global,
)
for b in expert_weights_buffers
]
# 4. Execute the P2P operations. The real communication happens here.
if p2p_ops and cuda_stream is not None:
with torch.cuda.stream(cuda_stream):
if is_stateless:
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
else:
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
elif p2p_ops:
if is_stateless:
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
else:
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
elif p2p_ops:
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
# wait for the communication to finish
return (
is_unchanged,
+207 -22
View File
@@ -33,7 +33,7 @@ from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from datetime import timedelta
from multiprocessing import shared_memory
from typing import Any, Protocol
from typing import TYPE_CHECKING, Any, Protocol
from unittest.mock import patch
import torch
@@ -55,6 +55,9 @@ from vllm.utils.torch_utils import (
direct_register_custom_op,
)
if TYPE_CHECKING:
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
@dataclass
class GraphCaptureContext:
@@ -1157,6 +1160,55 @@ def init_model_parallel_group(
)
def _init_stateless_group(
group_ranks: list[list[int]],
group_name: str,
group_ports: list[list[int]],
host: str,
backend: str,
use_device_communicator: bool = True,
) -> "StatelessGroupCoordinator":
"""Create a StatelessGroupCoordinator with the given parameters."""
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
world = get_world_group()
return StatelessGroupCoordinator(
group_ranks=group_ranks,
local_rank=world.local_rank,
torch_distributed_backend=backend,
use_device_communicator=use_device_communicator,
group_name=group_name,
host=host,
group_ports=group_ports,
global_rank=world.rank,
global_world_size=world.world_size,
)
def _replace_active_groups(
*,
world: GroupCoordinator | None,
dp: GroupCoordinator | None,
ep: GroupCoordinator | None,
eplb: GroupCoordinator | None,
node_count: int | None,
) -> None:
"""Destroy the current DP/EP/WORLD/EPLB groups and replace them.
Destruction is collective all ranks in the old groups must call this
function together. Pass all-``None`` to tear down without replacement.
"""
global _WORLD, _DP, _EP, _EPLB, _NODE_COUNT
for group in (_DP, _EP, _WORLD, _EPLB):
if group is not None:
group.destroy()
_WORLD = world
_DP = dp
_EP = ep
_EPLB = eplb
_NODE_COUNT = node_count
_TP: GroupCoordinator | None = None
@@ -1254,6 +1306,39 @@ def set_custom_all_reduce(enable: bool):
_ENABLE_CUSTOM_ALL_REDUCE = enable
def _init_elastic_ep_world(
config, local_rank: int, backend: str, rank: int, world_size: int
) -> None:
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
global _WORLD, _NODE_COUNT
assert _WORLD is None, "world group already initialized"
parallel_config = config.parallel_config
global_rank = parallel_config.data_parallel_rank * world_size + rank
global_world_size = parallel_config.world_size_across_dp
all_ranks = list(range(global_world_size))
group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)]
if global_rank in all_ranks:
group_ranks = [all_ranks]
group_ports = [parallel_config.get_next_stateless_world_group_port()]
world = StatelessGroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_device_communicator=False,
group_name="world",
host=parallel_config.data_parallel_master_ip,
group_ports=group_ports,
global_rank=global_rank,
global_world_size=global_world_size,
)
assert parallel_config.nnodes_within_dp == 1, (
"Elastic EP is not supported with multi-node TP/PP"
)
_NODE_COUNT = _node_count(world.tcp_store_group)
_WORLD = world
def init_distributed_environment(
world_size: int = -1,
rank: int = -1,
@@ -1273,6 +1358,7 @@ def init_distributed_environment(
from vllm.config import get_current_vllm_config_or_none
config = get_current_vllm_config_or_none()
enable_elastic_ep = config is not None and config.parallel_config.enable_elastic_ep
if (
config is not None
and config.parallel_config.distributed_executor_backend != "external_launcher"
@@ -1280,6 +1366,7 @@ def init_distributed_environment(
config.parallel_config.nnodes > 1
or config.parallel_config.data_parallel_size > 1
)
and not enable_elastic_ep
):
parallel_config = config.parallel_config
# adjust to take into account data parallelism
@@ -1333,6 +1420,18 @@ def init_distributed_environment(
rank=rank,
timeout=timeout,
)
if enable_elastic_ep:
tp_pp_cpu_group = torch.distributed.new_group(
backend="gloo", timeout=timeout
)
if _node_count(tp_pp_cpu_group) > 1:
# NOTE(yongji): StatelessGroupCoordinator uses data_parallel_master_ip
# to initialize all DP/EP groups, hence all ranks within TP/PP group
# must reside on the same node
raise RuntimeError(
"Elastic EP is not yet supported with multi-node TP/PP"
)
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
@@ -1341,6 +1440,9 @@ def init_distributed_environment(
# setting, where we can use rank as local rank
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
if enable_elastic_ep:
_init_elastic_ep_world(config, local_rank, backend, rank, world_size)
return
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend)
@@ -1404,16 +1506,33 @@ def initialize_model_parallel(
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
data_parallel_size = 1
from vllm.config import get_current_vllm_config_or_none
from vllm.config import get_current_vllm_config
config = get_current_vllm_config_or_none()
if config is not None:
data_parallel_size = config.parallel_config.data_parallel_size
config = get_current_vllm_config()
data_parallel_size = config.parallel_config.data_parallel_size
enable_elastic_ep = config.parallel_config.enable_elastic_ep
if enable_elastic_ep:
# Use stateless world group for global information
world_size = get_world_group().world_size
rank = get_world_group().rank
backend = backend or "nccl"
tp_pp_pcp_size = (
tensor_model_parallel_size
* pipeline_model_parallel_size
* prefill_context_model_parallel_size
)
local_all_ranks = torch.arange(tp_pp_pcp_size).reshape(
pipeline_model_parallel_size,
prefill_context_model_parallel_size,
tensor_model_parallel_size,
)
else:
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group
)
# the layout order is: ExternalDP x DP x PP x TP
# ExternalDP is the data parallel group that is not part of the model,
@@ -1437,7 +1556,9 @@ def initialize_model_parallel(
assert _TP is None, "tensor model parallel group is already initialized"
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = local_all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
# message queue broadcaster is only used in tensor model parallel group
_TP = init_model_parallel_group(
group_ranks,
@@ -1456,6 +1577,11 @@ def initialize_model_parallel(
# TP group into tp_size//dcp_size DCP groups.
group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = local_all_ranks.reshape(
-1, decode_context_model_parallel_size
).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_DCP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
@@ -1472,6 +1598,13 @@ def initialize_model_parallel(
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = (
local_all_ranks.transpose(1, 2)
.reshape(-1, prefill_context_model_parallel_size)
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
_PCP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="pcp"
)
@@ -1483,6 +1616,13 @@ def initialize_model_parallel(
all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = (
local_all_ranks.transpose(0, 2)
.reshape(-1, pipeline_model_parallel_size)
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
_PP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="pp"
)
@@ -1491,14 +1631,27 @@ def initialize_model_parallel(
assert _DP is None, "data parallel group is already initialized"
group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_DP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="dp"
)
if enable_elastic_ep:
parallel_config = config.parallel_config
dp_ports = [
parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks
]
_DP = _init_stateless_group(
group_ranks,
"dp",
dp_ports,
parallel_config.data_parallel_master_ip,
backend,
)
else:
_DP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="dp"
)
global _EP
assert _EP is None, "expert parallel group is already initialized"
# Don't create EP group for dense models.
if config is None or config.model_config is None or config.model_config.is_moe:
if config.model_config is None or config.model_config.is_moe:
group_ranks = (
all_ranks.transpose(1, 2)
.reshape(
@@ -1510,9 +1663,22 @@ def initialize_model_parallel(
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
_EP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="ep"
)
if enable_elastic_ep:
parallel_config = config.parallel_config
ep_ports = [
parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks
]
_EP = _init_stateless_group(
group_ranks,
"ep",
ep_ports,
parallel_config.data_parallel_master_ip,
backend,
)
else:
_EP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="ep"
)
# Create EPLB group with the same ranks as EP if EPLB is enabled.
# This is a separate process group to isolate EPLB communications
@@ -1525,10 +1691,25 @@ def initialize_model_parallel(
and config.parallel_config is not None
and config.parallel_config.enable_eplb
):
# Reuse the same group_ranks from EP
_EPLB = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="eplb"
)
if enable_elastic_ep:
eplb_ports = [
parallel_config.get_next_stateless_eplb_group_port()
for _ in group_ranks
]
_EPLB = _init_stateless_group(
group_ranks,
"eplb",
eplb_ports,
parallel_config.data_parallel_master_ip,
backend,
)
else:
_EPLB = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
group_name="eplb",
)
# If no EP group needed, _EP remains None
# If no EPLB group needed, _EPLB remains None
@@ -1558,7 +1739,11 @@ def ensure_model_parallel_initialized(
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized.
"""
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
world_group = get_world_group()
if hasattr(world_group, "backend"):
backend = backend or world_group.backend
else:
backend = backend or torch.distributed.get_backend(world_group.device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(
tensor_model_parallel_size,
+322
View File
@@ -0,0 +1,322 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
from torch.distributed import Backend, ProcessGroup
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
from vllm.distributed.parallel_state import (
GroupCoordinator,
TensorMetadata,
_get_unique_name,
_register_group,
_split_tensor_dict,
)
from vllm.distributed.utils import (
StatelessProcessGroup,
stateless_destroy_torch_distributed_process_group,
stateless_init_torch_distributed_process_group,
)
from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
class StatelessGroupCoordinator(GroupCoordinator):
"""
A stateless version of the GroupCoordinator class in parallel_state,
It will create CPU, device and TCPStore based communication groups
that are independent of PyTorch's WORLD group. Hence,
communication groups with a different set of participants GPUs
can be created without destroying the existing ones.
"""
def __init__(
self,
group_ranks: list[list[int]],
local_rank: int,
torch_distributed_backend: str | Backend,
use_device_communicator: bool,
use_message_queue_broadcaster: bool = False,
group_name: str | None = None,
host: str = "127.0.0.1",
group_ports: list[list[int]] | None = None,
global_rank: int = 0,
global_world_size: int = 1,
):
group_name = group_name or "anonymous"
self.unique_name = _get_unique_name(group_name)
_register_group(self)
self.rank = global_rank
self.local_rank = local_rank
self_device_group = None
self_cpu_group = None
self_tcp_store_group = None
from vllm.platforms import current_platform
backend = str(torch_distributed_backend)
self.backend = backend
assert group_ports is not None, "group_ports is not provided"
for idx, ranks in enumerate(group_ranks):
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
ports = group_ports[idx]
device_port = ports[0]
cpu_port = ports[1]
tcp_store_port = ports[2]
device_group = stateless_init_torch_distributed_process_group(
host=host,
port=device_port,
rank=self.rank_in_group,
world_size=self.world_size,
backend=backend,
group_name=f"{self.unique_name}_device",
)
cpu_group = stateless_init_torch_distributed_process_group(
host=host,
port=cpu_port,
rank=self.rank_in_group,
world_size=self.world_size,
backend="gloo",
group_name=f"{self.unique_name}_cpu",
)
tcp_store_group = StatelessProcessGroup.create(
host=host,
port=tcp_store_port,
rank=self.rank_in_group,
world_size=self.world_size,
)
self_device_group = device_group
self_cpu_group = cpu_group
self_tcp_store_group = tcp_store_group
assert self_cpu_group is not None
assert self_device_group is not None
assert self_tcp_store_group is not None
self.cpu_group = self_cpu_group
self.device_group = self_device_group
self.tcp_store_group = self_tcp_store_group
if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}")
elif current_platform.is_xpu():
self.device = torch.device(f"xpu:{local_rank}")
elif current_platform.is_out_of_tree():
self.device = torch.device(f"{current_platform.device_name}:{local_rank}")
else:
self.device = torch.device("cpu")
self.use_device_communicator = use_device_communicator
self.device_communicator = None
if use_device_communicator and self.world_size > 1:
device_comm_cls = resolve_obj_by_qualname(
current_platform.get_device_communicator_cls()
)
assert device_comm_cls == CudaCommunicator
self.device_communicator = CudaCommunicator(
cpu_group=self.cpu_group,
device=self.device,
device_group=self.device_group,
unique_name=self.unique_name,
global_ranks=self.ranks,
global_world_size=global_world_size,
tcp_store_group=self.tcp_store_group,
)
self.mq_broadcaster = None
self.use_custom_op_call = (
current_platform.is_cuda_alike() or current_platform.is_tpu()
)
self.use_cpu_custom_send_recv = False
def destroy(self):
if self.device_communicator:
self.device_communicator.destroy()
if self.device_group:
stateless_destroy_torch_distributed_process_group(self.device_group)
if self.cpu_group:
stateless_destroy_torch_distributed_process_group(self.cpu_group)
def size(self) -> int:
"""Return the world size of this group."""
return self.world_size
def broadcast(self, input_: torch.Tensor, src: int = 0):
if self.world_size == 1:
return input_
if self.device_communicator and input_.is_cuda:
return self.device_communicator.broadcast(input_, src)
else:
return self.tcp_store_group.broadcast(input_, src)
def broadcast_object(self, obj=None, src: int = 0):
if self.world_size == 1:
return obj
return self.tcp_store_group.broadcast_obj(obj, src)
def broadcast_object_list(
self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None
):
assert src < self.world_size
if self.world_size == 1:
return obj_list
if self.rank_in_group == src:
for obj in obj_list:
self.tcp_store_group.broadcast_obj(obj, src)
else:
for i in range(len(obj_list)):
obj_list[i] = self.tcp_store_group.broadcast_obj(None, src)
return obj_list
def broadcast_tensor_dict(
self,
tensor_dict: dict[str, torch.Tensor | Any] | None = None,
src: int = 0,
group: ProcessGroup | None = None,
metadata_group: ProcessGroup | None = None,
) -> dict[str, torch.Tensor | Any] | None:
if self.world_size == 1:
return tensor_dict
if self.rank_in_group == src:
assert isinstance(tensor_dict, dict), (
f"Expecting a dictionary, got {type(tensor_dict)}"
)
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
else:
metadata_list = None
tensor_list = []
recv_metadata_list: list[tuple[str, Any]] = self.tcp_store_group.broadcast_obj(
metadata_list, src
)
if self.rank_in_group != src:
tensor_dict = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(
value.size, dtype=value.dtype, device=value.device
)
tensor_list.append(tensor)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
for tensor in tensor_list:
if tensor.numel() == 0:
continue
if self.device_communicator and tensor.is_cuda:
tensor.copy_(self.device_communicator.broadcast(tensor, src))
else:
tensor.copy_(self.tcp_store_group.broadcast(tensor, src))
return tensor_dict
def send_object(self, obj, dst: int) -> None:
assert dst < self.world_size
assert dst != self.rank_in_group
self.tcp_store_group.send_obj(obj, dst)
def recv_object(self, src: int):
assert src < self.world_size
assert src != self.rank_in_group
return self.tcp_store_group.recv_obj(src)
def send_tensor_dict(
self,
tensor_dict: dict[str, torch.Tensor | Any],
dst: int | None = None,
all_gather_group: Optional["GroupCoordinator"] = None,
all_gather_tensors: dict[str, bool] | None = None,
) -> dict[str, torch.Tensor | Any] | None:
if self.world_size == 1:
return tensor_dict
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
self.tcp_store_group.send_obj(metadata_list, dst)
for tensor in tensor_list:
if tensor.numel() == 0:
continue
if self.device_communicator and tensor.is_cuda:
self.device_communicator.send(tensor, dst)
else:
self.tcp_store_group.send(tensor, dst)
return None
def recv_tensor_dict(
self,
src: int | None = None,
all_gather_group: Optional["GroupCoordinator"] = None,
all_gather_tensors: dict[str, bool] | None = None,
) -> dict[str, torch.Tensor | Any] | None:
if self.world_size == 1:
return None
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size
recv_metadata_list = self.tcp_store_group.recv_obj(src)
tensor_dict = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
if tensor.numel() > 0:
if self.device_communicator and tensor.is_cuda:
tensor = self.device_communicator.recv(
tensor.size(), tensor.dtype, src
)
else:
tensor = self.tcp_store_group.recv(tensor, src)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
return tensor_dict
def barrier(self):
self.tcp_store_group.barrier()
def gather(
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> torch.Tensor | None:
if self.world_size == 1:
return input_
if self.device_communicator is None:
raise ValueError("No device communicator found")
if self.rank_in_group == dst:
gathered_list = [torch.empty_like(input_) for _ in range(self.world_size)]
gathered_list[self.rank_in_group] = input_
for src_rank in range(self.world_size):
if src_rank != self.rank_in_group:
gathered_list[src_rank] = self.device_communicator.recv(
input_.size(), input_.dtype, src_rank
)
return torch.cat(gathered_list, dim=dim)
else:
self.device_communicator.send(input_, dst)
return None
+79 -14
View File
@@ -18,7 +18,7 @@ from datetime import timedelta
from typing import Any
import torch
from torch.distributed import ProcessGroup, TCPStore
from torch.distributed import ProcessGroup, Store, TCPStore
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
@@ -228,6 +228,55 @@ class StatelessProcessGroup:
gathered_objs.append(recv_obj)
return gathered_objs
def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
"""Broadcast a tensor from source rank to all other ranks."""
if self.rank == src:
tensor_bytes = pickle.dumps(tensor)
self.expire_data()
key = f"broadcast_tensor/{src}/{self.broadcast_send_counter}"
self.store.set(key, tensor_bytes)
self.broadcast_send_counter += 1
self.entries.append((key, time.time()))
return tensor
else:
key = f"broadcast_tensor/{src}/{self.broadcast_recv_src_counter[src]}"
tensor = pickle.loads(self.store.get(key))
self.broadcast_recv_src_counter[src] += 1
return tensor
def send(self, tensor: torch.Tensor, dst: int):
"""Send a tensor to a destination rank."""
self.expire_data()
key = f"send_tensor/{dst}/{self.send_dst_counter[dst]}"
self.store.set(key, pickle.dumps(tensor))
self.send_dst_counter[dst] += 1
self.entries.append((key, time.time()))
def recv(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
"""Receive a tensor from a source rank."""
key = f"send_tensor/{self.rank}/{self.recv_src_counter[src]}"
received = pickle.loads(self.store.get(key))
self.recv_src_counter[src] += 1
tensor.copy_(received)
return tensor
def all_reduce(
self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM
) -> torch.Tensor:
"""All-reduce a tensor across all ranks."""
tensors = self.all_gather_obj(tensor)
result = tensors[0].clone()
for t in tensors[1:]:
if op == torch.distributed.ReduceOp.SUM:
result.add_(t)
elif op == torch.distributed.ReduceOp.PRODUCT:
result.mul_(t)
elif op == torch.distributed.ReduceOp.MAX:
result = torch.maximum(result, t)
elif op == torch.distributed.ReduceOp.MIN:
result = torch.minimum(result, t)
return result
def barrier(self, timeout: float = 30.0):
"""A robust barrier to synchronize all ranks.
@@ -448,8 +497,14 @@ def init_gloo_process_group(
def stateless_init_torch_distributed_process_group(
host: str, port: int, rank: int, world_size: int, backend: str
) -> ProcessGroup:
host: str,
port: int,
rank: int,
world_size: int,
backend: str,
group_name: str | None = None,
return_store: bool = False,
) -> ProcessGroup | tuple[ProcessGroup, Store]:
"""
A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. The created ProcessGroup object can be used for
@@ -496,25 +551,35 @@ def stateless_init_torch_distributed_process_group(
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store)
try:
if backend == "gloo":
pg = init_gloo_process_group(
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
else:
from vllm.platforms import current_platform
return current_platform.stateless_init_device_torch_dist_pg(
pg = current_platform.stateless_init_device_torch_dist_pg(
backend=backend,
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
except NotImplementedError:
# If platform doesn't implement stateless_init_device_torch_dist_pg, it
# will raise a NotImplementedError. In this case, we fall back to gloo.
return init_gloo_process_group(
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
if group_name is not None:
from torch._C._distributed_c10d import _register_process_group
pg._set_group_name(group_name)
_register_process_group(group_name, pg)
if return_store:
return pg, store
else:
return pg
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:
+5
View File
@@ -419,6 +419,7 @@ class EngineArgs:
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
moe_backend: MoEBackend = KernelConfig.moe_backend
all2all_backend: All2AllBackend = ParallelConfig.all2all_backend
enable_elastic_ep: bool = ParallelConfig.enable_elastic_ep
enable_dbo: bool = ParallelConfig.enable_dbo
ubatch_size: int = ParallelConfig.ubatch_size
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
@@ -896,6 +897,9 @@ class EngineArgs:
"--ubatch-size",
**parallel_kwargs["ubatch_size"],
)
parallel_group.add_argument(
"--enable-elastic-ep", **parallel_kwargs["enable_elastic_ep"]
)
parallel_group.add_argument(
"--dbo-decode-token-threshold",
**parallel_kwargs["dbo_decode_token_threshold"],
@@ -1698,6 +1702,7 @@ class EngineArgs:
is_moe_model=model_config.is_moe,
enable_expert_parallel=self.enable_expert_parallel,
all2all_backend=self.all2all_backend,
enable_elastic_ep=self.enable_elastic_ep,
enable_dbo=self.enable_dbo,
ubatch_size=self.ubatch_size,
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
+5 -1
View File
@@ -246,8 +246,12 @@ def run_multi_api_server(args: argparse.Namespace):
api_server_manager: APIServerProcessManager | None = None
from vllm.v1.engine.utils import get_engine_zmq_addresses
addresses = get_engine_zmq_addresses(vllm_config, num_api_servers)
with launch_core_engines(
vllm_config, executor_class, log_stats, num_api_servers
vllm_config, executor_class, log_stats, addresses, num_api_servers
) as (local_engine_manager, coordinator, addresses):
# Construct common args for the APIServerProcessManager up-front.
api_server_manager_kwargs = dict(
+12
View File
@@ -243,6 +243,8 @@ if TYPE_CHECKING:
VLLM_LORA_DISABLE_PDL: bool = False
VLLM_ENABLE_CUDA_COMPATIBILITY: bool = False
VLLM_CUDA_COMPATIBILITY_PATH: str | None = None
VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False
VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False
def get_default_cache_root():
@@ -1617,6 +1619,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_CUDA_COMPATIBILITY_PATH": lambda: os.environ.get(
"VLLM_CUDA_COMPATIBILITY_PATH", None
),
# Whether it is a scale up launch engine for elastic EP,
# Should only be set by EngineCoreClient.
"VLLM_ELASTIC_EP_SCALE_UP_LAUNCH": lambda: bool(
int(os.getenv("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH", "0"))
),
# Whether to wait for all requests to drain before sending the
# scaling command in elastic EP.
"VLLM_ELASTIC_EP_DRAIN_REQUESTS": lambda: bool(
int(os.getenv("VLLM_ELASTIC_EP_DRAIN_REQUESTS", "0"))
),
}
@@ -627,6 +627,7 @@ class FusedMoE(CustomOp):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
self.base_quant_method = self.quant_method
# Disable shared expert overlap if:
# - we are using eplb with non-default backend, because of correctness issues
@@ -683,7 +684,7 @@ class FusedMoE(CustomOp):
# routing_tables only needed for round-robin expert placement with
# DeepEP all2all backend.
routing_tables = self._maybe_init_expert_routing_tables()
prepare_finalize = self.quant_method.maybe_make_prepare_finalize(
prepare_finalize = self.base_quant_method.maybe_make_prepare_finalize(
routing_tables=routing_tables
)
if prepare_finalize is not None:
@@ -693,7 +694,7 @@ class FusedMoE(CustomOp):
self._replace_quant_method(
FusedMoEModularMethod.make(
self,
self.quant_method,
self.base_quant_method,
prepare_finalize,
self.shared_experts,
inplace=not self.moe_config.disable_inplace,
+34
View File
@@ -6,10 +6,13 @@ pynvml. However, it should not initialize cuda context.
import os
from collections.abc import Callable
from datetime import timedelta
from functools import cache, wraps
from typing import TYPE_CHECKING, TypeVar
import torch
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
from typing_extensions import ParamSpec
# import custom ops, trigger op registration
@@ -482,6 +485,37 @@ class CudaPlatformBase(Platform):
def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
assert is_nccl_available()
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(
prefix_store, group_rank, group_size, backend_options
)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
@classmethod
def device_count(cls) -> int:
return cuda_device_count_stateless()
+34
View File
@@ -2,10 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from datetime import timedelta
from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING
import torch
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
import vllm.envs as envs
from vllm.logger import init_logger
@@ -656,6 +659,37 @@ class RocmPlatform(Platform):
def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
assert is_nccl_available()
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(
prefix_store, group_rank, group_size, backend_options
)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
@classmethod
def device_count(cls) -> int:
return cuda_device_count_stateless()
+14
View File
@@ -29,6 +29,15 @@ PauseMode = Literal["abort", "wait", "keep"]
# so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
EEP_NOTIFICATION_CALL_ID = -1
class EEPNotificationType(enum.Enum):
NEW_CORE_ENGINES_INIT_READY = "NEW_CORE_ENGINES_INIT_READY"
NEW_CORE_ENGINES_WEIGHTS_INIT_READY = "NEW_CORE_ENGINES_WEIGHTS_INIT_READY"
RECONFIGURE_FINISHED = "RECONFIGURE_FINISHED"
SHUTDOWN_COMPLETE = "SHUTDOWN_COMPLETE"
class FinishReason(enum.IntEnum):
"""
@@ -235,6 +244,11 @@ class ReconfigureDistributedRequest(msgspec.Struct):
new_data_parallel_rank_local: int
new_data_parallel_master_ip: str
new_data_parallel_master_port: int
new_data_parallel_master_port_list: list[int]
new_stateless_world_group_port_list: list[list[int]]
new_stateless_dp_group_port_list: list[list[int]]
new_stateless_ep_group_port_list: list[list[int]]
new_stateless_eplb_group_port_list: list[list[int]]
class ReconfigureRankType(enum.IntEnum):
+27 -14
View File
@@ -20,6 +20,7 @@ from vllm.distributed.weight_transfer.base import (
)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient, StreamingInput
from vllm.entrypoints.serve.elastic_ep.middleware import set_scaling_elastic_ep
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -647,7 +648,11 @@ class AsyncLLM(EngineClient):
engine_core = self.engine_core
output_processor = self.output_processor
log_stats = self.log_stats
logger_manager = self.logger_manager
# We use a mutable list for logger_manager so that it can be updated
# during elastic EP scaling (see scale_elastic_ep) without creating
# a circular reference via self.
self._logger_ref = [self.logger_manager]
logger_ref = self._logger_ref
renderer = self.renderer
chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
@@ -691,8 +696,8 @@ class AsyncLLM(EngineClient):
# 4) Logging.
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
if logger_manager:
logger_manager.record(
if logger_ref[0]:
logger_ref[0].record(
engine_idx=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
@@ -976,17 +981,13 @@ class AsyncLLM(EngineClient):
new_data_parallel_size,
)
return
logger.info(
"Waiting for requests to drain before scaling up to %s engines...",
new_data_parallel_size,
)
await self.wait_for_requests_to_drain(drain_timeout)
logger.info(
"Requests have been drained, proceeding with scale to %s engines",
new_data_parallel_size,
)
await self.engine_core.scale_elastic_ep(new_data_parallel_size)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
if envs.VLLM_ELASTIC_EP_DRAIN_REQUESTS:
logger.info(
"VLLM_ELASTIC_EP_DRAIN_REQUESTS is set, "
"waiting for requests to drain before scaling"
)
await self.wait_for_requests_to_drain(drain_timeout)
# recreate stat loggers
if new_data_parallel_size > old_data_parallel_size and self.log_stats:
@@ -999,6 +1000,18 @@ class AsyncLLM(EngineClient):
engine_idxs=list(range(new_data_parallel_size)),
custom_stat_loggers=None,
)
# Update the mutable ref so output_handler picks up the
# new logger without creating a circular reference via self.
if hasattr(self, "_logger_ref"):
self._logger_ref[0] = self.logger_manager
self.logger_manager.log_engine_initialized()
set_scaling_elastic_ep(True)
try:
await self.engine_core.scale_elastic_ep(new_data_parallel_size)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
finally:
set_scaling_elastic_ep(False)
@property
def is_running(self) -> bool:
+20 -1
View File
@@ -71,6 +71,9 @@ class DPCoordinator:
)
local_only_eng = dp_size == parallel_config.data_parallel_size_local
# NOTE(yongji): handling scaling from intra-node to inter-node
if parallel_config.enable_elastic_ep:
local_only_eng = False
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
@@ -201,6 +204,7 @@ class DPCoordinatorProc:
poller = zmq.Poller()
poller.register(publish_front, zmq.POLLIN)
poller.register(publish_back, zmq.POLLIN)
poller.register(output_back, zmq.POLLIN)
last_publish_time = 0
while True:
@@ -231,6 +235,22 @@ class DPCoordinatorProc:
events = dict(events)
wave_state_changed = False
if publish_back in events:
buffer = publish_back.recv()
if buffer == b"\x01":
# NOTE(yongji): newly started engine subscribed
# We need to send READY message here instead of receiving
# SCALE_ELASTIC_EP notification from engine core client
# as SCALE_ELASTIC_EP is only sent when
# new engines finished initialization.
# Subscription message, on the other hand, is sent
# by each engine during initialization
publish_back.send(b"READY")
else:
logger.error(
"DP Coordinator receives unexpected message from engines"
)
if publish_front in events:
buffer = publish_front.recv()
if buffer in (b"\x01", b"\x00"):
@@ -259,7 +279,6 @@ class DPCoordinatorProc:
# current_wave
# we note that 0 is the wave number for the new
# engine
engines_running = False
logger.info(
"DPCoordinator scaled up from %s to %s engines",
current_count,
+171 -52
View File
@@ -17,6 +17,7 @@ from typing import Any, TypeVar, cast
import msgspec
import zmq
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.envs import enable_envs_cache
@@ -44,6 +45,8 @@ from vllm.v1.core.kv_cache_utils import (
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import (
EEP_NOTIFICATION_CALL_ID,
EEPNotificationType,
EngineCoreOutput,
EngineCoreOutputs,
EngineCoreRequest,
@@ -110,6 +113,9 @@ class EngineCore:
self.available_gpu_memory_for_kv_cache = -1
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
self._eep_scale_up_before_kv_init()
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
vllm_config
@@ -233,12 +239,10 @@ class EngineCore:
has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
if has_kv_cache:
if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
dp_group = getattr(self, "dp_group", None)
assert dp_group is not None
self.available_gpu_memory_for_kv_cache = (
ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
)
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
# NOTE(yongji): should already be set
# during _eep_scale_up_before_kv_init
assert self.available_gpu_memory_for_kv_cache > 0
available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
kv_cache_specs
)
@@ -752,11 +756,22 @@ class EngineCore:
self.structured_output_manager.grammar_init(req)
return req, request.current_wave
def _eep_scale_up_before_kv_init(self):
raise NotImplementedError
def _eep_send_engine_core_notification(
self,
notification_type: EEPNotificationType,
vllm_config: VllmConfig | None = None,
):
raise NotImplementedError
class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
addresses: EngineZmqAddresses
@instrument(span_name="EngineCoreProc init")
def __init__(
@@ -807,6 +822,13 @@ class EngineCoreProc(EngineCore):
# and "hybrid" LB modes.
self.publish_dp_lb_stats = internal_dp_balancing
self.addresses = addresses
self.process_input_queue_block = True
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
self._eep_send_engine_core_notification(
EEPNotificationType.NEW_CORE_ENGINES_INIT_READY,
vllm_config=vllm_config,
)
self._init_data_parallel(vllm_config)
super().__init__(
@@ -1119,8 +1141,14 @@ class EngineCoreProc(EngineCore):
if logger.isEnabledFor(DEBUG):
logger.debug("EngineCore waiting for work.")
waited = True
req = self.input_queue.get()
self._handle_client_request(*req)
block = self.process_input_queue_block
try:
req = self.input_queue.get(block=block)
self._handle_client_request(*req)
except queue.Empty:
break
if not block:
break
if waited:
logger.debug("EngineCore loop active.")
@@ -1290,6 +1318,11 @@ class EngineCoreProc(EngineCore):
for input_socket, _ in poller.poll():
# (RequestType, RequestData)
type_frame, *data_frames = input_socket.recv_multipart(copy=False)
# NOTE(yongji): ignore READY message sent by DP coordinator
# that is used to notify newly started engines
if type_frame.buffer == b"READY":
assert input_socket == coord_socket
continue
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Deserialize the request data.
@@ -1488,6 +1521,10 @@ class DPEngineCoreProc(EngineCoreProc):
self.current_wave = 0
self.last_counts = (0, 0)
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
self.eep_scaling_state: ElasticEPScalingState | None = None
# Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank
super().__init__(
@@ -1511,7 +1548,9 @@ class DPEngineCoreProc(EngineCoreProc):
assert 0 <= local_dp_rank <= dp_rank < dp_size
self.dp_rank = dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
self.dp_group, self.dp_store = (
vllm_config.parallel_config.stateless_init_dp_group(return_store=True)
)
def shutdown(self):
super().shutdown()
@@ -1574,7 +1613,12 @@ class DPEngineCoreProc(EngineCoreProc):
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
# 2) Step the engine core.
if self.eep_scaling_state is not None:
_ = self.eep_scaling_state.progress()
if self.eep_scaling_state.is_complete():
self.process_input_queue_block = True
self.eep_scaling_state = None
executed = self._process_engine_step()
self._maybe_publish_request_counts()
@@ -1624,54 +1668,129 @@ class DPEngineCoreProc(EngineCoreProc):
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
stateless_destroy_torch_distributed_process_group(self.dp_group)
self.shutdown()
from copy import deepcopy
parallel_config = self.vllm_config.parallel_config
old_dp_size = parallel_config.data_parallel_size
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
if reconfig_request.new_data_parallel_rank != -1:
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
# local rank specifies device visibility, it should not be changed
assert (
reconfig_request.new_data_parallel_rank_local
== ReconfigureRankType.KEEP_CURRENT_RANK
)
parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
if reconfig_request.new_data_parallel_rank != -2:
self.dp_rank = parallel_config.data_parallel_rank
self.dp_group = parallel_config.stateless_init_dp_group()
reconfig_request.new_data_parallel_master_port = (
parallel_config.data_parallel_master_port
)
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
self.model_executor.reinitialize_distributed(reconfig_request)
if reconfig_request.new_data_parallel_size > old_dp_size:
assert self.available_gpu_memory_for_kv_cache > 0
# pass available_gpu_memory_for_kv_cache from existing
# engine-cores to new engine-cores so they can directly
# use it in _initialize_kv_caches() rather than profiling.
ParallelConfig.sync_kv_cache_memory_size(
self.dp_group, self.available_gpu_memory_for_kv_cache
)
# NOTE(yongji): newly joined workers require dummy_run even
# CUDA graph is not used
self.model_executor.collective_rpc("compile_or_warm_up_model")
new_parallel_config = deepcopy(self.vllm_config.parallel_config)
old_dp_size = new_parallel_config.data_parallel_size
new_parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
if (
reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
self.shutdown()
logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
else:
logger.info(
"Distributed environment reinitialized for DP rank %s", self.dp_rank
new_parallel_config.data_parallel_rank = (
reconfig_request.new_data_parallel_rank
)
new_parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
new_parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
new_parallel_config._data_parallel_master_port_list = (
reconfig_request.new_data_parallel_master_port_list
)
is_scale_down = reconfig_request.new_data_parallel_size < old_dp_size
is_shutdown = (
reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
)
self.eep_scaling_state = ElasticEPScalingState(
model_executor=self.model_executor,
engine_core=self,
vllm_config=self.vllm_config,
new_parallel_config=new_parallel_config,
worker_type="removing" if is_shutdown else "existing",
scale_type="scale_down" if is_scale_down else "scale_up",
reconfig_request=reconfig_request,
)
self.process_input_queue_block = False
logger.info(
"[Elastic EP] Received reconfiguration request and starting scaling up/down"
)
def _eep_send_engine_core_notification(
self,
notification_type: EEPNotificationType,
vllm_config: VllmConfig | None = None,
):
"""
Send notifications to EngineCoreClient, which can then forward
the notifications to other engine core processes. It is used for:
1) In scale up: new core engines to notify exisiting core engines
that they are ready;
2) In scale down: removing core engines to notify EngineCoreClient
so EngineCoreClient can release their ray placement groups;
3) Both scale up/down: to notify EngineCoreClient that exisiting
core engines have already switched to the new parallel setup.
"""
if vllm_config is None:
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
else:
dp_rank = vllm_config.parallel_config.data_parallel_rank
notification_data = (notification_type.value, dp_rank)
outputs = EngineCoreOutputs(
utility_output=UtilityOutput(
call_id=EEP_NOTIFICATION_CALL_ID,
result=UtilityResult(notification_data),
)
)
outputs.engine_index = self.engine_index
if hasattr(self, "output_thread") and self.output_thread.is_alive():
self.output_queue.put_nowait((0, outputs))
else:
encoder = MsgpackEncoder()
with (
zmq.Context() as ctx,
make_zmq_socket(
ctx, self.addresses.outputs[0], zmq.PUSH, linger=4000
) as socket,
):
socket.send_multipart(encoder.encode(outputs))
def eep_handle_engine_core_notification(
self, notification_type: str | EEPNotificationType
):
"""
Handle notification received from EngineCoreClient
(forwarded from new core engines).
"""
assert self.eep_scaling_state is not None
if isinstance(notification_type, str):
notification_type = EEPNotificationType(notification_type)
self.eep_scaling_state.handle_notification(notification_type)
def _eep_scale_up_before_kv_init(self):
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
self.eep_scaling_state = ElasticEPScalingState(
model_executor=self.model_executor,
engine_core=self,
vllm_config=self.vllm_config,
new_parallel_config=self.vllm_config.parallel_config,
worker_type="new",
scale_type="scale_up",
reconfig_request=None,
)
self.model_executor.collective_rpc("init_device")
self.model_executor.collective_rpc("load_model")
self._eep_send_engine_core_notification(
EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("receive_weights",)
)
self.available_gpu_memory_for_kv_cache = (
ParallelConfig.sync_kv_cache_memory_size(self.dp_group, -1)
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("prepare_new_worker",)
)
self.process_input_queue_block = False
class EngineCoreActorMixin:
+268 -51
View File
@@ -28,11 +28,12 @@ from vllm.tracing import instrument
from vllm.utils.async_utils import in_loop
from vllm.utils.network_utils import (
close_sockets,
get_open_port,
get_open_zmq_inproc_path,
make_zmq_socket,
)
from vllm.v1.engine import (
EEP_NOTIFICATION_CALL_ID,
EEPNotificationType,
EngineCoreOutputs,
EngineCoreRequest,
EngineCoreRequestType,
@@ -47,6 +48,7 @@ from vllm.v1.engine.exceptions import EngineDeadError
from vllm.v1.engine.utils import (
CoreEngineActorManager,
CoreEngineProcManager,
get_engine_zmq_addresses,
launch_core_engines,
)
from vllm.v1.executor import Executor
@@ -445,6 +447,63 @@ class BackgroundResources:
raise EngineDeadError()
@dataclass
class ElasticScalingCache:
existing_core_engines: list[EngineIdentity]
num_new_core_engines: int
pending_notifications: dict[EEPNotificationType, set[int]]
def allocate_stateless_group_ports(parallel_config, new_data_parallel_size: int):
"""
Allocate stateless group ports for elastic EP.
"""
from vllm.utils.network_utils import get_open_ports_list
assert parallel_config.enable_elastic_ep, "Elastic EP must be enabled"
world_size = parallel_config.world_size
new_world_size_across_dp = world_size * new_data_parallel_size
num_world_groups = 1
num_dp_groups = max(1, new_world_size_across_dp // new_data_parallel_size)
num_ep_groups = max(
1,
new_world_size_across_dp
// (new_data_parallel_size * parallel_config.tensor_parallel_size),
)
num_eplb_groups = num_ep_groups
total_ports_needed = (
num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
) * 3 + 5
all_ports = get_open_ports_list(total_ports_needed)
new_data_parallel_master_port_list = all_ports[-5:]
all_ports = all_ports[:-5]
new_stateless_world_group_port_list = [
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
]
start_idx = num_world_groups * 3
new_stateless_dp_group_port_list = [
all_ports[i : i + 3] for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
]
start_idx += num_dp_groups * 3
new_stateless_ep_group_port_list = [
all_ports[i : i + 3] for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
]
start_idx += num_ep_groups * 3
new_stateless_eplb_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
]
parallel_config._stateless_world_group_port_list = (
new_stateless_world_group_port_list
)
parallel_config._stateless_dp_group_port_list = new_stateless_dp_group_port_list
parallel_config._stateless_ep_group_port_list = new_stateless_ep_group_port_list
parallel_config._stateless_eplb_group_port_list = new_stateless_eplb_group_port_list
parallel_config.data_parallel_master_port = new_data_parallel_master_port_list.pop()
parallel_config._data_parallel_master_port_list = new_data_parallel_master_port_list
class MPClient(EngineCoreClient):
"""
MPClient: base client for multi-proc EngineCore.
@@ -491,32 +550,37 @@ class MPClient(EngineCoreClient):
input_address = client_addresses["input_address"]
output_address = client_addresses["output_address"]
self.stats_update_address = client_addresses.get("stats_update_address")
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx, input_address, zmq.ROUTER, bind=True
)
self.resources.output_socket = make_zmq_socket(
self.ctx, output_address, zmq.PULL
)
else:
# Engines are managed by this client.
with launch_core_engines(vllm_config, executor_class, log_stats) as (
engine_manager,
coordinator,
addresses = get_engine_zmq_addresses(vllm_config)
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx, addresses.inputs[0], zmq.ROUTER, bind=True
)
self.resources.output_socket = make_zmq_socket(
self.ctx, addresses.outputs[0], zmq.PULL
)
with launch_core_engines(
vllm_config,
executor_class,
log_stats,
addresses,
):
) as (engine_manager, coordinator, addresses):
self.resources.coordinator = coordinator
self.resources.engine_manager = engine_manager
(input_address,) = addresses.inputs
(output_address,) = addresses.outputs
self.stats_update_address = addresses.frontend_stats_publish_address
if coordinator is not None:
assert self.stats_update_address == (
coordinator.get_stats_publish_address()
)
# Create input and output sockets.
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx, input_address, zmq.ROUTER, bind=True
)
self.resources.output_socket = make_zmq_socket(
self.ctx, output_address, zmq.PULL
)
parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_index
@@ -877,6 +941,10 @@ class AsyncMPClient(MPClient):
output_socket = resources.output_socket
assert output_socket is not None
notification_callback_handler: (
Callable[[AsyncMPClient, Sequence[Any]], Any] | None
) = getattr(self.__class__, "eep_process_engine_core_notification", None)
async def process_outputs_socket():
try:
while True:
@@ -884,7 +952,26 @@ class AsyncMPClient(MPClient):
resources.validate_alive(frames)
outputs: EngineCoreOutputs = decoder.decode(frames)
if outputs.utility_output:
_process_utility_output(outputs.utility_output, utility_results)
if (
outputs.utility_output.call_id == EEP_NOTIFICATION_CALL_ID
and notification_callback_handler is not None
):
assert _self_ref is not None
_self = _self_ref()
if not _self:
return
if outputs.utility_output.result is None:
continue
notification_data = outputs.utility_output.result.result
assert isinstance(notification_data, Sequence)
assert len(notification_data) == 2
asyncio.create_task(
notification_callback_handler(_self, notification_data)
)
else:
_process_utility_output(
outputs.utility_output, utility_results
)
continue
if output_handler is not None:
@@ -1081,6 +1168,8 @@ class DPAsyncMPClient(AsyncMPClient):
# Used only by DPLBAsyncMPClient subclass.
self.lb_engines: list[list[int]] = [[0, 0] for _ in self.core_engines]
self.eep_scaling_cache: ElasticScalingCache | None = None
self.first_req_sock_addr = get_open_zmq_inproc_path()
self.first_req_send_socket = self.resources.first_req_send_socket = (
make_zmq_socket(self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=True)
@@ -1101,12 +1190,6 @@ class DPAsyncMPClient(AsyncMPClient):
assert self.stats_update_address is not None
stats_addr: str = self.stats_update_address
assert len(self.engine_ranks_managed) > 0
# NOTE: running and waiting counts are all global from
# the Coordinator include all global EngineCores. This
# slice includes just the cores managed by this client.
count_slice = slice(
self.engine_ranks_managed[0], self.engine_ranks_managed[-1] + 1
)
async def run_engine_stats_update_task():
with (
@@ -1145,6 +1228,29 @@ class DPAsyncMPClient(AsyncMPClient):
):
# Extract new engine count from the decoded message
new_engine_count = decoded[1]
# Update engine_ranks_managed and count_slice
parallel_config = self.vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
assert dp_rank == 0
assert dp_size == new_engine_count
assert not (
parallel_config.data_parallel_hybrid_lb
or parallel_config.data_parallel_external_lb
)
num_ranks = dp_size
self.engine_ranks_managed = list(
range(dp_rank, dp_rank + num_ranks)
)
if len(self.lb_engines) < new_engine_count:
self.lb_engines = self.lb_engines + [
[0, 0]
for _ in range(
new_engine_count - len(self.lb_engines)
)
]
else:
self.lb_engines = self.lb_engines[:new_engine_count]
# Send scale up notification to coordinator
scale_msg = msgspec.msgpack.encode(
("SCALE_ELASTIC_EP", new_engine_count)
@@ -1178,6 +1284,11 @@ class DPAsyncMPClient(AsyncMPClient):
self.current_wave = wave
self.engines_running = running
if counts is not None:
# Running and waiting counts are global from the
# Coordinator including all EngineCores. Slice to get
# just the cores managed by this client.
ranks = self.engine_ranks_managed
count_slice = slice(ranks[0], ranks[-1] + 1)
sliced_counts = counts[count_slice]
self.lb_engines = sliced_counts
logger.debug(
@@ -1287,6 +1398,67 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
for req_id in outputs.finished_requests:
self.reqs_in_flight.pop(req_id, None)
@staticmethod
async def eep_process_engine_core_notification(
self: "DPLBAsyncMPClient", notification_data: tuple[str, int]
):
cache = self.eep_scaling_cache
notification_type_str, dp_rank = notification_data
try:
notification_type = EEPNotificationType(notification_type_str)
except ValueError as e:
raise ValueError(
f"Unknown EEP notification type: {notification_type_str}"
) from e
if notification_type == EEPNotificationType.RECONFIGURE_FINISHED:
from vllm.v1.engine import UtilityResult
# NOTE(yongji): process a dummy UtilityOutput to resolve the future
# awaited in _eep_wait_for_setup_switch_complete(), signaling that
# all engine cores have completed reconfiguration.
dummy_output = UtilityOutput(
call_id=EEP_NOTIFICATION_CALL_ID, result=UtilityResult(None)
)
_process_utility_output(dummy_output, self.utility_results)
return
assert cache is not None
if notification_type not in cache.pending_notifications:
cache.pending_notifications[notification_type] = set()
if dp_rank in cache.pending_notifications[notification_type]:
raise ValueError(
f"Duplicate notification {notification_type} from dp_rank {dp_rank}"
)
cache.pending_notifications[notification_type].add(dp_rank)
if len(cache.pending_notifications[notification_type]) >= abs(
cache.num_new_core_engines
):
if notification_type == EEPNotificationType.SHUTDOWN_COMPLETE:
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
assert cache.num_new_core_engines < 0
old_dp_size = len(cache.existing_core_engines)
new_dp_size = old_dp_size + cache.num_new_core_engines
self.resources.engine_manager.scale_down_elastic_ep(
old_dp_size, new_dp_size
)
else:
await asyncio.gather(
*[
self._call_utility_async(
"eep_handle_engine_core_notification",
notification_type,
engine=engine,
)
for engine in cache.existing_core_engines
]
)
cache.pending_notifications[notification_type] = set()
if notification_type in [
EEPNotificationType.SHUTDOWN_COMPLETE,
EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY,
]:
self.eep_scaling_cache = None
async def abort_requests_async(self, request_ids: list[str]) -> None:
if not request_ids or self.resources.engine_dead:
return
@@ -1333,6 +1505,20 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
cur_data_parallel_size, new_data_parallel_size
)
async def _eep_wait_for_setup_switch_complete(self) -> None:
"""
Wait for core engines to switch to the new setup.
In eep_process_engine_core_notification(), a dummy UtilityOutput with
EEP_NOTIFICATION_CALL_ID will be set when RECONFIGURE_FINISHED
notification is received from engine 0. We create a future with
that call_id and wait for it to be resolved.
"""
future = asyncio.get_running_loop().create_future()
self.utility_results[EEP_NOTIFICATION_CALL_ID] = future
self._ensure_output_queue_task()
await future
async def _scale_up_elastic_ep(
self, cur_data_parallel_size: int, new_data_parallel_size: int
) -> None:
@@ -1340,38 +1526,57 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
and reconfiguring existing ones."""
cur_data_parallel_size = len(self.core_engines)
# Phase 1: Send reconfigure messages to all existing engines and wait
# for them to be sent
self.eep_scaling_cache = ElasticScalingCache(
existing_core_engines=self.core_engines.copy(),
num_new_core_engines=new_data_parallel_size - cur_data_parallel_size,
pending_notifications=dict(),
)
parallel_config = self.vllm_config.parallel_config
allocate_stateless_group_ports(parallel_config, new_data_parallel_size)
# Phase 1: Send reconfig messages to existing engines
reconfig_futures = []
self.vllm_config.parallel_config.data_parallel_master_port = get_open_port()
for engine in self.core_engines:
reconfig_request = ReconfigureDistributedRequest(
new_data_parallel_size=new_data_parallel_size,
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip,
new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port,
new_data_parallel_master_ip=parallel_config.data_parallel_master_ip,
new_data_parallel_master_port=parallel_config.data_parallel_master_port,
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list,
new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list,
new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list,
new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list,
)
coro = self._call_utility_async(
"reinitialize_distributed", reconfig_request, engine=engine
)
reconfig_futures.append(asyncio.create_task(coro))
logger.info("All reconfigure messages sent, starting engine creation")
# Phase 2: Create new engines now that reconfig messages have been sent
# self.resources.engine_manager is guaranteed to be
# CoreEngineActorManager for RayDPClient
# Phase 2: Create new engines
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
self.resources.engine_manager.scale_up_elastic_ep(
self.vllm_config, new_data_parallel_size
parallel_config.eplb_config.num_redundant_experts = 0
start_new_worker_future = asyncio.to_thread(
self.resources.engine_manager.scale_up_elastic_ep,
self.vllm_config,
new_data_parallel_size,
)
wait_future = self._eep_wait_for_setup_switch_complete()
# Phase 3: Wait for new engines to be created
# and reconfig messages to be received
await asyncio.gather(start_new_worker_future, *reconfig_futures)
logger.info("[Elastic EP] Successfully started new engines")
# Create new CoreEngine objects for the new engines
new_engine_identities = set()
for i in range(cur_data_parallel_size, new_data_parallel_size):
new_engine = i.to_bytes(2, "little")
self.core_engines.append(new_engine)
# NOTE(yongji): we don't update lb_engines here,
# we let run_engine_stats_update_task to update it.
new_engine_identities.add(new_engine)
# Wait for ready messages from new engines on the input socket
@@ -1387,10 +1592,11 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
identity, _ = sync_input_socket.recv_multipart()
new_engine_identities.discard(identity)
# Phase 3: Wait for all existing engines to complete reconfiguration
logger.info("Waiting for existing engines to complete reconfiguration")
await asyncio.gather(*reconfig_futures)
# NOTE(yongji): Before we schedule any requests on the new workers,
# we should wait for them to switch to the new setup.
await wait_future
# Update the parallel config
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
# Notify coordinator about scale up through existing
# stats_update_task connection
self._ensure_stats_update_task()
@@ -1399,8 +1605,6 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
)
await self.first_req_send_socket.send(scale_up_marker)
# Update the parallel config
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
logger.info(
"[Elastic EP] Scale up completed, new data parallel size: %s",
new_data_parallel_size,
@@ -1413,7 +1617,14 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
reconfiguring existing engine cores."""
cur_data_parallel_size = len(self.core_engines)
self.vllm_config.parallel_config.data_parallel_master_port = get_open_port()
self.eep_scaling_cache = ElasticScalingCache(
existing_core_engines=self.core_engines.copy(),
num_new_core_engines=new_data_parallel_size - cur_data_parallel_size,
pending_notifications=dict(),
)
parallel_config = self.vllm_config.parallel_config
allocate_stateless_group_ports(parallel_config, new_data_parallel_size)
reconfig_futures = []
for cur_dp_rank, engine in enumerate(self.core_engines):
@@ -1421,8 +1632,13 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
new_data_parallel_size=new_data_parallel_size,
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip,
new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port,
new_data_parallel_master_ip=parallel_config.data_parallel_master_ip,
new_data_parallel_master_port=parallel_config.data_parallel_master_port,
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list,
new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list,
new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list,
new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list,
)
if cur_dp_rank >= new_data_parallel_size:
reconfig_request.new_data_parallel_rank = (
@@ -1433,23 +1649,24 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
)
reconfig_futures.append(asyncio.create_task(coro))
for _ in range(new_data_parallel_size, cur_data_parallel_size):
self.core_engines.pop()
# NOTE(yongji): Immediately stop sending requests to the removing engines.
self.core_engines = self.core_engines[:new_data_parallel_size]
self.lb_engines = self.lb_engines[:new_data_parallel_size]
wait_future = self._eep_wait_for_setup_switch_complete()
await asyncio.gather(*reconfig_futures)
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
self.resources.engine_manager.scale_down_elastic_ep(
cur_data_parallel_size, new_data_parallel_size
)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
self._ensure_stats_update_task()
scale_down_marker = msgspec.msgpack.encode(
("SCALE_ELASTIC_EP", new_data_parallel_size)
)
await self.first_req_send_socket.send(scale_down_marker)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
# NOTE(yongji): Unlike scaling up,
# here we don't actually need to wait for the setup switch to complete.
# We may want to remove it in the future.
await wait_future
logger.info(
"[Elastic EP] Scale down completed, new data parallel size: %s",
new_data_parallel_size,
+47 -21
View File
@@ -277,6 +277,8 @@ class CoreEngineActorManager:
else:
ray.init()
vllm_config.parallel_config.allocate_elastic_ep_ports()
if placement_groups is not None:
assert local_dp_ranks is not None, (
"local_dp_ranks must be provided if placement_groups is provided"
@@ -584,6 +586,8 @@ class CoreEngineActorManager:
node_ip = node.node_ip
node_id = node.node_id
if device_str not in available_resources[node_id]:
continue
available_gpus = int(available_resources[node_id][device_str])
# Get total GPUs on this node from the node's resources
@@ -773,11 +777,50 @@ class CoreEngineActorManager:
ray.util.remove_placement_group(pg)
def get_engine_zmq_addresses(
vllm_config: VllmConfig,
num_api_servers: int = 1,
) -> EngineZmqAddresses:
"""Allocate ZMQ addresses for engine-client communication."""
parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local
local_start_index = parallel_config.data_parallel_rank_local
dp_size = parallel_config.data_parallel_size
host = parallel_config.data_parallel_master_ip
local_engines_only = parallel_config.local_engines_only
# In offline mode there is an LLM instance per DP rank and
# one core engine per LLM, see
# examples/offline_inference/data_parallel.py.
offline_mode = local_start_index is not None
# client_local_only = True for cases where this front-end
# sends requests only to colocated engines.
client_local_only = (
offline_mode or local_engines_only or (local_engine_count == dp_size)
)
# NOTE(yongji): handling scaling from intra-node to inter-node
if parallel_config.enable_elastic_ep:
client_local_only = False
return EngineZmqAddresses(
inputs=[
get_engine_client_zmq_addr(client_local_only, host)
for _ in range(num_api_servers)
],
outputs=[
get_engine_client_zmq_addr(client_local_only, host)
for _ in range(num_api_servers)
],
)
@contextlib.contextmanager
def launch_core_engines(
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
addresses: EngineZmqAddresses,
num_api_servers: int = 1,
) -> Iterator[
tuple[
@@ -796,29 +839,8 @@ def launch_core_engines(
host = parallel_config.data_parallel_master_ip
local_engines_only = parallel_config.local_engines_only
# In offline mode there is an LLM instance per DP rank and
# one core engine per LLM, see
# examples/offline_inference/data_parallel.py.
offline_mode = local_start_index is not None
# client_local_only = True for cases where this front-end
# sends requests only to colocated engines.
client_local_only = (
offline_mode or local_engines_only or (local_engine_count == dp_size)
)
# Set up input and output addresses.
addresses = EngineZmqAddresses(
inputs=[
get_engine_client_zmq_addr(client_local_only, host)
for _ in range(num_api_servers)
],
outputs=[
get_engine_client_zmq_addr(client_local_only, host)
for _ in range(num_api_servers)
],
)
# Run the DP Coordinator process with rank 0 when in online DP mode.
# The coordinator is needed for:
# 1. Internal/hybrid LB: collecting and publishing queue stats for load balancing
@@ -885,6 +907,10 @@ def launch_core_engines(
# will be False.
handshake_local_only = offline_mode or local_engine_count == dp_size
# NOTE(yongji): handling scaling from intra-node to inter-node
if parallel_config.enable_elastic_ep:
handshake_local_only = False
handshake_address = get_engine_client_zmq_addr(
handshake_local_only, host, parallel_config.data_parallel_rpc_port
)
+16 -5
View File
@@ -38,6 +38,7 @@ from vllm.distributed.parallel_state import (
get_pcp_group,
get_pp_group,
get_tp_group,
model_parallel_is_initialized,
)
from vllm.envs import enable_envs_cache
from vllm.logger import init_logger
@@ -580,17 +581,20 @@ class WorkerProc:
)
self.async_output_copy_thread.start()
# Initialize device
self.worker.init_device()
# Set process title and log prefix
self.setup_proc_title_and_log_prefix(
enable_ep=vllm_config.parallel_config.enable_expert_parallel
)
# Load model
self._init_message_queues(input_shm_handle, vllm_config)
self.worker.load_model()
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
if not is_eep_new_worker:
self.worker.init_device()
# Update process title now that parallel groups are initialized
self.setup_proc_title_and_log_prefix(
enable_ep=vllm_config.parallel_config.enable_expert_parallel
)
self.worker.load_model()
# Enable environment variable cache (e.g. assume no more
# environment variable overrides after this point)
@@ -885,6 +889,13 @@ class WorkerProc:
@staticmethod
def setup_proc_title_and_log_prefix(enable_ep: bool) -> None:
# Check if parallel groups are initialized first
if not model_parallel_is_initialized():
# Parallel groups not yet initialized, use default process name
set_process_title(name="Worker")
decorate_logs("Worker")
return
dp_size = get_dp_group().world_size
dp_rank = get_dp_group().rank_in_group
pp_size = get_pp_group().world_size
+4 -2
View File
@@ -382,8 +382,10 @@ class RayDistributedExecutor(Executor):
all_kwargs.append(kwargs)
self.collective_rpc("init_worker", args=(all_kwargs,))
self.collective_rpc("init_device")
self.collective_rpc("load_model")
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
if not is_eep_new_worker:
self.collective_rpc("init_device")
self.collective_rpc("load_model")
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
+4 -13
View File
@@ -14,7 +14,6 @@ import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
from vllm.v1.serial_utils import run_method
@@ -43,9 +42,11 @@ class UniProcExecutor(Executor):
max_workers=1, thread_name_prefix="WorkerAsyncOutput"
)
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
self.driver_worker.init_worker(all_kwargs=[kwargs])
self.driver_worker.init_device()
self.driver_worker.load_model()
if not is_eep_new_worker:
self.driver_worker.init_device()
self.driver_worker.load_model()
def _distributed_args(self) -> tuple[str, int, int]:
"""Return (distributed_init_method, rank, local_rank)."""
@@ -122,16 +123,6 @@ class UniProcExecutor(Executor):
# it's running.
return
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
self.driver_worker.reinitialize_distributed(reconfig_request)
if (
reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
):
self.shutdown()
def shutdown(self) -> None:
if worker := self.driver_worker:
worker.shutdown()
+6 -1
View File
@@ -53,7 +53,12 @@ class CPUModelRunner(GPUModelRunner):
v.gpu = v.cpu
@instrument(span_name="Loading (CPU)")
def load_model(self, eep_scale_up: bool = False) -> None:
def load_model(self, load_dummy_weights: bool = False) -> None:
if load_dummy_weights:
raise ValueError(
"Loading dummy weights (needed for elastic EP scale-up) "
"Is not supported by the CPU Model Runner."
)
logger.info("Starting to load model %s...", self.model_config.model)
self.model = get_model(vllm_config=self.vllm_config)
+39 -40
View File
@@ -461,6 +461,8 @@ class GPUModelRunner(
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
self.eplb_state: EplbState | None = None
# NOTE(yongji): flag to temporarily disable EPLB during scaling up/down
self.eep_eplb_suppressed = False
"""
State of the expert parallelism load balancer.
@@ -2702,7 +2704,7 @@ class GPUModelRunner(
"""
Step for the EPLB (Expert Parallelism Load Balancing) state.
"""
if not self.parallel_config.enable_eplb:
if not self.parallel_config.enable_eplb or self.eep_eplb_suppressed:
return
assert self.eplb_state is not None
@@ -2714,6 +2716,23 @@ class GPUModelRunner(
log_stats=self.parallel_config.eplb_config.log_balancedness,
)
def setup_eplb_from_mapping(
self,
expanded_physical_to_logical: torch.Tensor,
old_num_physical_experts: int,
) -> None:
model = self.get_model()
assert is_mixture_of_experts(model)
self.eplb_state = EplbState.from_mapping(
model=model,
model_config=self.model_config,
device=self.device,
parallel_config=self.parallel_config,
expanded_physical_to_logical=expanded_physical_to_logical,
num_valid_physical_experts=old_num_physical_experts,
)
def _pool(
self,
hidden_states: torch.Tensor,
@@ -4175,21 +4194,16 @@ class GPUModelRunner(
setattr(self, config_name, new_config)
@instrument(span_name="Loading (GPU)")
def load_model(self, eep_scale_up: bool = False) -> None:
def load_model(self, load_dummy_weights: bool = False) -> None:
"""
Args:
eep_scale_up: the model loading is for elastic EP scale up.
load_dummy_weights: load dummy weights instead of real weights.
"""
logger.info_once(
"Starting to load model %s...",
self.model_config.model,
scope="global",
)
global_expert_loads, old_global_expert_indices_per_model, rank_mapping = (
EplbState.get_eep_state(self.parallel_config)
if eep_scale_up
else (None, None, None)
)
if self.parallel_config.enable_eplb:
self.eplb_state = EplbState(self.parallel_config, self.device)
@@ -4198,6 +4212,8 @@ class GPUModelRunner(
try:
with DeviceMemoryProfiler() as m:
time_before_load = time.perf_counter()
if load_dummy_weights:
self.load_config.load_format = "dummy"
model_loader = get_model_loader(self.load_config)
self.model = model_loader.load_model(
vllm_config=self.vllm_config, model_config=self.model_config
@@ -4214,6 +4230,9 @@ class GPUModelRunner(
and is_mixture_of_experts(self.drafter.model)
and self.parallel_config.enable_eplb
):
assert not self.parallel_config.enable_elastic_ep, (
"Elastic EP is not supported with drafter model."
)
spec_config = self.vllm_config.speculative_config
assert spec_config is not None
assert spec_config.draft_model_config is not None
@@ -4221,17 +4240,6 @@ class GPUModelRunner(
"EPLB is enabled for drafter model %s.",
spec_config.draft_model_config.model,
)
global_expert_load = (
global_expert_loads[eplb_models]
if global_expert_loads
else None
)
old_global_expert_indices = (
old_global_expert_indices_per_model[eplb_models]
if old_global_expert_indices_per_model
else None
)
if self.eplb_state is None:
self.eplb_state = EplbState(
self.parallel_config, self.device
@@ -4239,9 +4247,6 @@ class GPUModelRunner(
self.eplb_state.add_model(
self.drafter.model,
spec_config.draft_model_config,
global_expert_load,
old_global_expert_indices,
rank_mapping,
)
eplb_models += 1
@@ -4283,11 +4288,12 @@ class GPUModelRunner(
time_after_load - time_before_load,
scope="local",
)
prepare_communication_buffer_for_model(self.model)
if (drafter := getattr(self, "drafter", None)) and (
drafter_model := getattr(drafter, "model", None)
):
prepare_communication_buffer_for_model(drafter_model)
if not load_dummy_weights:
prepare_communication_buffer_for_model(self.model)
if (drafter := getattr(self, "drafter", None)) and (
drafter_model := getattr(drafter, "model", None)
):
prepare_communication_buffer_for_model(drafter_model)
mm_config = self.model_config.multimodal_config
self.is_multimodal_pruning_enabled = (
supports_multimodal_pruning(self.get_model())
@@ -4295,26 +4301,19 @@ class GPUModelRunner(
and mm_config.is_multimodal_pruning_enabled()
)
if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb:
if (
is_mixture_of_experts(self.model)
and self.parallel_config.enable_eplb
and not load_dummy_weights
):
logger.info_once("EPLB is enabled for model %s.", self.model_config.model)
global_expert_load = (
global_expert_loads[eplb_models] if global_expert_loads else None
)
old_global_expert_indices = (
old_global_expert_indices_per_model[eplb_models]
if old_global_expert_indices_per_model
else None
)
assert self.eplb_state is not None
self.eplb_state.add_model(
self.model,
self.model_config,
global_expert_load,
old_global_expert_indices,
rank_mapping,
)
if self.eplb_state.is_async:
self.eplb_state.start_async_loop(rank_mapping=rank_mapping)
self.eplb_state.start_async_loop()
if (
self.vllm_config.compilation_config.mode
+27 -228
View File
@@ -7,11 +7,10 @@ import os
from collections.abc import Callable
from contextlib import AbstractContextManager, nullcontext
from types import NoneType
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
@@ -32,14 +31,12 @@ from vllm.distributed.kv_transfer import (
)
from vllm.distributed.parallel_state import (
Handle,
get_pcp_group,
get_pp_group,
get_tp_group,
)
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.models.interfaces import is_mixture_of_experts
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
from vllm.platforms import current_platform
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
@@ -49,7 +46,6 @@ from vllm.tracing import instrument
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (
AsyncModelRunnerOutput,
@@ -124,6 +120,10 @@ class Worker(WorkerBase):
precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
torch.set_float32_matmul_precision(precision)
from vllm.distributed.elastic_ep.elastic_execute import ElasticEPScalingExecutor
self.elastic_ep_executor = ElasticEPScalingExecutor(self)
# Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
@@ -317,12 +317,29 @@ class Worker(WorkerBase):
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
# to hijack tensor allocation.
def load_model(self) -> None:
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
dummy_weights = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
if dummy_weights:
(
expanded_physical_to_logical,
num_logical_experts,
old_num_physical_experts,
) = self.elastic_ep_executor.receive_expert_mapping()
num_physical_experts = expanded_physical_to_logical.shape[1]
self.parallel_config.eplb_config.num_redundant_experts = (
num_physical_experts - num_logical_experts
)
with (
self._maybe_get_memory_pool_context(tag="weights"),
set_current_vllm_config(self.vllm_config),
):
self.model_runner.load_model(eep_scale_up=eep_scale_up)
self.model_runner.load_model(load_dummy_weights=dummy_weights)
if dummy_weights:
self.model_runner.setup_eplb_from_mapping(
expanded_physical_to_logical, old_num_physical_experts
)
self.model_runner.eep_eplb_suppressed = True
def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)
@@ -801,227 +818,6 @@ class Worker(WorkerBase):
# worker will always be healthy as long as it's running.
return
def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
from vllm.distributed.parallel_state import get_ep_group
if get_ep_group().rank == 0:
logger.info(
"[Elastic EP] Starting expert resharding before scaling down..."
)
rank_mapping = {
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
for old_ep_rank in range(old_ep_size)
}
assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange(
execute_shuffle=True,
global_expert_loads=None,
rank_mapping=rank_mapping,
)
torch.cuda.synchronize()
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed!")
def _eplb_after_scale_up(
self,
old_ep_size: int,
new_ep_size: int,
global_expert_loads: list[torch.Tensor] | None,
) -> None:
from vllm.distributed.parallel_state import get_ep_group
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Starting expert resharding after scaling up...")
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange(
execute_shuffle=True,
global_expert_loads=global_expert_loads,
rank_mapping=rank_mapping,
)
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed!")
def _reconfigure_parallel_config(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
"""
Update parallel config with provided reconfig_request
"""
parallel_config = self.vllm_config.parallel_config
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
if (
reconfig_request.new_data_parallel_rank
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
if (
reconfig_request.new_data_parallel_rank_local
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank_local = (
reconfig_request.new_data_parallel_rank_local
)
parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
def _reconfigure_moe(
self, old_ep_size: int, new_ep_size: int
) -> list[torch.Tensor] | None:
"""
Reconfigure MoE modules with provided reconfig_request
Return the global expert load if new_ep_size > old_ep_size,
otherwise None
"""
from vllm.distributed.parallel_state import (
get_dp_group,
get_ep_group,
prepare_communication_buffer_for_model,
)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEParallelConfig,
)
parallel_config = self.vllm_config.parallel_config
def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
return [
module
for module in model.modules()
if (
module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE"
)
]
def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
assert all(
module.moe_config.num_local_experts == num_local_experts
for module in moe_modules
), "All MoE modules must have the same number of experts"
for module in moe_modules:
module.moe_config.num_experts = num_local_experts * new_ep_size
module.global_num_experts = module.moe_config.num_experts
tp_size = get_tp_group().world_size
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
sp_size = tp_size if is_sequence_parallel else 1
module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=tp_size,
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
sp_size_=sp_size,
vllm_parallel_config=parallel_config,
)
module.moe_config.moe_parallel_config = module.moe_parallel_config
return moe_modules
model_moe_modules = get_moe_modules(self.model_runner.model)
num_local_experts = model_moe_modules[0].moe_config.num_local_experts
update_moe_modules(model_moe_modules, num_local_experts)
drafter_model = None
if hasattr(self.model_runner, "drafter") and hasattr(
self.model_runner.drafter, "model"
):
drafter_model = self.model_runner.drafter.model
if drafter_model is not None and is_mixture_of_experts(drafter_model):
drafter_moe_modules = get_moe_modules(drafter_model)
# Check if drafter and model have matching configs
assert (
drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
), "Drafter and model configs should be the same"
update_moe_modules(drafter_moe_modules, num_local_experts)
if new_ep_size < old_ep_size:
num_local_physical_experts = num_local_experts
assert self.model_runner.eplb_state is not None
new_physical_experts = (
self.model_runner.eplb_state.physical_to_logical_map.shape[1] # type: ignore[attr-defined]
)
parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts
- self.model_runner.eplb_state.logical_replica_count.shape[1] # type: ignore[attr-defined]
)
global_expert_loads = None
else:
num_local_physical_experts_tensor = torch.tensor(
[num_local_experts], dtype=torch.int32, device="cpu"
)
torch.distributed.broadcast(
num_local_physical_experts_tensor,
group=get_ep_group().cpu_group,
group_src=0,
)
num_local_physical_experts = int(num_local_physical_experts_tensor.item())
new_physical_experts = num_local_physical_experts * new_ep_size
assert self.model_runner.eplb_state is not None
global_expert_loads_any = self.model_runner.eplb_state.rearrange(
execute_shuffle=False
)
global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts - global_expert_loads[0].shape[1]
)
prepare_communication_buffer_for_model(self.model_runner.model)
if drafter_model is not None:
prepare_communication_buffer_for_model(drafter_model)
self.model_runner.model.update_physical_experts_metadata(
num_physical_experts=new_physical_experts,
num_local_physical_experts=num_local_physical_experts,
)
return global_expert_loads
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (
cleanup_dist_env_and_memory,
get_ep_group,
)
old_ep_size = get_ep_group().world_size
old_ep_rank = get_ep_group().rank
new_ep_size = (
reconfig_request.new_data_parallel_size
* get_tp_group().world_size
* get_pp_group().world_size
)
if new_ep_size < old_ep_size:
self._eplb_before_scale_down(old_ep_size, new_ep_size)
cleanup_dist_env_and_memory()
if (
reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
):
assert old_ep_rank >= new_ep_size
# shutdown
return
self._reconfigure_parallel_config(reconfig_request)
with set_current_vllm_config(self.vllm_config):
init_worker_distributed_environment(
self.vllm_config,
self.rank,
self.distributed_init_method,
self.local_rank,
)
global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
if new_ep_size > old_ep_size:
assert global_expert_loads is not None
self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
def save_sharded_state(
self,
path: str,
@@ -1118,6 +914,9 @@ class Worker(WorkerBase):
if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
weight_transfer_engine.shutdown()
def elastic_ep_execute(self, execute_method: str, *args, **kwargs):
return self.elastic_ep_executor.execute(execute_method, *args, **kwargs)
def init_worker_distributed_environment(
vllm_config: VllmConfig,
+28
View File
@@ -66,6 +66,23 @@ class WorkspaceManager:
],
)
def unlock(self) -> None:
"""Unlock the workspace to allow growth.
This is used during elastic EP scaling when the workspace size
needs to grow due to changes in the number of experts.
"""
self._locked = False
if envs.VLLM_DEBUG_WORKSPACE:
logger.info(
"[WORKSPACE DEBUG] Workspace unlocked. Current sizes: %s",
[
self._workspace_size_bytes(ws) / _MB
for ws in self._current_workspaces
if ws is not None
],
)
def is_locked(self) -> bool:
"""Check if workspace is locked."""
return self._locked
@@ -242,6 +259,17 @@ def lock_workspace() -> None:
current_workspace_manager().lock()
def unlock_workspace() -> None:
"""Unlock the workspace to allow growth.
This is used during elastic EP scaling when the workspace size
needs to grow due to changes in the number of experts.
After scaling operations complete, lock_workspace() should be
called again to prevent unexpected allocations.
"""
current_workspace_manager().unlock()
def reset_workspace_manager() -> None:
"""Reset the workspace manager to uninitialized state.