mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[1/N] Elastic EP Milestone 2 (#34861)
Signed-off-by: Yongji Wu <wuyongji317@gmail.com> Signed-off-by: Itay Alroy <ialroy@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Ron Tourgeman <rtourgeman@nvidia.com> Co-authored-by: Yongji Wu <wuyongji317@gmail.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
This commit is contained in:
@@ -20,4 +20,19 @@ steps:
|
||||
- tests/distributed/test_eplb_execute.py
|
||||
commands:
|
||||
- pytest -v -s distributed/test_eplb_execute.py
|
||||
- pytest -v -s distributed/test_eplb_spec_decode.py
|
||||
- pytest -v -s distributed/test_eplb_spec_decode.py
|
||||
|
||||
- label: Elastic EP Scaling Test
|
||||
timeout_in_minutes: 20
|
||||
device: b200
|
||||
optional: true
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_devices: 4
|
||||
source_file_dependencies:
|
||||
- vllm/distributed/
|
||||
- vllm/engine/
|
||||
- vllm/executor/
|
||||
- vllm/compilation/
|
||||
- tests/distributed/
|
||||
commands:
|
||||
- pytest -v -s distributed/test_elastic_ep.py
|
||||
|
||||
@@ -316,7 +316,6 @@ def async_tp_pass_on_test_model(
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# configure vllm config for SequenceParallelismPass
|
||||
vllm_config = VllmConfig()
|
||||
@@ -334,11 +333,10 @@ def async_tp_pass_on_test_model(
|
||||
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
|
||||
)
|
||||
|
||||
async_tp_pass = AsyncTPPass(vllm_config)
|
||||
|
||||
# Set the global vllm_config for TestBackend which calls
|
||||
# get_current_vllm_config()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
async_tp_pass = AsyncTPPass(vllm_config)
|
||||
backend = TestBackend(async_tp_pass)
|
||||
|
||||
assert (
|
||||
|
||||
@@ -278,7 +278,6 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
)
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
custom_ops = []
|
||||
if enable_rms_norm_custom_op:
|
||||
@@ -304,6 +303,7 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
|
||||
@@ -242,7 +242,6 @@ def sequence_parallelism_pass_on_test_model(
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# configure vllm config for SequenceParallelismPass
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
@@ -272,6 +271,7 @@ def sequence_parallelism_pass_on_test_model(
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
+13
-9
@@ -176,16 +176,20 @@ def init_test_http_connection():
|
||||
|
||||
@pytest.fixture
|
||||
def dist_init():
|
||||
from tests.utils import ensure_current_vllm_config
|
||||
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
local_rank=0,
|
||||
backend="nccl",
|
||||
)
|
||||
initialize_model_parallel(1, 1)
|
||||
yield
|
||||
|
||||
with ensure_current_vllm_config():
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
local_rank=0,
|
||||
backend="nccl",
|
||||
)
|
||||
initialize_model_parallel(1, 1)
|
||||
yield
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import random
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
init_distributed_environment,
|
||||
)
|
||||
@@ -42,7 +43,11 @@ def set_env_vars_and_device(env: dict[str, str]) -> None:
|
||||
local_rank = os.environ["LOCAL_RANK"]
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_distributed_environment()
|
||||
|
||||
# Create a minimal vllm config for init_distributed_environment
|
||||
vllm_config = VllmConfig()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
init_distributed_environment()
|
||||
|
||||
# Ensure each worker process has the same random seed
|
||||
random.seed(42)
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from ..evals.gsm8k.gsm8k_eval import evaluate_gsm8k
|
||||
from ..utils import RemoteOpenAIServer, multi_gpu_test
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_ray_between_tests():
|
||||
"""Force-stop any lingering Ray processes between tests."""
|
||||
subprocess.run(["ray", "stop", "--force"], timeout=30, capture_output=True)
|
||||
time.sleep(5)
|
||||
yield
|
||||
|
||||
|
||||
MODEL_NAME = "deepseek-ai/DeepSeek-V2-Lite-Chat"
|
||||
|
||||
NUM_GSM8K_QUESTIONS = 256
|
||||
EXPECTED_ACCURACY = 0.58
|
||||
ACCURACY_TOL = 0.08
|
||||
MAX_NUM_SEQS = 32
|
||||
|
||||
|
||||
def _send_scale_command(server: RemoteOpenAIServer, new_dp_size: int) -> bool:
|
||||
url = server.url_for("scale_elastic_ep")
|
||||
payload = {"new_data_parallel_size": new_dp_size}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=payload, headers=headers, timeout=300)
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.RequestException:
|
||||
return False
|
||||
|
||||
|
||||
def _run_gsm8k_eval(server: RemoteOpenAIServer, stage: str) -> float:
|
||||
assert server.port is not None
|
||||
result = evaluate_gsm8k(
|
||||
num_questions=NUM_GSM8K_QUESTIONS,
|
||||
host=f"http://{server.host}",
|
||||
port=server.port,
|
||||
)
|
||||
accuracy = result["accuracy"]
|
||||
print(
|
||||
f"[{stage}] GSM8K accuracy: {accuracy:.3f} "
|
||||
f"({result['num_questions']} questions)"
|
||||
)
|
||||
assert accuracy >= EXPECTED_ACCURACY, (
|
||||
f"[{stage}] GSM8K accuracy {accuracy:.3f} is below "
|
||||
f"expected threshold {EXPECTED_ACCURACY}"
|
||||
)
|
||||
return accuracy
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
def test_elastic_ep_scaling():
|
||||
vllm_serve_args = [
|
||||
"--trust-remote-code",
|
||||
"--tensor-parallel-size",
|
||||
"1",
|
||||
"--gpu-memory-utilization",
|
||||
"0.8",
|
||||
"--max-model-len",
|
||||
"4096",
|
||||
"--max-num-seqs",
|
||||
str(MAX_NUM_SEQS),
|
||||
"--enable-expert-parallel",
|
||||
"--all2all-backend",
|
||||
"allgather_reducescatter",
|
||||
"--enable-elastic-ep",
|
||||
"--enable-eplb",
|
||||
"--eplb-config.num_redundant_experts",
|
||||
"0",
|
||||
"--data-parallel-backend",
|
||||
"ray",
|
||||
"--data-parallel-size",
|
||||
"2",
|
||||
"--api-server-count",
|
||||
"1",
|
||||
]
|
||||
|
||||
leader_address = os.environ.get("LEADER_ADDRESS")
|
||||
if leader_address:
|
||||
vllm_serve_args.extend(["--data-parallel-address", leader_address])
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
MODEL_NAME, vllm_serve_args, env_dict={}, max_wait_seconds=1200
|
||||
) as server:
|
||||
initial_accuracy = _run_gsm8k_eval(server, "Initial (2 GPUs)")
|
||||
|
||||
assert _send_scale_command(server, 4)
|
||||
time.sleep(10)
|
||||
scale_up_accuracy = _run_gsm8k_eval(server, "After scale up (4 GPUs)")
|
||||
|
||||
assert scale_up_accuracy >= initial_accuracy - ACCURACY_TOL, (
|
||||
f"Scale up accuracy {scale_up_accuracy:.3f} dropped more than "
|
||||
f"{ACCURACY_TOL} below initial accuracy {initial_accuracy:.3f}"
|
||||
)
|
||||
|
||||
assert _send_scale_command(server, 2)
|
||||
time.sleep(5)
|
||||
scale_down_accuracy = _run_gsm8k_eval(server, "After scale down (2 GPUs)")
|
||||
|
||||
assert scale_down_accuracy >= initial_accuracy - ACCURACY_TOL, (
|
||||
f"Scale down accuracy {scale_down_accuracy:.3f} dropped more than "
|
||||
f"{ACCURACY_TOL} below initial accuracy {initial_accuracy:.3f}"
|
||||
)
|
||||
|
||||
print("\nAccuracy Summary:")
|
||||
print(f" Initial: {initial_accuracy:.3f}")
|
||||
print(
|
||||
f" Scale up: {scale_up_accuracy:.3f} "
|
||||
f"(diff: {scale_up_accuracy - initial_accuracy:+.3f})"
|
||||
)
|
||||
print(
|
||||
f" Scale down: {scale_down_accuracy:.3f} "
|
||||
f"(diff: {scale_down_accuracy - initial_accuracy:+.3f})"
|
||||
)
|
||||
print(f" Tolerance: {ACCURACY_TOL:.3f}")
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
def test_elastic_ep_scaling_uneven():
|
||||
"""Test scale up with uneven worker distribution.
|
||||
|
||||
This tests the case where num_new_workers % old_dp_size != 0,
|
||||
specifically 2 -> 3 where remainder = 1 % 2 = 1.
|
||||
This exercises the remainder handling in sender-receiver pairing.
|
||||
"""
|
||||
vllm_serve_args = [
|
||||
"--trust-remote-code",
|
||||
"--tensor-parallel-size",
|
||||
"1",
|
||||
"--gpu-memory-utilization",
|
||||
"0.8",
|
||||
"--max-model-len",
|
||||
"4096",
|
||||
"--max-num-seqs",
|
||||
str(MAX_NUM_SEQS),
|
||||
"--enable-expert-parallel",
|
||||
"--all2all-backend",
|
||||
"allgather_reducescatter",
|
||||
"--enable-elastic-ep",
|
||||
"--enable-eplb",
|
||||
"--eplb-config.num_redundant_experts",
|
||||
"0",
|
||||
"--data-parallel-backend",
|
||||
"ray",
|
||||
"--data-parallel-size",
|
||||
"2",
|
||||
"--api-server-count",
|
||||
"1",
|
||||
]
|
||||
|
||||
leader_address = os.environ.get("LEADER_ADDRESS")
|
||||
if leader_address:
|
||||
vllm_serve_args.extend(["--data-parallel-address", leader_address])
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
MODEL_NAME, vllm_serve_args, env_dict={}, max_wait_seconds=1200
|
||||
) as server:
|
||||
initial_accuracy = _run_gsm8k_eval(server, "Initial (2 GPUs)")
|
||||
|
||||
# Scale 2 -> 3: This has remainder = 1 % 2 = 1
|
||||
# Tests uneven sender-receiver pairing
|
||||
assert _send_scale_command(server, 3)
|
||||
time.sleep(10)
|
||||
scale_up_accuracy = _run_gsm8k_eval(server, "After scale up (3 GPUs)")
|
||||
|
||||
assert scale_up_accuracy >= initial_accuracy - ACCURACY_TOL, (
|
||||
f"Scale up accuracy {scale_up_accuracy:.3f} dropped more than "
|
||||
f"{ACCURACY_TOL} below initial accuracy {initial_accuracy:.3f}"
|
||||
)
|
||||
|
||||
# Scale back down to 2
|
||||
assert _send_scale_command(server, 2)
|
||||
time.sleep(5)
|
||||
scale_down_accuracy = _run_gsm8k_eval(server, "After scale down (2 GPUs)")
|
||||
|
||||
assert scale_down_accuracy >= initial_accuracy - ACCURACY_TOL, (
|
||||
f"Scale down accuracy {scale_down_accuracy:.3f} dropped more than "
|
||||
f"{ACCURACY_TOL} below initial accuracy {initial_accuracy:.3f}"
|
||||
)
|
||||
|
||||
print("\nAccuracy Summary (Uneven Scaling):")
|
||||
print(f" Initial: {initial_accuracy:.3f}")
|
||||
print(
|
||||
f" Scale up: {scale_up_accuracy:.3f} "
|
||||
f"(diff: {scale_up_accuracy - initial_accuracy:+.3f})"
|
||||
)
|
||||
print(
|
||||
f" Scale down: {scale_down_accuracy:.3f} "
|
||||
f"(diff: {scale_down_accuracy - initial_accuracy:+.3f})"
|
||||
)
|
||||
print(f" Tolerance: {ACCURACY_TOL:.3f}")
|
||||
@@ -8,6 +8,7 @@ import pytest
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.eplb.rebalance_execute import (
|
||||
move_from_buffer,
|
||||
rearrange_expert_weights_inplace,
|
||||
@@ -244,90 +245,95 @@ def _test_async_transfer_layer_without_mtp_worker(
|
||||
num_logical_experts: int,
|
||||
) -> None:
|
||||
set_env_vars_and_device(env)
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
tp_group = get_tp_group()
|
||||
ep_group = tp_group.device_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.tensor_parallel_size = world_size
|
||||
|
||||
total_physical_experts = world_size * num_local_experts
|
||||
hidden_sizes = [16, 32]
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
redundancy_config = create_redundancy_config(
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
)
|
||||
old_indices = create_expert_indices_with_redundancy(
|
||||
num_layers,
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
redundancy_config,
|
||||
)
|
||||
tp_group = get_tp_group()
|
||||
ep_group = tp_group.device_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
new_redundancy_config = create_redundancy_config(
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
)
|
||||
new_indices = create_expert_indices_with_redundancy(
|
||||
num_layers,
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
new_redundancy_config,
|
||||
)
|
||||
total_physical_experts = world_size * num_local_experts
|
||||
hidden_sizes = [16, 32]
|
||||
|
||||
expert_weights = create_expert_weights(
|
||||
num_layers,
|
||||
num_local_experts,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
device,
|
||||
old_indices,
|
||||
)
|
||||
old_indices_cpu = old_indices.cpu()
|
||||
new_indices_cpu = new_indices.cpu()
|
||||
redundancy_config = create_redundancy_config(
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
)
|
||||
old_indices = create_expert_indices_with_redundancy(
|
||||
num_layers,
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
redundancy_config,
|
||||
)
|
||||
|
||||
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||
cuda_stream = torch.cuda.Stream(device=device)
|
||||
new_redundancy_config = create_redundancy_config(
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
)
|
||||
new_indices = create_expert_indices_with_redundancy(
|
||||
num_layers,
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
new_redundancy_config,
|
||||
)
|
||||
|
||||
for layer_idx in range(num_layers):
|
||||
is_unchanged, is_received_locally, recv_metadata = asyncio.run(
|
||||
transfer_layer(
|
||||
old_layer_indices=old_indices_cpu[layer_idx],
|
||||
new_layer_indices=new_indices_cpu[layer_idx],
|
||||
expert_weights=expert_weights[layer_idx],
|
||||
expert_weights_buffer=expert_buffer,
|
||||
ep_group=ep_group,
|
||||
cuda_stream=cuda_stream,
|
||||
expert_weights = create_expert_weights(
|
||||
num_layers,
|
||||
num_local_experts,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
device,
|
||||
old_indices,
|
||||
)
|
||||
old_indices_cpu = old_indices.cpu()
|
||||
new_indices_cpu = new_indices.cpu()
|
||||
|
||||
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||
cuda_stream = torch.cuda.Stream(device=device)
|
||||
|
||||
for layer_idx in range(num_layers):
|
||||
is_unchanged, is_received_locally, recv_metadata = asyncio.run(
|
||||
transfer_layer(
|
||||
old_layer_indices=old_indices_cpu[layer_idx],
|
||||
new_layer_indices=new_indices_cpu[layer_idx],
|
||||
expert_weights=expert_weights[layer_idx],
|
||||
expert_weights_buffer=expert_buffer,
|
||||
ep_group=ep_group,
|
||||
cuda_stream=cuda_stream,
|
||||
)
|
||||
)
|
||||
cuda_stream.synchronize()
|
||||
move_from_buffer(
|
||||
expert_weights=expert_weights[layer_idx],
|
||||
expert_weights_buffers=expert_buffer,
|
||||
is_unchanged=is_unchanged,
|
||||
is_received_locally=is_received_locally,
|
||||
recv_metadata=recv_metadata,
|
||||
new_indices=new_indices_cpu[layer_idx].numpy(),
|
||||
ep_rank=ep_rank,
|
||||
)
|
||||
)
|
||||
cuda_stream.synchronize()
|
||||
move_from_buffer(
|
||||
expert_weights=expert_weights[layer_idx],
|
||||
expert_weights_buffers=expert_buffer,
|
||||
is_unchanged=is_unchanged,
|
||||
is_received_locally=is_received_locally,
|
||||
recv_metadata=recv_metadata,
|
||||
new_indices=new_indices_cpu[layer_idx].numpy(),
|
||||
ep_rank=ep_rank,
|
||||
)
|
||||
|
||||
verify_expert_weights_after_shuffle(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
num_local_experts,
|
||||
)
|
||||
verify_redundant_experts_have_same_weights(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
world_size,
|
||||
num_local_experts,
|
||||
)
|
||||
verify_expert_weights_after_shuffle(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
num_local_experts,
|
||||
)
|
||||
verify_redundant_experts_have_same_weights(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
world_size,
|
||||
num_local_experts,
|
||||
)
|
||||
|
||||
|
||||
def _test_rearrange_expert_weights_with_redundancy(
|
||||
@@ -336,71 +342,76 @@ def _test_rearrange_expert_weights_with_redundancy(
|
||||
# Initialize model parallel (using tensor parallel as an entrypoint
|
||||
# to expert parallel)
|
||||
set_env_vars_and_device(env)
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.tensor_parallel_size = world_size
|
||||
|
||||
# Test parameters
|
||||
total_physical_experts = world_size * num_local_experts
|
||||
hidden_sizes = [32, 64] # Two different weight matrices
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
# Create old expert indices (with redundancy)
|
||||
redundancy_config = create_redundancy_config(
|
||||
num_logical_experts, total_physical_experts
|
||||
)
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
old_indices = create_expert_indices_with_redundancy(
|
||||
num_layers,
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
redundancy_config,
|
||||
)
|
||||
# Test parameters
|
||||
total_physical_experts = world_size * num_local_experts
|
||||
hidden_sizes = [32, 64] # Two different weight matrices
|
||||
|
||||
# Create new expert indices (with redundancy)
|
||||
new_redundancy_config = create_redundancy_config(
|
||||
num_logical_experts, total_physical_experts
|
||||
)
|
||||
new_indices = create_expert_indices_with_redundancy(
|
||||
num_layers,
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
new_redundancy_config,
|
||||
)
|
||||
# Create old expert indices (with redundancy)
|
||||
redundancy_config = create_redundancy_config(
|
||||
num_logical_experts, total_physical_experts
|
||||
)
|
||||
|
||||
# Create expert weights
|
||||
expert_weights = create_expert_weights(
|
||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
||||
)
|
||||
old_indices = create_expert_indices_with_redundancy(
|
||||
num_layers,
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
redundancy_config,
|
||||
)
|
||||
|
||||
# Execute weight rearrangement
|
||||
rearrange_expert_weights_inplace(
|
||||
old_indices,
|
||||
new_indices,
|
||||
expert_weights,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
)
|
||||
# Create new expert indices (with redundancy)
|
||||
new_redundancy_config = create_redundancy_config(
|
||||
num_logical_experts, total_physical_experts
|
||||
)
|
||||
new_indices = create_expert_indices_with_redundancy(
|
||||
num_layers,
|
||||
num_logical_experts,
|
||||
total_physical_experts,
|
||||
new_redundancy_config,
|
||||
)
|
||||
|
||||
# Verify the rearrangement result
|
||||
verify_expert_weights_after_shuffle(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
num_local_experts,
|
||||
)
|
||||
# Create expert weights
|
||||
expert_weights = create_expert_weights(
|
||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
||||
)
|
||||
|
||||
verify_redundant_experts_have_same_weights(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
world_size,
|
||||
num_local_experts,
|
||||
)
|
||||
# Execute weight rearrangement
|
||||
rearrange_expert_weights_inplace(
|
||||
old_indices,
|
||||
new_indices,
|
||||
expert_weights,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
)
|
||||
|
||||
# Verify the rearrangement result
|
||||
verify_expert_weights_after_shuffle(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
num_local_experts,
|
||||
)
|
||||
|
||||
verify_redundant_experts_have_same_weights(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
world_size,
|
||||
num_local_experts,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -444,58 +455,63 @@ def test_rearrange_expert_weights_with_redundancy(
|
||||
|
||||
def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
|
||||
set_env_vars_and_device(env)
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.tensor_parallel_size = world_size
|
||||
|
||||
num_layers = 2
|
||||
num_local_experts = 2
|
||||
total_physical_experts = world_size * num_local_experts
|
||||
num_logical_experts = total_physical_experts // 2 # Some redundancy
|
||||
hidden_sizes = [32, 64]
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
# Create redundancy configuration
|
||||
redundancy_config = [2] * num_logical_experts
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
# Same indices - no change
|
||||
indices = create_expert_indices_with_redundancy(
|
||||
num_layers, num_logical_experts, total_physical_experts, redundancy_config
|
||||
)
|
||||
num_layers = 2
|
||||
num_local_experts = 2
|
||||
total_physical_experts = world_size * num_local_experts
|
||||
num_logical_experts = total_physical_experts // 2 # Some redundancy
|
||||
hidden_sizes = [32, 64]
|
||||
|
||||
expert_weights = create_expert_weights(
|
||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices
|
||||
)
|
||||
# Create redundancy configuration
|
||||
redundancy_config = [2] * num_logical_experts
|
||||
|
||||
# Save original weights
|
||||
original_weights = []
|
||||
for layer_weights in expert_weights:
|
||||
layer_copy = []
|
||||
for weight in layer_weights:
|
||||
layer_copy.append(weight.clone())
|
||||
original_weights.append(layer_copy)
|
||||
# Same indices - no change
|
||||
indices = create_expert_indices_with_redundancy(
|
||||
num_layers, num_logical_experts, total_physical_experts, redundancy_config
|
||||
)
|
||||
|
||||
# Execute rearrangement (should be no change)
|
||||
rearrange_expert_weights_inplace(
|
||||
indices,
|
||||
indices, # Same indices
|
||||
expert_weights,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
)
|
||||
expert_weights = create_expert_weights(
|
||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices
|
||||
)
|
||||
|
||||
# Verify that the weights have not changed
|
||||
for layer in range(num_layers):
|
||||
for weight_idx in range(len(hidden_sizes)):
|
||||
torch.testing.assert_close(
|
||||
expert_weights[layer][weight_idx],
|
||||
original_weights[layer][weight_idx],
|
||||
msg=f"""Layer {layer}, weight {weight_idx}
|
||||
# Save original weights
|
||||
original_weights = []
|
||||
for layer_weights in expert_weights:
|
||||
layer_copy = []
|
||||
for weight in layer_weights:
|
||||
layer_copy.append(weight.clone())
|
||||
original_weights.append(layer_copy)
|
||||
|
||||
# Execute rearrangement (should be no change)
|
||||
rearrange_expert_weights_inplace(
|
||||
indices,
|
||||
indices, # Same indices
|
||||
expert_weights,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
)
|
||||
|
||||
# Verify that the weights have not changed
|
||||
for layer in range(num_layers):
|
||||
for weight_idx in range(len(hidden_sizes)):
|
||||
torch.testing.assert_close(
|
||||
expert_weights[layer][weight_idx],
|
||||
original_weights[layer][weight_idx],
|
||||
msg=f"""Layer {layer}, weight {weight_idx}
|
||||
should remain unchanged""",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -538,64 +554,69 @@ def test_rearrange_expert_weights_no_change(world_size):
|
||||
|
||||
def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
|
||||
set_env_vars_and_device(env)
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.tensor_parallel_size = world_size
|
||||
|
||||
num_layers = 1
|
||||
num_local_experts = 2
|
||||
total_physical_experts = world_size * num_local_experts
|
||||
num_logical_experts = total_physical_experts // 2
|
||||
hidden_sizes = [32]
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
# Create different index distributions
|
||||
old_redundancy = create_redundancy_config(
|
||||
num_logical_experts, total_physical_experts
|
||||
)
|
||||
new_redundancy = create_redundancy_config(
|
||||
num_logical_experts, total_physical_experts
|
||||
)
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
old_indices = create_expert_indices_with_redundancy(
|
||||
num_layers, num_logical_experts, total_physical_experts, old_redundancy
|
||||
)
|
||||
new_indices = create_expert_indices_with_redundancy(
|
||||
num_layers, num_logical_experts, total_physical_experts, new_redundancy
|
||||
)
|
||||
num_layers = 1
|
||||
num_local_experts = 2
|
||||
total_physical_experts = world_size * num_local_experts
|
||||
num_logical_experts = total_physical_experts // 2
|
||||
hidden_sizes = [32]
|
||||
|
||||
expert_weights = create_expert_weights(
|
||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
||||
)
|
||||
# Create different index distributions
|
||||
old_redundancy = create_redundancy_config(
|
||||
num_logical_experts, total_physical_experts
|
||||
)
|
||||
new_redundancy = create_redundancy_config(
|
||||
num_logical_experts, total_physical_experts
|
||||
)
|
||||
|
||||
# Save original weights
|
||||
original_weights = []
|
||||
for layer_weights in expert_weights:
|
||||
layer_copy = []
|
||||
for weight in layer_weights:
|
||||
layer_copy.append(weight.clone())
|
||||
original_weights.append(layer_copy)
|
||||
old_indices = create_expert_indices_with_redundancy(
|
||||
num_layers, num_logical_experts, total_physical_experts, old_redundancy
|
||||
)
|
||||
new_indices = create_expert_indices_with_redundancy(
|
||||
num_layers, num_logical_experts, total_physical_experts, new_redundancy
|
||||
)
|
||||
|
||||
# Execute profile mode rearrangement
|
||||
rearrange_expert_weights_inplace(
|
||||
old_indices,
|
||||
new_indices,
|
||||
expert_weights,
|
||||
ep_group,
|
||||
is_profile=True, # Profile mode
|
||||
)
|
||||
expert_weights = create_expert_weights(
|
||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
||||
)
|
||||
|
||||
# In profile mode, the weights should remain unchanged
|
||||
for layer in range(num_layers):
|
||||
for weight_idx in range(len(hidden_sizes)):
|
||||
torch.testing.assert_close(
|
||||
expert_weights[layer][weight_idx],
|
||||
original_weights[layer][weight_idx],
|
||||
msg="In profile mode, the weights should remain unchanged",
|
||||
)
|
||||
# Save original weights
|
||||
original_weights = []
|
||||
for layer_weights in expert_weights:
|
||||
layer_copy = []
|
||||
for weight in layer_weights:
|
||||
layer_copy.append(weight.clone())
|
||||
original_weights.append(layer_copy)
|
||||
|
||||
# Execute profile mode rearrangement
|
||||
rearrange_expert_weights_inplace(
|
||||
old_indices,
|
||||
new_indices,
|
||||
expert_weights,
|
||||
ep_group,
|
||||
is_profile=True, # Profile mode
|
||||
)
|
||||
|
||||
# In profile mode, the weights should remain unchanged
|
||||
for layer in range(num_layers):
|
||||
for weight_idx in range(len(hidden_sizes)):
|
||||
torch.testing.assert_close(
|
||||
expert_weights[layer][weight_idx],
|
||||
original_weights[layer][weight_idx],
|
||||
msg="In profile mode, the weights should remain unchanged",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
|
||||
@@ -10,6 +10,7 @@ import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import vllm.envs as envs
|
||||
from tests.utils import ensure_current_vllm_config
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
|
||||
from vllm.distributed.device_communicators.pynccl import register_nccl_symmetric_ops
|
||||
@@ -51,7 +52,8 @@ def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
|
||||
)
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
with ensure_current_vllm_config():
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
cuda_communicator = typing.cast(
|
||||
CudaCommunicator, get_tp_group().device_communicator
|
||||
|
||||
@@ -9,6 +9,7 @@ import pytest
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from tests.utils import ensure_current_vllm_config
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
|
||||
@@ -112,7 +113,8 @@ def test_pynccl_multiple_allreduce():
|
||||
@worker_fn_wrapper
|
||||
def multiple_allreduce_with_vllm_worker_fn():
|
||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||
ensure_model_parallel_initialized(2, 2)
|
||||
with ensure_current_vllm_config():
|
||||
ensure_model_parallel_initialized(2, 2)
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||
with graph_capture(device=device):
|
||||
# two tp groups can communicate independently
|
||||
|
||||
@@ -6,7 +6,7 @@ import unittest
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from tests.utils import ensure_current_vllm_config, multi_gpu_test
|
||||
from vllm.distributed.parallel_state import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
@@ -87,7 +87,8 @@ def mixer2_gated_norm_tensor_parallel(
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
with ensure_current_vllm_config():
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# create random weights an inputs
|
||||
weight = torch.rand((hidden_size,), dtype=dtype, device=device)
|
||||
|
||||
+12
-9
@@ -45,21 +45,24 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
|
||||
|
||||
@pytest.fixture
|
||||
def dist_init():
|
||||
from tests.utils import ensure_current_vllm_config
|
||||
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
|
||||
backend = "nccl"
|
||||
if current_platform.is_cpu() or current_platform.is_tpu():
|
||||
backend = "gloo"
|
||||
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
local_rank=0,
|
||||
backend=backend,
|
||||
)
|
||||
initialize_model_parallel(1, 1)
|
||||
yield
|
||||
with ensure_current_vllm_config():
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
local_rank=0,
|
||||
backend=backend,
|
||||
)
|
||||
initialize_model_parallel(1, 1)
|
||||
yield
|
||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import random
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from tests.utils import ensure_current_vllm_config, multi_gpu_test
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import (
|
||||
init_distributed_environment,
|
||||
@@ -631,7 +631,8 @@ def use_fused_moe_lora_kernel_tensor_parallel(
|
||||
local_rank=local_rank,
|
||||
distributed_init_method=init_method,
|
||||
)
|
||||
initialize_model_parallel(world_size, 1)
|
||||
with ensure_current_vllm_config():
|
||||
initialize_model_parallel(world_size, 1)
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
input_dim = K if column_parallel else N
|
||||
|
||||
@@ -13,6 +13,7 @@ from vllm.config import (
|
||||
ParallelConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.config.lora import LoRAConfig
|
||||
@@ -77,8 +78,9 @@ def test_worker_apply_lora(qwen3_lora_files):
|
||||
distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
|
||||
)
|
||||
|
||||
worker.init_device()
|
||||
worker.load_model()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
worker.init_device()
|
||||
worker.load_model()
|
||||
|
||||
set_active_loras(worker, [])
|
||||
assert worker.list_loras() == set()
|
||||
|
||||
@@ -6,7 +6,7 @@ import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from tests.utils import ensure_current_vllm_config, multi_gpu_test
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import (
|
||||
init_distributed_environment,
|
||||
@@ -117,7 +117,8 @@ def run_dp_sharded_vision_model_vs_direct(
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
with ensure_current_vllm_config():
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create a test input tensor
|
||||
image_input = torch.randn(batch_size, 3, 224, 224)
|
||||
@@ -302,7 +303,8 @@ def run_dp_sharded_mrope_vision_model_vs_direct(
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
with ensure_current_vllm_config():
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create test data
|
||||
grid_thw_list = []
|
||||
@@ -377,7 +379,8 @@ def run_dp_sharded_mrope_vision_model_empty_input_worker(
|
||||
)
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
with ensure_current_vllm_config():
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create empty inputs
|
||||
pixel_values = torch.empty((0, 768))
|
||||
@@ -425,7 +428,8 @@ def run_dp_sharded_mrope_vision_model_uneven_load_worker(
|
||||
)
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
with ensure_current_vllm_config():
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# Create images with very different sizes
|
||||
grid_thw_list = [
|
||||
|
||||
+32
-1
@@ -895,6 +895,36 @@ def compare_all_settings(
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def ensure_current_vllm_config():
|
||||
"""Ensures a vllm config is set for the duration of the context.
|
||||
|
||||
If a config is already set, this is a no-op. Otherwise, it creates a default
|
||||
VllmConfig and sets it for the duration of the context.
|
||||
|
||||
Used for tests that call functions which require a vllm config but don't
|
||||
need a specific config.
|
||||
|
||||
Example:
|
||||
with ensure_current_vllm_config():
|
||||
init_distributed_environment(...)
|
||||
ensure_model_parallel_initialized(...)
|
||||
"""
|
||||
from vllm.config import (
|
||||
VllmConfig,
|
||||
get_current_vllm_config_or_none,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
|
||||
if get_current_vllm_config_or_none() is not None:
|
||||
# Config already set, just yield
|
||||
yield
|
||||
else:
|
||||
# No config set, create a default one for the duration
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
yield
|
||||
|
||||
|
||||
def init_test_distributed_environment(
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
@@ -921,6 +951,7 @@ def init_test_distributed_environment(
|
||||
distributed_init_method=distributed_init_method,
|
||||
local_rank=local_rank,
|
||||
)
|
||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||
else:
|
||||
# No config set, create a default one for the test
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
@@ -930,7 +961,7 @@ def init_test_distributed_environment(
|
||||
distributed_init_method=distributed_init_method,
|
||||
local_rank=local_rank,
|
||||
)
|
||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||
|
||||
|
||||
def multi_process_parallel(
|
||||
|
||||
@@ -789,8 +789,11 @@ def test_hybrid_attention_mamba_tensor_shapes():
|
||||
"MASTER_PORT": "12345",
|
||||
}
|
||||
)
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=1)
|
||||
from tests.utils import ensure_current_vllm_config
|
||||
|
||||
with ensure_current_vllm_config():
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=1)
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
model_config = ModelConfig(
|
||||
|
||||
@@ -10,6 +10,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.utils.mem_utils import MemorySnapshot
|
||||
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
|
||||
@@ -95,7 +96,12 @@ def worker_process(
|
||||
side_effect=make_operation_tracker("nccl_all_reduce", original_all_reduce),
|
||||
)
|
||||
|
||||
with init_patch, memory_patch, all_reduce_patch:
|
||||
with (
|
||||
init_patch,
|
||||
memory_patch,
|
||||
all_reduce_patch,
|
||||
set_current_vllm_config(vllm_config),
|
||||
):
|
||||
# Initialize device (this is where we test the order)
|
||||
worker.init_device()
|
||||
|
||||
|
||||
@@ -319,3 +319,52 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
yield
|
||||
finally:
|
||||
self.__class__.forward.__code__ = original
|
||||
|
||||
|
||||
def reset_compile_wrapper(model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Clean up compiled model and captured CUDA graphs for elastic EP.
|
||||
"""
|
||||
if not isinstance(model, TorchCompileWithNoGuardsWrapper) and hasattr(
|
||||
model, "model"
|
||||
):
|
||||
model = model.model
|
||||
if not isinstance(model, TorchCompileWithNoGuardsWrapper):
|
||||
return
|
||||
# model.do_not_compile is set by the @support_torch_compile decorator
|
||||
if hasattr(model, "do_not_compile") and model.do_not_compile:
|
||||
return
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
|
||||
# reset the compilation counter
|
||||
compilation_counter.num_models_seen = 0
|
||||
compilation_counter.num_graphs_seen = 0
|
||||
compilation_counter.num_piecewise_graphs_seen = 0
|
||||
compilation_counter.num_piecewise_capturable_graphs_seen = 0
|
||||
compilation_counter.num_backend_compilations = 0
|
||||
compilation_counter.num_gpu_runner_capture_triggers = 0
|
||||
compilation_counter.num_cudagraph_captured = 0
|
||||
compilation_counter.num_inductor_compiles = 0
|
||||
compilation_counter.num_eager_compiles = 0
|
||||
compilation_counter.num_cache_entries_updated = 0
|
||||
compilation_counter.num_compiled_artifacts_saved = 0
|
||||
compilation_counter.stock_torch_compile_count = 0
|
||||
|
||||
# Clear the AOT compiled function so the model is forced to
|
||||
# recompile on the next call. Without this, decorators.py
|
||||
# __call__ uses the stale aot_compiled_fn whose torchinductor
|
||||
# kernels have old parameters (expert_map size for example)
|
||||
# baked in as compile-time constants.
|
||||
if hasattr(model, "aot_compiled_fn"):
|
||||
model.aot_compiled_fn = None
|
||||
if hasattr(model, "was_aot_compile_fn_loaded_from_disk"):
|
||||
model.was_aot_compile_fn_loaded_from_disk = False
|
||||
|
||||
# Reset the cache_dir so VllmBackend recomputes the hash
|
||||
# (data_parallel_size changed, so the config hash differs).
|
||||
compilation_config = model.vllm_config.compilation_config
|
||||
compilation_config.cache_dir = ""
|
||||
compilation_config.local_cache_dir = ""
|
||||
|
||||
model.__class__.forward.__code__ = model.original_code_object()
|
||||
TorchCompileWithNoGuardsWrapper.__init__(model)
|
||||
|
||||
+116
-6
@@ -165,6 +165,9 @@ class ParallelConfig:
|
||||
disable_custom_all_reduce: bool = False
|
||||
"""Disable the custom all-reduce kernel and fall back to NCCL."""
|
||||
|
||||
enable_elastic_ep: bool = False
|
||||
"""Enable elastic expert parallelism with stateless NCCL groups for DP/EP."""
|
||||
|
||||
enable_dbo: bool = False
|
||||
"""Enable dual batch overlap for the model executor."""
|
||||
ubatch_size: int = 0
|
||||
@@ -244,6 +247,34 @@ class ParallelConfig:
|
||||
Set to be private as it's not intended to be configured by users.
|
||||
"""
|
||||
|
||||
_stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list)
|
||||
"""List of open ports for stateless DP groups when enable_elastic_ep is True.
|
||||
Set to be private as it's not intended to be configured by users.
|
||||
It is a list of list[int], with each inner list contains a set of 3 ports
|
||||
to be used for setting up the stateless CPU/device/TCPStore groups
|
||||
in StatelessGroupCoordinator. The number of inner lists is equal to
|
||||
the number of DP groups,
|
||||
i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size,
|
||||
and len(self._stateless_dp_group_port_list[i]) == 3 for all i.
|
||||
"""
|
||||
|
||||
_stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list)
|
||||
"""List of open ports for stateless EP groups when enable_elastic_ep is True.
|
||||
Set to be private as it's not intended to be configured by users.
|
||||
len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size,
|
||||
"""
|
||||
|
||||
_stateless_eplb_group_port_list: list[list[int]] = Field(default_factory=list)
|
||||
"""List of open ports for stateless EPLB groups when enable_elastic_ep is True.
|
||||
Same topology as EP but separate NCCL communicator to avoid deadlocks.
|
||||
"""
|
||||
|
||||
_stateless_world_group_port_list: list[list[int]] = Field(default_factory=list)
|
||||
"""List of open ports for stateless world group when enable_elastic_ep is True.
|
||||
Set to be private as it's not intended to be configured by users.
|
||||
len(self._stateless_world_group_port_list) == 1,
|
||||
"""
|
||||
|
||||
decode_context_parallel_size: int = 1
|
||||
"""Number of decode context parallel groups, because the world size does
|
||||
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
|
||||
@@ -402,7 +433,67 @@ class ParallelConfig:
|
||||
|
||||
return answer
|
||||
|
||||
def stateless_init_dp_group(self) -> ProcessGroup:
|
||||
def allocate_elastic_ep_ports(self) -> None:
|
||||
"""Allocate all ports for elastic EP (stateless groups + DP master).
|
||||
|
||||
Must be called AFTER ray.init() so that ports claimed by Ray's
|
||||
idle worker pool are already in use and won't be returned by
|
||||
get_open_ports_list().
|
||||
"""
|
||||
if not self.enable_elastic_ep:
|
||||
return
|
||||
if self._stateless_world_group_port_list:
|
||||
return
|
||||
|
||||
num_world_groups = 1
|
||||
dp_size = self.data_parallel_size
|
||||
ep_size = self.data_parallel_size * self.world_size_across_dp
|
||||
num_dp_groups = max(1, self.world_size_across_dp // dp_size)
|
||||
num_ep_groups = max(1, self.world_size_across_dp // ep_size)
|
||||
num_eplb_groups = num_ep_groups
|
||||
total_stateless_ports = (
|
||||
num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
|
||||
) * 3
|
||||
num_dp_master_ports = 5
|
||||
|
||||
all_ports = get_open_ports_list(total_stateless_ports + num_dp_master_ports)
|
||||
|
||||
self._data_parallel_master_port_list = all_ports[-num_dp_master_ports:]
|
||||
self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
|
||||
all_ports = all_ports[:-num_dp_master_ports]
|
||||
|
||||
self._stateless_world_group_port_list = [
|
||||
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
|
||||
]
|
||||
start_idx = num_world_groups * 3
|
||||
self._stateless_dp_group_port_list = [
|
||||
all_ports[i : i + 3]
|
||||
for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
|
||||
]
|
||||
start_idx += num_dp_groups * 3
|
||||
self._stateless_ep_group_port_list = [
|
||||
all_ports[i : i + 3]
|
||||
for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
|
||||
]
|
||||
start_idx += num_ep_groups * 3
|
||||
self._stateless_eplb_group_port_list = [
|
||||
all_ports[i : i + 3]
|
||||
for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
|
||||
]
|
||||
|
||||
def get_next_stateless_world_group_port(self) -> list[int]:
|
||||
return self._stateless_world_group_port_list.pop()
|
||||
|
||||
def get_next_stateless_dp_group_port(self) -> list[int]:
|
||||
return self._stateless_dp_group_port_list.pop()
|
||||
|
||||
def get_next_stateless_ep_group_port(self) -> list[int]:
|
||||
return self._stateless_ep_group_port_list.pop()
|
||||
|
||||
def get_next_stateless_eplb_group_port(self) -> list[int]:
|
||||
return self._stateless_eplb_group_port_list.pop()
|
||||
|
||||
def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup:
|
||||
# NOTE: In high-concurrency scenarios multiple processes
|
||||
# can pick the same (currently free) port through a race
|
||||
# condition when calling `get_open_port()`. When the first
|
||||
@@ -426,7 +517,8 @@ class ParallelConfig:
|
||||
self.get_next_dp_init_port(),
|
||||
self.data_parallel_rank,
|
||||
self.data_parallel_size,
|
||||
backend=current_platform.dist_backend,
|
||||
backend="gloo",
|
||||
return_store=return_store,
|
||||
)
|
||||
except DistNetworkError as e:
|
||||
# We only want to retry when the root cause is EADDRINUSE.
|
||||
@@ -561,6 +653,21 @@ class ParallelConfig:
|
||||
logger.info("Using external launcher for distributed inference.")
|
||||
self.world_size *= self.data_parallel_size
|
||||
|
||||
if self.enable_elastic_ep:
|
||||
if not self.enable_eplb:
|
||||
raise ValueError("Elastic EP is only supported with enable_eplb=True.")
|
||||
if self.pipeline_parallel_size > 1:
|
||||
raise ValueError(
|
||||
"Elastic EP is not supported with pipeline parallelism "
|
||||
f"(pipeline_parallel_size={self.pipeline_parallel_size})."
|
||||
)
|
||||
if self.data_parallel_external_lb or self.data_parallel_hybrid_lb:
|
||||
raise NotImplementedError(
|
||||
"Elastic EP is not compatible with data_parallel_external_lb "
|
||||
"or data_parallel_hybrid_lb. Elastic EP relies on a single API "
|
||||
"server and core client to coordinate scale up/down."
|
||||
)
|
||||
|
||||
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
|
||||
# Data parallel was specified in the engine args.
|
||||
if self.distributed_executor_backend == "external_launcher":
|
||||
@@ -573,9 +680,12 @@ class ParallelConfig:
|
||||
"Set data_parallel_rank to %d automatically.",
|
||||
self.data_parallel_rank,
|
||||
)
|
||||
if not self._data_parallel_master_port_list:
|
||||
self._data_parallel_master_port_list = get_open_ports_list(5)
|
||||
self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
|
||||
if not self.enable_elastic_ep:
|
||||
if not self._data_parallel_master_port_list:
|
||||
self._data_parallel_master_port_list = get_open_ports_list(5)
|
||||
self.data_parallel_master_port = (
|
||||
self._data_parallel_master_port_list.pop()
|
||||
)
|
||||
|
||||
if not (0 <= self.data_parallel_rank < self.data_parallel_size):
|
||||
raise ValueError(
|
||||
@@ -602,7 +712,7 @@ class ParallelConfig:
|
||||
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
||||
logger.info("Disabling V1 multiprocessing for external launcher.")
|
||||
|
||||
if self.distributed_executor_backend is None and self.world_size > 1:
|
||||
if self.distributed_executor_backend is None and self.world_size_across_dp > 1:
|
||||
# We use multiprocessing by default if world_size fits on the
|
||||
# current node and we aren't in a ray placement group.
|
||||
|
||||
|
||||
@@ -31,8 +31,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
debugging.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
def naive_multicast(
|
||||
self,
|
||||
@@ -138,8 +138,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
all-gather (dispatch) and reduce-scatter (combine).
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
@@ -239,12 +239,12 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
assert has_deep_ep(), (
|
||||
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
|
||||
" to install DeepEP kernels."
|
||||
) # noqa
|
||||
super().__init__(cpu_group)
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
self.handle_cache = Cache()
|
||||
|
||||
# This is the DeepEP default. Stick to it till we can establish
|
||||
@@ -282,7 +282,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
with self.handle_cache._lock:
|
||||
for _, handle in self.handle_cache._cache.items():
|
||||
handle.destroy()
|
||||
self.handle_cache._cache.clear()
|
||||
|
||||
|
||||
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
@@ -290,8 +293,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
def _make_all2all_kwargs(self) -> dict[Any, Any]:
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
@@ -314,6 +317,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
explicitly_destroy=True,
|
||||
)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
@@ -347,8 +351,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
All2All communication based on DeepEP Low-Latency kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
def _make_all2all_kwargs(
|
||||
self,
|
||||
@@ -387,6 +391,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
allow_nvlink_for_low_latency_mode=True,
|
||||
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
|
||||
explicitly_destroy=True,
|
||||
)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
@@ -418,11 +423,11 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
assert has_flashinfer_all2all(), (
|
||||
"flashinfer all2all module not found. Please install/check flashinfer"
|
||||
) # noqa
|
||||
super().__init__(cpu_group)
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
logger.debug(
|
||||
"Initialize for flashinfer All2All rank=%d, world size=%d",
|
||||
self.rank,
|
||||
|
||||
@@ -29,8 +29,9 @@ class All2AllManagerBase:
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
self.cpu_group = cpu_group
|
||||
self.tcp_store_group = tcp_store_group
|
||||
|
||||
# compute some common properties
|
||||
from vllm.distributed.parallel_state import (
|
||||
@@ -47,12 +48,17 @@ class All2AllManagerBase:
|
||||
# when we create this object
|
||||
self.dp_rank = self.dp_group.rank_in_group
|
||||
self.dp_world_size = self.dp_group.world_size
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
self.rank = cpu_group.rank()
|
||||
self.world_size = cpu_group.size()
|
||||
|
||||
# all2all communication often has separate implementations for
|
||||
# intra-node and inter-node communication
|
||||
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
|
||||
if tcp_store_group is None:
|
||||
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
|
||||
else:
|
||||
self.internode = not all(
|
||||
in_the_same_node_as(tcp_store_group, source_rank=0)
|
||||
)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
# get a handle for the all2all communication,
|
||||
@@ -121,17 +127,36 @@ class DeviceCommunicatorBase:
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
global_ranks: list[int] | None = None,
|
||||
global_world_size: int | None = None,
|
||||
):
|
||||
self.device = device or torch.device("cpu")
|
||||
self.cpu_group = cpu_group
|
||||
self.device_group = device_group
|
||||
self.unique_name = unique_name
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
self.ranks = dist.get_process_group_ranks(cpu_group)
|
||||
self.global_rank = dist.get_rank()
|
||||
self.global_world_size = dist.get_world_size()
|
||||
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
|
||||
|
||||
# Check if this is a stateless process group
|
||||
from torch.distributed.distributed_c10d import _world
|
||||
|
||||
is_stateless = _world.pg_map.get(cpu_group, None) is None
|
||||
|
||||
if is_stateless:
|
||||
# For stateless groups, we can't use torch.distributed methods
|
||||
self.rank = cpu_group.rank()
|
||||
self.world_size = cpu_group.size()
|
||||
assert global_ranks is not None
|
||||
assert global_world_size is not None
|
||||
self.ranks = global_ranks
|
||||
self.global_rank = self.ranks[self.rank]
|
||||
self.global_world_size = global_world_size
|
||||
self.rank_in_group = self.rank
|
||||
else:
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
self.ranks = dist.get_process_group_ranks(cpu_group)
|
||||
self.global_rank = dist.get_rank()
|
||||
self.global_world_size = dist.get_world_size()
|
||||
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
|
||||
|
||||
use_ep = False
|
||||
all2all_backend = None
|
||||
@@ -145,7 +170,7 @@ class DeviceCommunicatorBase:
|
||||
use_ep = config.parallel_config.data_parallel_size > 1
|
||||
all2all_backend = config.parallel_config.all2all_backend
|
||||
|
||||
self.is_ep_communicator = "ep" in unique_name
|
||||
self.is_ep_communicator = unique_name.split(":")[0] == "ep"
|
||||
self.use_all2all = self.is_ep_communicator and use_ep
|
||||
self.all2all_backend = all2all_backend
|
||||
self.all2all_manager: All2AllManagerBase | None = None
|
||||
@@ -275,6 +300,13 @@ class DeviceCommunicatorBase:
|
||||
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
|
||||
"""Broadcast a tensor from source rank to all ranks."""
|
||||
if self.world_size == 1:
|
||||
return tensor
|
||||
torch.distributed.broadcast(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
@@ -343,3 +375,6 @@ class DeviceCommunicatorBase:
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
return hidden_states
|
||||
|
||||
def batch_isend_irecv(self, p2p_ops: list):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -16,6 +16,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import StatelessProcessGroup
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -28,8 +29,18 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
global_ranks: list[int] | None = None,
|
||||
global_world_size: int | None = None,
|
||||
tcp_store_group: StatelessProcessGroup | None = None,
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
super().__init__(
|
||||
cpu_group,
|
||||
device,
|
||||
device_group,
|
||||
unique_name,
|
||||
global_ranks,
|
||||
global_world_size,
|
||||
)
|
||||
if "tp" not in unique_name:
|
||||
# custom allreduce or torch symm mem can be used only by tp
|
||||
use_custom_allreduce = False
|
||||
@@ -62,7 +73,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.pynccl_comm: PyNcclCommunicator | None = None
|
||||
if self.world_size > 1:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
group=self.cpu_group,
|
||||
group=self.cpu_group if tcp_store_group is None else tcp_store_group,
|
||||
device=self.device,
|
||||
)
|
||||
if is_symmetric_memory_enabled():
|
||||
@@ -107,19 +118,27 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
if self.all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = NaiveAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "allgather_reducescatter":
|
||||
from .all2all import AgRsAll2AllManager
|
||||
|
||||
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = AgRsAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "deepep_high_throughput":
|
||||
from .all2all import DeepEPHTAll2AllManager
|
||||
|
||||
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = DeepEPHTAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "deepep_low_latency":
|
||||
from .all2all import DeepEPLLAll2AllManager
|
||||
|
||||
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = DeepEPLLAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "mori":
|
||||
from .all2all import MoriAll2AllManager
|
||||
|
||||
@@ -127,7 +146,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
elif self.all2all_backend == "flashinfer_all2allv":
|
||||
from .all2all import FlashInferAllToAllManager
|
||||
|
||||
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
|
||||
self.all2all_manager = FlashInferAllToAllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
|
||||
|
||||
@@ -284,6 +305,18 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
|
||||
"""Broadcast a tensor from source rank to all ranks."""
|
||||
if self.world_size == 1:
|
||||
return tensor
|
||||
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
pynccl_comm.broadcast(tensor, src)
|
||||
return tensor
|
||||
else:
|
||||
raise ValueError("No PyNCCL communicator found")
|
||||
|
||||
def destroy(self):
|
||||
if self.pynccl_comm is not None:
|
||||
self.pynccl_comm = None
|
||||
@@ -403,3 +436,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
hidden_states,
|
||||
is_sequence_parallel,
|
||||
)
|
||||
|
||||
def batch_isend_irecv(self, p2p_ops: list):
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
pynccl_comm.batch_isend_irecv(p2p_ops)
|
||||
else:
|
||||
raise ValueError("No PyNCCL communicator found")
|
||||
|
||||
@@ -312,10 +312,19 @@ class PyNcclCommunicator:
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
if tensor.dtype in [
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e5m2fnuz,
|
||||
]:
|
||||
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
|
||||
else:
|
||||
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
|
||||
self.nccl.ncclSend(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
nccl_dtype,
|
||||
dst,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
@@ -330,10 +339,19 @@ class PyNcclCommunicator:
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
if tensor.dtype in [
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e5m2fnuz,
|
||||
]:
|
||||
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
|
||||
else:
|
||||
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
|
||||
self.nccl.ncclRecv(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
nccl_dtype,
|
||||
src,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
@@ -384,3 +402,17 @@ class PyNcclCommunicator:
|
||||
|
||||
def deregister_comm_window(self, window):
|
||||
return self.nccl.ncclCommWindowDeregister(self.comm, window)
|
||||
|
||||
def batch_isend_irecv(self, p2p_ops: list, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.group_start()
|
||||
for op in p2p_ops:
|
||||
if op.op is torch.distributed.isend:
|
||||
self.send(op.tensor, op.group_peer, stream)
|
||||
elif op.op is torch.distributed.irecv:
|
||||
self.recv(op.tensor, op.group_peer, stream)
|
||||
|
||||
self.group_end()
|
||||
|
||||
@@ -0,0 +1,529 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import gc
|
||||
import weakref
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import P2POp
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||
from vllm.compilation.wrapper import reset_compile_wrapper
|
||||
from vllm.config import (
|
||||
CompilationMode,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
get_pcp_group,
|
||||
get_tp_group,
|
||||
)
|
||||
from vllm.distributed.elastic_ep.standby_state import (
|
||||
create_standby_groups,
|
||||
get_standby_dp_group,
|
||||
get_standby_ep_group,
|
||||
pop_standby_groups,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
_replace_active_groups,
|
||||
prepare_communication_buffer_for_model,
|
||||
)
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
|
||||
from vllm.v1.worker.workspace import lock_workspace, unlock_workspace
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def batch_transfer_weights(
|
||||
model: nn.Module,
|
||||
is_sender: bool,
|
||||
peer_rank: int,
|
||||
dp_group: StatelessGroupCoordinator,
|
||||
expert_weights: Sequence[Iterable[torch.Tensor]],
|
||||
) -> None:
|
||||
device_comm = dp_group.device_communicator
|
||||
if device_comm is None:
|
||||
raise ValueError("No device communicator found")
|
||||
|
||||
expert_weights_set = set()
|
||||
for weight_group in expert_weights:
|
||||
for weight in weight_group:
|
||||
expert_weights_set.add(weight.data_ptr())
|
||||
|
||||
state_dict = model.state_dict()
|
||||
all_params = []
|
||||
|
||||
for name, param in state_dict.items():
|
||||
if name.endswith("expert_map"):
|
||||
continue
|
||||
if param.data_ptr() not in expert_weights_set:
|
||||
all_params.append(param.data)
|
||||
|
||||
assert len(all_params) > 0
|
||||
p2p_ops = []
|
||||
for param in all_params:
|
||||
op = object.__new__(P2POp)
|
||||
if is_sender:
|
||||
op.op = torch.distributed.isend
|
||||
op.tensor = param
|
||||
else:
|
||||
op.op = torch.distributed.irecv
|
||||
op.tensor = param
|
||||
op.group_peer = peer_rank
|
||||
p2p_ops.append(op)
|
||||
device_comm.batch_isend_irecv(p2p_ops)
|
||||
|
||||
|
||||
def broadcast_expert_mapping(
|
||||
physical_to_logical: torch.Tensor | None,
|
||||
num_local_physical_experts: int | None,
|
||||
num_logical_experts: int | None,
|
||||
dp_group: StatelessGroupCoordinator,
|
||||
device: torch.device,
|
||||
src_rank: int = 0,
|
||||
) -> tuple[torch.Tensor, int, int]:
|
||||
if dp_group.rank_in_group == src_rank:
|
||||
assert physical_to_logical is not None
|
||||
assert num_local_physical_experts is not None
|
||||
assert num_logical_experts is not None
|
||||
assert physical_to_logical.dtype == torch.int64
|
||||
shape_tensor = torch.tensor(
|
||||
list(physical_to_logical.shape), dtype=torch.int64, device="cpu"
|
||||
)
|
||||
metadata_tensor = torch.tensor(
|
||||
[num_local_physical_experts, num_logical_experts],
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
)
|
||||
else:
|
||||
shape_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
|
||||
metadata_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
|
||||
|
||||
shape_tensor = dp_group.tcp_store_group.broadcast(shape_tensor, src_rank)
|
||||
metadata_tensor = dp_group.tcp_store_group.broadcast(metadata_tensor, src_rank)
|
||||
|
||||
if dp_group.rank_in_group != src_rank:
|
||||
assert device is not None
|
||||
physical_to_logical = torch.empty(
|
||||
tuple(shape_tensor.tolist()),
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
assert physical_to_logical is not None
|
||||
physical_to_logical = dp_group.broadcast(physical_to_logical, src_rank)
|
||||
num_local_physical_experts = int(metadata_tensor[0].item())
|
||||
num_logical_experts = int(metadata_tensor[1].item())
|
||||
|
||||
return physical_to_logical, num_local_physical_experts, num_logical_experts
|
||||
|
||||
|
||||
class ElasticEPScalingExecutor:
|
||||
def __init__(self, worker):
|
||||
self.worker_ref = weakref.ref(worker)
|
||||
self.reconfig_request = None
|
||||
|
||||
@property
|
||||
def worker(self):
|
||||
worker = self.worker_ref()
|
||||
if worker is None:
|
||||
raise RuntimeError("Worker has been garbage collected")
|
||||
return worker
|
||||
|
||||
def execute(self, execute_method: str, *args, **kwargs):
|
||||
method = getattr(self, execute_method, None)
|
||||
if method is None:
|
||||
raise ValueError(f"Unknown execute method: {execute_method}")
|
||||
return method(*args, **kwargs)
|
||||
|
||||
def create_standby_groups(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
self.reconfig_request = reconfig_request
|
||||
new_dp_size = reconfig_request.new_data_parallel_size
|
||||
world_size = self.worker.vllm_config.parallel_config.world_size
|
||||
new_world_size_across_dp = world_size * new_dp_size
|
||||
updated_config = copy.copy(self.worker.vllm_config)
|
||||
updated_config.parallel_config = copy.deepcopy(
|
||||
self.worker.vllm_config.parallel_config
|
||||
)
|
||||
updated_config.parallel_config.data_parallel_size = new_dp_size
|
||||
with set_current_vllm_config(updated_config):
|
||||
create_standby_groups(
|
||||
new_dp_size=new_dp_size,
|
||||
new_world_size_across_dp=new_world_size_across_dp,
|
||||
master_ip=reconfig_request.new_data_parallel_master_ip,
|
||||
world_group_ports=reconfig_request.new_stateless_world_group_port_list,
|
||||
dp_group_ports=reconfig_request.new_stateless_dp_group_port_list,
|
||||
ep_group_ports=reconfig_request.new_stateless_ep_group_port_list,
|
||||
eplb_group_ports=reconfig_request.new_stateless_eplb_group_port_list,
|
||||
)
|
||||
self.worker.model_runner.eep_eplb_suppressed = True
|
||||
standby_ep_group = get_standby_ep_group()
|
||||
assert standby_ep_group is not None
|
||||
if standby_ep_group.rank == 0:
|
||||
logger.info("[Elastic EP] EPLB disabled during elastic scaling transition")
|
||||
|
||||
def transfer_weights(self, old_dp_size: int, new_dp_size: int) -> None:
|
||||
standby_dp_group = get_standby_dp_group()
|
||||
assert standby_dp_group is not None
|
||||
# Broadcast old_dp_size to all workers in standby group
|
||||
if standby_dp_group.rank_in_group < old_dp_size:
|
||||
old_dp_size_tensor = torch.tensor(
|
||||
[old_dp_size], dtype=torch.int64, device="cpu"
|
||||
)
|
||||
else:
|
||||
old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
|
||||
old_dp_size_tensor = standby_dp_group.tcp_store_group.broadcast(
|
||||
old_dp_size_tensor, 0
|
||||
)
|
||||
|
||||
num_new_workers = new_dp_size - old_dp_size
|
||||
dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
# Sender-receiver pairing: the first new_workers % old_dp_size
|
||||
# senders get (k+1) contiguous receivers, the rest get k
|
||||
# receivers.
|
||||
num_dst_per_sender = num_new_workers // old_dp_size
|
||||
remainder = num_new_workers % old_dp_size
|
||||
|
||||
if dp_rank < remainder:
|
||||
recv_begin = dp_rank * (num_dst_per_sender + 1)
|
||||
recv_end = recv_begin + num_dst_per_sender + 1
|
||||
else:
|
||||
recv_begin = (
|
||||
remainder * (num_dst_per_sender + 1)
|
||||
+ (dp_rank - remainder) * num_dst_per_sender
|
||||
)
|
||||
recv_end = recv_begin + num_dst_per_sender
|
||||
|
||||
ranks_to_send = list(range(old_dp_size + recv_begin, old_dp_size + recv_end))
|
||||
|
||||
model = self.worker.model_runner.get_model()
|
||||
for new_worker_rank in sorted(ranks_to_send):
|
||||
batch_transfer_weights(
|
||||
model=model,
|
||||
is_sender=True,
|
||||
peer_rank=new_worker_rank,
|
||||
dp_group=standby_dp_group,
|
||||
expert_weights=model.expert_weights,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def broadcast_expert_mapping(self) -> None:
|
||||
standby_dp_group = get_standby_dp_group()
|
||||
assert standby_dp_group is not None
|
||||
model_config = self.worker.model_runner.model_config
|
||||
eplb_state = self.worker.model_runner.eplb_state
|
||||
assert eplb_state is not None
|
||||
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
|
||||
physical_to_logical = eplb_model_state.physical_to_logical_map
|
||||
num_physical_experts = physical_to_logical.shape[1]
|
||||
num_local_physical_experts = num_physical_experts // get_ep_group().world_size
|
||||
num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
|
||||
broadcast_expert_mapping(
|
||||
physical_to_logical=physical_to_logical,
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
num_logical_experts=num_logical_experts,
|
||||
dp_group=standby_dp_group,
|
||||
src_rank=0,
|
||||
device=self.worker.device,
|
||||
)
|
||||
|
||||
def switch_and_remove(self) -> None:
|
||||
_replace_active_groups(world=None, dp=None, ep=None, eplb=None, node_count=None)
|
||||
|
||||
def switch_and_prepare(self) -> None:
|
||||
old_dp_size = get_dp_group().world_size
|
||||
old_ep_size = get_ep_group().world_size
|
||||
|
||||
_replace_active_groups(**pop_standby_groups())
|
||||
|
||||
parallel_config = self.worker.vllm_config.parallel_config
|
||||
reconfig_request = self.reconfig_request
|
||||
assert reconfig_request is not None
|
||||
new_dp_size = reconfig_request.new_data_parallel_size
|
||||
new_ep_size = get_ep_group().world_size
|
||||
|
||||
parallel_config.data_parallel_size = new_dp_size
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank_local = (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
)
|
||||
parallel_config.data_parallel_master_ip = (
|
||||
reconfig_request.new_data_parallel_master_ip
|
||||
)
|
||||
parallel_config.data_parallel_master_port = (
|
||||
reconfig_request.new_data_parallel_master_port
|
||||
)
|
||||
|
||||
# Reconfigure MoE modules with new EP size
|
||||
moe_modules = [
|
||||
module
|
||||
for module in self.worker.model_runner.model.modules()
|
||||
if (
|
||||
module.__class__.__name__ == "FusedMoE"
|
||||
or module.__class__.__name__ == "SharedFusedMoE"
|
||||
)
|
||||
]
|
||||
num_local_experts = moe_modules[0].moe_config.num_local_experts
|
||||
assert all(
|
||||
module.moe_config.num_local_experts == num_local_experts
|
||||
for module in moe_modules
|
||||
), "All MoE modules must have the same number of experts"
|
||||
for module in moe_modules:
|
||||
module.moe_config.num_experts = num_local_experts * new_ep_size
|
||||
module.global_num_experts = module.moe_config.num_experts
|
||||
tp_size = get_tp_group().world_size
|
||||
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
sp_size = tp_size if is_sequence_parallel else 1
|
||||
module.moe_parallel_config = FusedMoEParallelConfig.make(
|
||||
tp_size_=tp_size,
|
||||
pcp_size_=get_pcp_group().world_size,
|
||||
dp_size_=get_dp_group().world_size,
|
||||
sp_size_=sp_size,
|
||||
vllm_parallel_config=parallel_config,
|
||||
)
|
||||
module.moe_config.moe_parallel_config = module.moe_parallel_config
|
||||
|
||||
# Update EPLB state
|
||||
eplb_state = self.worker.model_runner.eplb_state
|
||||
assert eplb_state is not None
|
||||
model_config = self.worker.model_runner.model_config
|
||||
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
|
||||
|
||||
num_physical_experts = num_local_experts * new_ep_size
|
||||
num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
|
||||
parallel_config.eplb_config.num_redundant_experts = (
|
||||
num_physical_experts - num_logical_experts
|
||||
)
|
||||
old_physical_to_logical = eplb_model_state.physical_to_logical_map
|
||||
num_moe_layers = old_physical_to_logical.shape[0]
|
||||
num_local_experts = eplb_model_state.expert_load_pass.shape[1] // old_ep_size
|
||||
if new_dp_size > old_dp_size:
|
||||
expanded_physical_to_logical = torch.full(
|
||||
(num_moe_layers, num_local_experts * new_ep_size),
|
||||
-1,
|
||||
dtype=old_physical_to_logical.dtype,
|
||||
device=old_physical_to_logical.device,
|
||||
)
|
||||
expanded_physical_to_logical[:, : num_local_experts * old_ep_size] = (
|
||||
old_physical_to_logical
|
||||
)
|
||||
eplb_model_state.physical_to_logical_map = expanded_physical_to_logical
|
||||
|
||||
old_num_physical_experts = eplb_model_state.expert_load_pass.shape[1]
|
||||
pad_size = num_physical_experts - old_num_physical_experts
|
||||
if new_dp_size > old_dp_size:
|
||||
assert pad_size > 0
|
||||
expanded_expert_load_pass = F.pad(
|
||||
eplb_model_state.expert_load_pass, (0, pad_size), value=0
|
||||
)
|
||||
expanded_expert_load_window = F.pad(
|
||||
eplb_model_state.expert_load_window, (0, pad_size), value=0
|
||||
)
|
||||
eplb_model_state.expert_load_pass = expanded_expert_load_pass
|
||||
eplb_model_state.expert_load_window = expanded_expert_load_window
|
||||
eplb_state.num_valid_physical_experts = old_num_physical_experts
|
||||
else:
|
||||
assert pad_size < 0
|
||||
eplb_model_state.expert_load_pass = eplb_model_state.expert_load_pass[
|
||||
:, :num_physical_experts
|
||||
]
|
||||
eplb_model_state.expert_load_window = eplb_model_state.expert_load_window[
|
||||
:, :, :num_physical_experts
|
||||
]
|
||||
eplb_state.num_valid_physical_experts = num_physical_experts
|
||||
|
||||
model = self.worker.model_runner.get_model()
|
||||
model.expert_weights = []
|
||||
with set_current_vllm_config(self.worker.vllm_config):
|
||||
model.set_eplb_state(
|
||||
eplb_model_state.expert_load_pass,
|
||||
eplb_model_state.logical_to_physical_map,
|
||||
eplb_model_state.logical_replica_count,
|
||||
)
|
||||
model.update_physical_experts_metadata(
|
||||
num_physical_experts=num_physical_experts,
|
||||
num_local_physical_experts=num_local_experts,
|
||||
)
|
||||
# Force re-creation of the modular kernel (and all2all manager)
|
||||
# for the new EP size by resetting quant_method to base
|
||||
for module in moe_modules:
|
||||
if hasattr(module.quant_method, "old_quant_method"):
|
||||
module.quant_method = module.quant_method.old_quant_method
|
||||
module.runner = module._init_runner()
|
||||
prepare_communication_buffer_for_model(self.worker.model_runner.model)
|
||||
if (
|
||||
self.worker.vllm_config.compilation_config.mode
|
||||
== CompilationMode.STOCK_TORCH_COMPILE
|
||||
):
|
||||
# NOTE(yongji): when using stock torch.compile,
|
||||
# torch.compile is triggered during GPUModelRunner's load_model()
|
||||
# TODO(yongji):check do we need to re-trigger torch.compile here?
|
||||
# any changes to the tensor shapes in execution should already
|
||||
# be handled internally by torch.compile.
|
||||
backend = self.worker.vllm_config.compilation_config.init_backend(
|
||||
self.worker.vllm_config
|
||||
)
|
||||
compilation_counter.stock_torch_compile_count += 1
|
||||
self.worker.model_runner.model.compile(fullgraph=True, backend=backend)
|
||||
|
||||
# release all previously captured CUDA graphs
|
||||
if isinstance(self.worker.model_runner.model, CUDAGraphWrapper):
|
||||
wrapper = self.worker.model_runner.model
|
||||
wrapper.concrete_cudagraph_entries = {}
|
||||
elif isinstance(self.worker.model_runner.model, UBatchWrapper):
|
||||
raise RuntimeError("DBO is not yet supported in elastic EP")
|
||||
|
||||
multi_block_table = self.worker.model_runner.input_batch.block_table
|
||||
saved_block_tables: list[tuple[torch.Tensor, torch.Tensor]] = []
|
||||
for bt in multi_block_table.block_tables:
|
||||
saved_block_tables.append(
|
||||
(bt.block_table.gpu.clone(), bt.block_table.cpu.clone())
|
||||
)
|
||||
multi_block_table.clear()
|
||||
|
||||
# reset the compile wrapper
|
||||
torch.compiler.reset()
|
||||
with set_current_vllm_config(self.worker.vllm_config):
|
||||
reset_compile_wrapper(self.worker.model_runner.get_model())
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
unlock_workspace()
|
||||
self.worker.compile_or_warm_up_model()
|
||||
lock_workspace()
|
||||
|
||||
for bt, (saved_gpu, saved_cpu) in zip(
|
||||
multi_block_table.block_tables, saved_block_tables
|
||||
):
|
||||
bt.block_table.gpu.copy_(saved_gpu)
|
||||
bt.block_table.cpu.copy_(saved_cpu)
|
||||
|
||||
def perform_eplb_reshuffle(self, new_dp_size: int | None = None) -> None:
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Starting expert resharding...")
|
||||
|
||||
eplb_state = self.worker.model_runner.eplb_state
|
||||
assert eplb_state is not None
|
||||
|
||||
model_config = self.worker.model_runner.model_config
|
||||
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
|
||||
is_async_enabled = eplb_state.is_async
|
||||
eplb_state.is_async = False
|
||||
if new_dp_size is None:
|
||||
eplb_state.rearrange()
|
||||
else:
|
||||
# scale down
|
||||
parallel_config = self.worker.vllm_config.parallel_config
|
||||
tp_size = parallel_config.tensor_parallel_size
|
||||
old_ep_size = parallel_config.data_parallel_size * tp_size
|
||||
new_ep_size = new_dp_size * tp_size
|
||||
|
||||
rank_mapping = {
|
||||
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
|
||||
for old_ep_rank in range(old_ep_size)
|
||||
}
|
||||
|
||||
eplb_state.rearrange(rank_mapping=rank_mapping)
|
||||
# NOTE(yongji): check whether we need to synchronize here
|
||||
torch.cuda.synchronize()
|
||||
# reset expert_rearrangement_step to ensure all ranks are synchronized
|
||||
eplb_state.expert_rearrangement_step = 0
|
||||
eplb_state.num_valid_physical_experts = (
|
||||
eplb_model_state.physical_to_logical_map.shape[1]
|
||||
)
|
||||
eplb_state.is_async = is_async_enabled
|
||||
self.worker.model_runner.eep_eplb_suppressed = False
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Expert resharding completed")
|
||||
|
||||
def receive_weights(self) -> None:
|
||||
dp_group = get_dp_group()
|
||||
assert isinstance(dp_group, StatelessGroupCoordinator)
|
||||
new_dp_size = dp_group.world_size
|
||||
dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
# Receive old_dp_size broadcasted during transfer_weights
|
||||
old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
|
||||
old_dp_size_tensor = dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0)
|
||||
old_dp_size = int(old_dp_size_tensor[0].item())
|
||||
|
||||
# Calculate which existing worker will send to this new worker
|
||||
num_new_workers = new_dp_size - old_dp_size
|
||||
new_worker_idx = dp_rank - old_dp_size
|
||||
num_dst_per_sender = num_new_workers // old_dp_size
|
||||
remainder = num_new_workers % old_dp_size
|
||||
|
||||
if new_worker_idx < remainder * (num_dst_per_sender + 1):
|
||||
sender_rank = new_worker_idx // (num_dst_per_sender + 1)
|
||||
else:
|
||||
sender_rank = (
|
||||
remainder
|
||||
+ (new_worker_idx - remainder * (num_dst_per_sender + 1))
|
||||
// num_dst_per_sender
|
||||
)
|
||||
|
||||
model = self.worker.model_runner.get_model()
|
||||
batch_transfer_weights(
|
||||
model=model,
|
||||
is_sender=False,
|
||||
peer_rank=sender_rank,
|
||||
dp_group=dp_group,
|
||||
expert_weights=model.expert_weights,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]:
|
||||
dp_group = get_dp_group()
|
||||
assert isinstance(dp_group, StatelessGroupCoordinator)
|
||||
physical_to_logical, num_local_physical_experts, num_logical_experts = (
|
||||
broadcast_expert_mapping(
|
||||
physical_to_logical=None,
|
||||
num_local_physical_experts=None,
|
||||
num_logical_experts=None,
|
||||
dp_group=dp_group,
|
||||
src_rank=0,
|
||||
device=self.worker.device,
|
||||
)
|
||||
)
|
||||
num_moe_layers = physical_to_logical.shape[0]
|
||||
new_dp_size = get_dp_group().world_size
|
||||
tp_size = self.worker.vllm_config.parallel_config.tensor_parallel_size
|
||||
new_ep_size = new_dp_size * tp_size
|
||||
expanded_physical_to_logical = torch.full(
|
||||
(num_moe_layers, num_local_physical_experts * new_ep_size),
|
||||
-1,
|
||||
dtype=physical_to_logical.dtype,
|
||||
device=physical_to_logical.device,
|
||||
)
|
||||
old_num_physical_experts = physical_to_logical.shape[1]
|
||||
expanded_physical_to_logical[:, :old_num_physical_experts] = physical_to_logical
|
||||
return (
|
||||
expanded_physical_to_logical,
|
||||
num_logical_experts,
|
||||
old_num_physical_experts,
|
||||
)
|
||||
|
||||
def prepare_new_worker(self) -> None:
|
||||
with set_current_vllm_config(self.worker.vllm_config):
|
||||
prepare_communication_buffer_for_model(self.worker.model_runner.get_model())
|
||||
@@ -0,0 +1,563 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import enum
|
||||
import time
|
||||
import weakref
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed import (
|
||||
sched_yield,
|
||||
stateless_destroy_torch_distributed_process_group,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.engine import (
|
||||
EEPNotificationType,
|
||||
ReconfigureDistributedRequest,
|
||||
ReconfigureRankType,
|
||||
)
|
||||
from vllm.v1.engine.core import DPEngineCoreProc
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
WorkerType = Literal["existing", "new", "removing"]
|
||||
|
||||
|
||||
class ScaleUpExistingEngineState(enum.IntEnum):
|
||||
WAIT_NEW_CORE_ENGINES_INIT = 0
|
||||
CREATE_STANDBY_GROUPS = 1
|
||||
TRANSFER_EXPERT_MAPPING = 2
|
||||
WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT = 3
|
||||
TRANSFER_WEIGHTS = 4
|
||||
SYNC_KV_CACHE_MEMORY_SIZE = 5
|
||||
SWITCH_AND_PREPARE = 6
|
||||
EPLB_RESHUFFLE = 7
|
||||
COMPLETE = 8
|
||||
|
||||
|
||||
class ScaleUpNewEngineState(enum.IntEnum):
|
||||
PREPARE = 0
|
||||
EPLB_RESHUFFLE = 1
|
||||
COMPLETE = 2
|
||||
|
||||
|
||||
class ScaleDownRemainingEngineState(enum.IntEnum):
|
||||
PREPARE = 0
|
||||
EPLB_RESHUFFLE = 1
|
||||
SWITCH_AND_PREPARE = 2
|
||||
COMPLETE = 3
|
||||
|
||||
|
||||
class ScaleDownRemovingEngineState(enum.IntEnum):
|
||||
PREPARE = 0
|
||||
EPLB_RESHUFFLE = 1
|
||||
COMPLETE = 2
|
||||
|
||||
|
||||
class _BarrierTimeoutError(RuntimeError):
|
||||
"""
|
||||
Exception raised for timeout
|
||||
in the first stage of our two-staged
|
||||
TCPStore based barrier to synchronize the
|
||||
execution of all engines in the DP group.
|
||||
"""
|
||||
|
||||
|
||||
class ElasticEPScalingState:
|
||||
def __init__(
|
||||
self,
|
||||
model_executor: "Executor",
|
||||
engine_core: "DPEngineCoreProc",
|
||||
vllm_config: "VllmConfig",
|
||||
new_parallel_config: ParallelConfig,
|
||||
worker_type: WorkerType,
|
||||
scale_type: Literal["scale_up", "scale_down"],
|
||||
reconfig_request: ReconfigureDistributedRequest | None = None,
|
||||
):
|
||||
self.model_executor_ref = weakref.ref(model_executor)
|
||||
self.engine_core_ref = weakref.ref(engine_core)
|
||||
self.vllm_config = vllm_config
|
||||
self.old_dp_group = self.engine_core.dp_group if worker_type != "new" else None
|
||||
self.old_dp_store = self.engine_core.dp_store if worker_type != "new" else None
|
||||
self.new_parallel_config: ParallelConfig = new_parallel_config
|
||||
self.new_dp_group: torch.distributed.ProcessGroup | None = (
|
||||
self.engine_core.dp_group if worker_type == "new" else None
|
||||
)
|
||||
self.new_dp_store = self.engine_core.dp_store if worker_type == "new" else None
|
||||
self.worker_type = worker_type
|
||||
self.scale_type = scale_type
|
||||
self.reconfig_request = reconfig_request
|
||||
|
||||
if scale_type == "scale_up":
|
||||
self.state = (
|
||||
ScaleUpNewEngineState.PREPARE
|
||||
if worker_type == "new"
|
||||
else ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
|
||||
)
|
||||
else:
|
||||
self.state = (
|
||||
ScaleDownRemovingEngineState.PREPARE
|
||||
if worker_type == "removing"
|
||||
else ScaleDownRemainingEngineState.PREPARE
|
||||
)
|
||||
|
||||
@property
|
||||
def model_executor(self) -> "Executor":
|
||||
model_executor = self.model_executor_ref()
|
||||
if model_executor is None:
|
||||
raise RuntimeError("Model executor has been garbage collected")
|
||||
return model_executor
|
||||
|
||||
@property
|
||||
def engine_core(self) -> "DPEngineCoreProc":
|
||||
engine_core = self.engine_core_ref()
|
||||
if engine_core is None:
|
||||
raise RuntimeError("Engine core has been garbage collected")
|
||||
return engine_core
|
||||
|
||||
def progress(self) -> bool:
|
||||
if self.scale_type == "scale_up":
|
||||
return (
|
||||
self._progress_new_engine()
|
||||
if self.worker_type == "new"
|
||||
else self._progress_existing_engine()
|
||||
)
|
||||
return (
|
||||
self._progress_removing_engine()
|
||||
if self.worker_type == "removing"
|
||||
else self._progress_remaining_engine()
|
||||
)
|
||||
|
||||
def _execute_tcp_store_barrier(
|
||||
self, dp_store, group_rank, group_size, barrier_id, timeout=None
|
||||
):
|
||||
arrival_key = f"arrival_{barrier_id}_{group_rank}"
|
||||
dp_store.set(arrival_key, b"1")
|
||||
|
||||
start_time = time.time()
|
||||
processes_arrived: set[int] = set()
|
||||
|
||||
while len(processes_arrived) < group_size:
|
||||
if (
|
||||
timeout is not None
|
||||
and time.time() - start_time > timeout.total_seconds()
|
||||
):
|
||||
raise _BarrierTimeoutError(
|
||||
f"Barrier timed out after {timeout.total_seconds()} seconds"
|
||||
)
|
||||
|
||||
for i in range(group_size):
|
||||
if i in processes_arrived:
|
||||
continue
|
||||
|
||||
key = f"arrival_{barrier_id}_{i}"
|
||||
present = dp_store.check([key])
|
||||
if present:
|
||||
processes_arrived.add(i)
|
||||
|
||||
if len(processes_arrived) < group_size:
|
||||
sched_yield()
|
||||
|
||||
def _staged_barrier(self, use_new_group: bool, barrier_name: str) -> bool:
|
||||
"""
|
||||
Execute a two-staged barrier to synchronize all engines in the DP group.
|
||||
|
||||
Some DP EngineCores may receive the reconfiguration notifications
|
||||
later than others, and already proceed to engine step (model forward)
|
||||
in the busy loop.
|
||||
In this case, EngineCores that already proceed to reconfiguration
|
||||
should skip reconfiguration and execute model forward for one more
|
||||
step, so in the next step, all EngineCores will be synchronized.
|
||||
We use a two-staged barrier to achieve this. The first time each
|
||||
EngineCore executes the barrier, if a timeout is reached before the
|
||||
barrier completes, that means some EngineCores have already entered
|
||||
engine step. The EngineCores that timed out will then proceed to
|
||||
engine step, and will synchronize with the other EngineCores in the
|
||||
next step with a barrier without timeout.
|
||||
"""
|
||||
dp_store = self.new_dp_store if use_new_group else self.old_dp_store
|
||||
dp_group = self.new_dp_group if use_new_group else self.old_dp_group
|
||||
assert dp_group is not None
|
||||
|
||||
group_rank = dp_group.rank()
|
||||
group_size = dp_group.size()
|
||||
barrier_id = f"eep_barrier_{barrier_name}"
|
||||
sync_key = f"{barrier_id}_sync"
|
||||
|
||||
# TODO(yongji): figure out appropriate timeout for the barrier
|
||||
timeout = None if dp_store.check([sync_key]) else timedelta(seconds=5)
|
||||
|
||||
try:
|
||||
self._execute_tcp_store_barrier(
|
||||
dp_store, group_rank, group_size, barrier_id, timeout=timeout
|
||||
)
|
||||
torch.distributed.barrier(dp_group)
|
||||
if group_rank == 0:
|
||||
dp_store.delete_key(sync_key)
|
||||
for i in range(group_size):
|
||||
dp_store.delete_key(f"arrival_{barrier_id}_{i}")
|
||||
return True
|
||||
except _BarrierTimeoutError as e:
|
||||
if timeout is None:
|
||||
raise RuntimeError("Unexpected timeout encountered") from e
|
||||
dp_store.compare_set(sync_key, "", b"1")
|
||||
return False
|
||||
|
||||
def _progress_existing_engine(self) -> bool:
|
||||
state = self.state
|
||||
|
||||
if state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT:
|
||||
return False
|
||||
|
||||
elif state == ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS:
|
||||
# NOTE(yongji): wait for all existing workers to receive the request
|
||||
if (
|
||||
int(self.old_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.old_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=False, barrier_name="create_standby_groups"
|
||||
):
|
||||
return False
|
||||
if self.old_dp_group.rank() == 0:
|
||||
self.old_dp_store.delete_key("eep_barrier_engine_count")
|
||||
self._create_standby_groups()
|
||||
self.state = ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING
|
||||
return True
|
||||
|
||||
elif state == ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING:
|
||||
self._transfer_expert_mapping()
|
||||
self.state = ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
|
||||
return True
|
||||
|
||||
elif state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT:
|
||||
return False
|
||||
|
||||
elif state == ScaleUpExistingEngineState.TRANSFER_WEIGHTS:
|
||||
if (
|
||||
int(self.old_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.old_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=False, barrier_name="transfer_weights"
|
||||
):
|
||||
return False
|
||||
if self.old_dp_group.rank() == 0:
|
||||
self.old_dp_store.delete_key("eep_barrier_engine_count")
|
||||
self._transfer_weights()
|
||||
self.state = ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE
|
||||
return True
|
||||
|
||||
elif state == ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE:
|
||||
self._sync_kv_cache_memory_size()
|
||||
self.state = ScaleUpExistingEngineState.SWITCH_AND_PREPARE
|
||||
return True
|
||||
|
||||
elif state == ScaleUpExistingEngineState.SWITCH_AND_PREPARE:
|
||||
self._switch_and_prepare()
|
||||
self.state = ScaleUpExistingEngineState.EPLB_RESHUFFLE
|
||||
self.new_dp_store.add("eep_barrier_engine_count", 1)
|
||||
return True
|
||||
|
||||
elif state == ScaleUpExistingEngineState.EPLB_RESHUFFLE:
|
||||
assert self.new_dp_group is not None
|
||||
if (
|
||||
int(self.new_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.new_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=True, barrier_name="eplb_reshuffle"
|
||||
):
|
||||
return False
|
||||
if self.new_dp_group.rank() == 0:
|
||||
self.new_dp_store.delete_key("eep_barrier_engine_count")
|
||||
self._eplb_reshuffle()
|
||||
self.state = ScaleUpExistingEngineState.COMPLETE
|
||||
self._update_parallel_config()
|
||||
return True
|
||||
|
||||
else:
|
||||
assert self.state == ScaleUpExistingEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
def _progress_new_engine(self) -> bool:
|
||||
state = self.state
|
||||
assert self.new_dp_group is not None
|
||||
|
||||
if state == ScaleUpNewEngineState.PREPARE:
|
||||
tensor = torch.tensor([0, 0, 0], dtype=torch.int32, device="cpu")
|
||||
torch.distributed.all_reduce(
|
||||
tensor,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=self.new_dp_group,
|
||||
)
|
||||
data = tensor.tolist()
|
||||
self.engine_core.engines_running = bool(data[0])
|
||||
self.engine_core.current_wave = int(data[1])
|
||||
self.engine_core.step_counter = int(data[2])
|
||||
self.state = ScaleUpNewEngineState.EPLB_RESHUFFLE
|
||||
self.new_dp_store.add("eep_barrier_engine_count", 1)
|
||||
return True
|
||||
|
||||
elif state == ScaleUpNewEngineState.EPLB_RESHUFFLE:
|
||||
if (
|
||||
int(self.new_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.new_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=True, barrier_name="eplb_reshuffle"
|
||||
):
|
||||
return False
|
||||
assert self.new_dp_group.rank() > 0
|
||||
self._eplb_reshuffle()
|
||||
self.state = ScaleUpNewEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
else:
|
||||
assert self.state == ScaleUpNewEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
def _progress_remaining_engine(self) -> bool:
|
||||
state = self.state
|
||||
|
||||
if state == ScaleDownRemainingEngineState.PREPARE:
|
||||
self.state = ScaleDownRemainingEngineState.EPLB_RESHUFFLE
|
||||
self.old_dp_store.add("eep_barrier_engine_count", 1)
|
||||
return True
|
||||
|
||||
elif state == ScaleDownRemainingEngineState.EPLB_RESHUFFLE:
|
||||
if (
|
||||
int(self.old_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.old_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=False, barrier_name="eplb_reshuffle"
|
||||
):
|
||||
return False
|
||||
if self.old_dp_group.rank() == 0:
|
||||
self.old_dp_store.delete_key("eep_barrier_engine_count")
|
||||
self._eplb_reshuffle_before_scale_down()
|
||||
self.state = ScaleDownRemainingEngineState.SWITCH_AND_PREPARE
|
||||
# NOTE(yongji): currently, after EPLB reshuffle
|
||||
# that redistributes experts to remaining workers, workers
|
||||
# to be removed will immediately initiate shutdown;
|
||||
# existing workers can no longer execute forward steps using
|
||||
# the old setup. In the future, we may keep
|
||||
# the removing workers alive a bit longer,
|
||||
# e.g., to drain in-batch requests.
|
||||
self._create_standby_groups()
|
||||
self._switch_and_prepare()
|
||||
self._update_parallel_config()
|
||||
self.state = ScaleDownRemainingEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
else:
|
||||
assert self.state == ScaleDownRemainingEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
def _progress_removing_engine(self) -> bool:
|
||||
state = self.state
|
||||
|
||||
if state == ScaleDownRemovingEngineState.PREPARE:
|
||||
self.state = ScaleDownRemovingEngineState.EPLB_RESHUFFLE
|
||||
self.old_dp_store.add("eep_barrier_engine_count", 1)
|
||||
return True
|
||||
|
||||
if state == ScaleDownRemovingEngineState.EPLB_RESHUFFLE:
|
||||
if (
|
||||
int(self.old_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.old_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=False, barrier_name="eplb_reshuffle"
|
||||
):
|
||||
return False
|
||||
assert self.old_dp_group.rank() > 0
|
||||
self._eplb_reshuffle_before_scale_down()
|
||||
self._switch_and_remove()
|
||||
self.state = ScaleDownRemovingEngineState.COMPLETE
|
||||
self.engine_core._eep_send_engine_core_notification(
|
||||
EEPNotificationType.SHUTDOWN_COMPLETE
|
||||
)
|
||||
self.engine_core.shutdown()
|
||||
return True
|
||||
|
||||
else:
|
||||
assert self.state == ScaleDownRemovingEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
def handle_notification(self, notification_type: EEPNotificationType):
|
||||
assert self.worker_type != "new"
|
||||
if (
|
||||
notification_type == EEPNotificationType.NEW_CORE_ENGINES_INIT_READY
|
||||
and self.state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
|
||||
):
|
||||
self.old_dp_store.add("eep_barrier_engine_count", 1)
|
||||
self.state = ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS
|
||||
elif (
|
||||
notification_type == EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
|
||||
and self.state
|
||||
== ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
|
||||
):
|
||||
self.old_dp_store.add("eep_barrier_engine_count", 1)
|
||||
self.state = ScaleUpExistingEngineState.TRANSFER_WEIGHTS
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
if self.scale_type == "scale_up":
|
||||
return (
|
||||
self.state == ScaleUpNewEngineState.COMPLETE
|
||||
if self.worker_type == "new"
|
||||
else self.state == ScaleUpExistingEngineState.COMPLETE
|
||||
)
|
||||
return (
|
||||
self.state == ScaleDownRemovingEngineState.COMPLETE
|
||||
if self.worker_type == "removing"
|
||||
else self.state == ScaleDownRemainingEngineState.COMPLETE
|
||||
)
|
||||
|
||||
def _create_standby_groups(self):
|
||||
self.new_dp_group, self.new_dp_store = (
|
||||
self.new_parallel_config.stateless_init_dp_group(return_store=True)
|
||||
)
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("create_standby_groups", self.reconfig_request)
|
||||
)
|
||||
if self.old_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] Created standby communication groups")
|
||||
|
||||
def _transfer_weights(self):
|
||||
assert self.reconfig_request is not None
|
||||
old_dp_size = self.old_dp_group.size()
|
||||
new_dp_size = self.reconfig_request.new_data_parallel_size
|
||||
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("transfer_weights", old_dp_size, new_dp_size)
|
||||
)
|
||||
if self.old_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] Transferred weights to new workers")
|
||||
|
||||
def _transfer_expert_mapping(self):
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("broadcast_expert_mapping",)
|
||||
)
|
||||
if self.old_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] Broadcasted expert mapping to new workers")
|
||||
|
||||
def _sync_kv_cache_memory_size(self):
|
||||
assert self.engine_core.available_gpu_memory_for_kv_cache > 0
|
||||
assert self.new_dp_group is not None
|
||||
ParallelConfig.sync_kv_cache_memory_size(
|
||||
self.new_dp_group,
|
||||
self.engine_core.available_gpu_memory_for_kv_cache,
|
||||
)
|
||||
if self.old_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] Synced KV cache memory size to new workers")
|
||||
|
||||
def _switch_and_prepare(self):
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("switch_and_prepare",)
|
||||
)
|
||||
old_dp_group = self.old_dp_group
|
||||
stateless_destroy_torch_distributed_process_group(old_dp_group)
|
||||
assert self.new_dp_group is not None
|
||||
new_dp_group = self.new_dp_group
|
||||
self.engine_core.dp_group = new_dp_group
|
||||
self.engine_core.dp_rank = new_dp_group.rank()
|
||||
self.engine_core.dp_store = self.new_dp_store
|
||||
engines_running = int(self.engine_core.engines_running)
|
||||
current_wave = self.engine_core.current_wave
|
||||
step_counter = self.engine_core.step_counter
|
||||
tensor = torch.tensor(
|
||||
[engines_running, current_wave, step_counter],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
tensor, op=torch.distributed.ReduceOp.MAX, group=new_dp_group
|
||||
)
|
||||
data = tensor.tolist()
|
||||
self.engine_core.engines_running = bool(data[0])
|
||||
self.engine_core.current_wave = int(data[1])
|
||||
self.engine_core.step_counter = int(data[2])
|
||||
if new_dp_group.rank() == 0:
|
||||
self.engine_core._eep_send_engine_core_notification(
|
||||
EEPNotificationType.RECONFIGURE_FINISHED
|
||||
)
|
||||
logger.info("[Elastic EP] Switched to new setup")
|
||||
|
||||
def _eplb_reshuffle(self):
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("perform_eplb_reshuffle",)
|
||||
)
|
||||
assert self.new_dp_group is not None
|
||||
if self.new_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] EPLB reshuffle completed")
|
||||
|
||||
def _eplb_reshuffle_before_scale_down(self):
|
||||
assert self.reconfig_request is not None
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute",
|
||||
args=(
|
||||
"perform_eplb_reshuffle",
|
||||
self.reconfig_request.new_data_parallel_size,
|
||||
),
|
||||
)
|
||||
if self.old_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] EPLB reshuffle completed")
|
||||
|
||||
def _switch_and_remove(self):
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("switch_and_remove",)
|
||||
)
|
||||
|
||||
def _update_parallel_config(self):
|
||||
assert self.reconfig_request is not None
|
||||
reconfig_request = self.reconfig_request
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank_local = (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
)
|
||||
parallel_config.data_parallel_master_ip = (
|
||||
reconfig_request.new_data_parallel_master_ip
|
||||
)
|
||||
parallel_config.data_parallel_master_port = (
|
||||
reconfig_request.new_data_parallel_master_port
|
||||
)
|
||||
parallel_config._data_parallel_master_port_list = (
|
||||
reconfig_request.new_data_parallel_master_port_list
|
||||
)
|
||||
parallel_config._stateless_world_group_port_list = (
|
||||
reconfig_request.new_stateless_world_group_port_list
|
||||
)
|
||||
parallel_config._stateless_dp_group_port_list = (
|
||||
reconfig_request.new_stateless_dp_group_port_list
|
||||
)
|
||||
parallel_config._stateless_ep_group_port_list = (
|
||||
reconfig_request.new_stateless_ep_group_port_list
|
||||
)
|
||||
parallel_config._stateless_eplb_group_port_list = (
|
||||
reconfig_request.new_stateless_eplb_group_port_list
|
||||
)
|
||||
@@ -0,0 +1,117 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.distributed.parallel_state import (
|
||||
_init_stateless_group,
|
||||
_node_count,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
get_world_group,
|
||||
)
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
|
||||
_STANDBY_WORLD: StatelessGroupCoordinator | None = None
|
||||
_STANDBY_WORLD_NODE_COUNT: int | None = None
|
||||
_STANDBY_DP: StatelessGroupCoordinator | None = None
|
||||
_STANDBY_EP: StatelessGroupCoordinator | None = None
|
||||
_STANDBY_EPLB: StatelessGroupCoordinator | None = None
|
||||
|
||||
|
||||
def get_standby_dp_group() -> StatelessGroupCoordinator | None:
|
||||
return _STANDBY_DP
|
||||
|
||||
|
||||
def get_standby_ep_group() -> StatelessGroupCoordinator | None:
|
||||
return _STANDBY_EP
|
||||
|
||||
|
||||
def get_standby_eplb_group() -> StatelessGroupCoordinator | None:
|
||||
return _STANDBY_EPLB
|
||||
|
||||
|
||||
def get_standby_world_group() -> StatelessGroupCoordinator | None:
|
||||
return _STANDBY_WORLD
|
||||
|
||||
|
||||
def create_standby_groups(
|
||||
new_dp_size: int,
|
||||
new_world_size_across_dp: int,
|
||||
master_ip: str,
|
||||
world_group_ports: list[list[int]],
|
||||
dp_group_ports: list[list[int]],
|
||||
ep_group_ports: list[list[int]],
|
||||
eplb_group_ports: list[list[int]] | None = None,
|
||||
backend: str | None = None,
|
||||
) -> None:
|
||||
global \
|
||||
_STANDBY_WORLD, \
|
||||
_STANDBY_WORLD_NODE_COUNT, \
|
||||
_STANDBY_DP, \
|
||||
_STANDBY_EP, \
|
||||
_STANDBY_EPLB
|
||||
|
||||
assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size
|
||||
world_group = get_world_group()
|
||||
assert isinstance(world_group, StatelessGroupCoordinator)
|
||||
backend = backend or world_group.backend
|
||||
|
||||
standby_world_ranks = [list(range(new_world_size_across_dp))]
|
||||
_STANDBY_WORLD = _init_stateless_group(
|
||||
standby_world_ranks,
|
||||
"world",
|
||||
world_group_ports,
|
||||
master_ip,
|
||||
backend,
|
||||
use_device_communicator=False,
|
||||
)
|
||||
_STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group)
|
||||
|
||||
tp_size = get_tp_group().world_size
|
||||
pp_size = get_pp_group().world_size
|
||||
|
||||
all_ranks = torch.arange(new_world_size_across_dp).reshape(
|
||||
-1, new_dp_size, pp_size, tp_size
|
||||
)
|
||||
standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0)
|
||||
standby_dp_ranks = [x.tolist() for x in standby_dp_ranks]
|
||||
_STANDBY_DP = _init_stateless_group(
|
||||
standby_dp_ranks, "dp", dp_group_ports, master_ip, backend
|
||||
)
|
||||
|
||||
standby_ep_ranks = (
|
||||
all_ranks.transpose(1, 2).reshape(-1, new_dp_size * tp_size).unbind(0)
|
||||
)
|
||||
standby_ep_ranks = [x.tolist() for x in standby_ep_ranks]
|
||||
_STANDBY_EP = _init_stateless_group(
|
||||
standby_ep_ranks, "ep", ep_group_ports, master_ip, backend
|
||||
)
|
||||
|
||||
if eplb_group_ports is not None:
|
||||
_STANDBY_EPLB = _init_stateless_group(
|
||||
standby_ep_ranks, "eplb", eplb_group_ports, master_ip, backend
|
||||
)
|
||||
|
||||
|
||||
def pop_standby_groups() -> dict:
|
||||
"""Return all standby groups and clear the standby state."""
|
||||
global \
|
||||
_STANDBY_WORLD, \
|
||||
_STANDBY_WORLD_NODE_COUNT, \
|
||||
_STANDBY_DP, \
|
||||
_STANDBY_EP, \
|
||||
_STANDBY_EPLB
|
||||
|
||||
result = dict(
|
||||
world=_STANDBY_WORLD,
|
||||
dp=_STANDBY_DP,
|
||||
ep=_STANDBY_EP,
|
||||
eplb=_STANDBY_EPLB,
|
||||
node_count=_STANDBY_WORLD_NODE_COUNT,
|
||||
)
|
||||
_STANDBY_WORLD = None
|
||||
_STANDBY_WORLD_NODE_COUNT = None
|
||||
_STANDBY_DP = None
|
||||
_STANDBY_EP = None
|
||||
_STANDBY_EPLB = None
|
||||
return result
|
||||
@@ -24,7 +24,6 @@ logger = init_logger(__name__)
|
||||
|
||||
def start_async_worker(
|
||||
state: "EplbState",
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
is_profile: bool = False,
|
||||
) -> threading.Thread:
|
||||
eplb_group = get_eplb_group().device_group
|
||||
@@ -45,7 +44,6 @@ def start_async_worker(
|
||||
eplb_group=eplb_group,
|
||||
cuda_stream=cuda_stream,
|
||||
is_profile=is_profile,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - diagnostic path
|
||||
@@ -107,7 +105,6 @@ async def transfer_run_periodically(
|
||||
eplb_group: ProcessGroup,
|
||||
cuda_stream: torch.cuda.Stream,
|
||||
is_profile: bool = False,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
) -> None:
|
||||
while True:
|
||||
await asyncio.to_thread(state.rearrange_event.wait)
|
||||
@@ -176,7 +173,6 @@ async def transfer_run_periodically(
|
||||
ep_group=eplb_group,
|
||||
is_profile=is_profile,
|
||||
cuda_stream=cuda_stream,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
event = torch.cuda.Event(blocking=False)
|
||||
cuda_stream.record_event(event)
|
||||
|
||||
@@ -40,6 +40,7 @@ from vllm.distributed.parallel_state import (
|
||||
get_node_count,
|
||||
in_the_same_node_as,
|
||||
)
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import MixtureOfExperts
|
||||
@@ -302,6 +303,14 @@ class EplbState:
|
||||
"""
|
||||
CUDA device index for the async EPLB worker thread.
|
||||
"""
|
||||
self.num_valid_physical_experts: int = 0
|
||||
"""
|
||||
Number of valid physical experts.
|
||||
This is the number of physical experts that are
|
||||
actually mapped to logical experts. In elastic EP,
|
||||
newly started EP ranks may not have physical experts
|
||||
mapped yet.
|
||||
"""
|
||||
if self.device.type == "cuda":
|
||||
self.cuda_device_index = self.device.index
|
||||
if self.cuda_device_index is None and torch.cuda.is_available():
|
||||
@@ -367,9 +376,6 @@ class EplbState:
|
||||
self,
|
||||
model: MixtureOfExperts,
|
||||
model_config: ModelConfig,
|
||||
global_expert_load: torch.Tensor | None = None,
|
||||
old_global_expert_indices: torch.Tensor | None = None,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
):
|
||||
"""
|
||||
Build the initial EPLB state.
|
||||
@@ -462,75 +468,15 @@ class EplbState:
|
||||
)
|
||||
self.expert_rearrangement_step_interval = eplb_step_interval
|
||||
|
||||
# Set the policy based on the selected eplb algorithm type.
|
||||
policy_type = self.parallel_config.eplb_config.policy
|
||||
self.policy = EPLB_POLICIES[policy_type]
|
||||
logger.debug("Selected EPLB policy: %s", policy_type)
|
||||
if global_expert_load is not None:
|
||||
ep_group = get_ep_group().device_group
|
||||
assert global_expert_load.shape == (
|
||||
model.num_moe_layers,
|
||||
model.num_logical_experts,
|
||||
)
|
||||
assert global_expert_load.dtype == torch.int64
|
||||
|
||||
num_replicas = model.num_physical_experts
|
||||
num_groups = model.num_expert_groups
|
||||
num_nodes = get_node_count()
|
||||
num_gpus = ep_group.size()
|
||||
|
||||
if num_gpus % num_nodes != 0:
|
||||
num_nodes = 1
|
||||
logger.warning_once(
|
||||
f"num_gpus % num_nodes != 0, "
|
||||
"not using hierarchical rearrangement algorithm.\n"
|
||||
f"{num_gpus=}, {num_nodes=}"
|
||||
)
|
||||
|
||||
# Get new expert mappings
|
||||
(
|
||||
new_physical_to_logical_map,
|
||||
new_logical_to_physical_map,
|
||||
new_logical_replica_count,
|
||||
) = self.policy.rebalance_experts(
|
||||
global_expert_load,
|
||||
num_replicas,
|
||||
num_groups,
|
||||
num_nodes,
|
||||
num_gpus,
|
||||
)
|
||||
|
||||
max_physical_slots = new_logical_to_physical_map.shape[-1]
|
||||
assert max_physical_slots <= logical_to_physical_map.shape[-1]
|
||||
new_logical_to_physical_map = torch.nn.functional.pad(
|
||||
new_logical_to_physical_map,
|
||||
(0, logical_to_physical_map.shape[-1] - max_physical_slots),
|
||||
value=-1,
|
||||
)
|
||||
physical_to_logical_map = new_physical_to_logical_map.to(self.device)
|
||||
logical_to_physical_map.copy_(new_logical_to_physical_map)
|
||||
logical_replica_count.copy_(new_logical_replica_count)
|
||||
else:
|
||||
new_physical_to_logical_map = None
|
||||
|
||||
new_logical_to_physical_map = None
|
||||
|
||||
new_logical_replica_count = None
|
||||
model.set_eplb_state(
|
||||
expert_load_pass,
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
if global_expert_load is not None:
|
||||
rearrange_expert_weights_inplace(
|
||||
old_global_expert_indices,
|
||||
new_physical_to_logical_map,
|
||||
model.expert_weights,
|
||||
ep_group,
|
||||
False,
|
||||
rank_mapping,
|
||||
)
|
||||
self.expert_rearrangement_step = 0
|
||||
|
||||
expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]]
|
||||
|
||||
@@ -561,11 +507,12 @@ class EplbState:
|
||||
recv_dst_rows=np.array([]),
|
||||
),
|
||||
cuda_device_index=self.cuda_device_index,
|
||||
new_physical_to_logical_map=new_physical_to_logical_map,
|
||||
new_logical_to_physical_map=new_logical_to_physical_map,
|
||||
new_logical_replica_count=new_logical_replica_count,
|
||||
new_physical_to_logical_map=None,
|
||||
new_logical_to_physical_map=None,
|
||||
new_logical_replica_count=None,
|
||||
)
|
||||
self.model_states[model_config.compute_hash()] = model_state
|
||||
self.num_valid_physical_experts = model.num_physical_experts
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -696,8 +643,6 @@ class EplbState:
|
||||
def rearrange(
|
||||
self,
|
||||
is_profile: bool = False,
|
||||
execute_shuffle: bool = True,
|
||||
global_expert_loads: list[torch.Tensor] | None = None,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
@@ -707,12 +652,6 @@ class EplbState:
|
||||
is_profile (bool): If `True`, perform a dummy rearrangement.
|
||||
This is used in `profile_run` to reserve enough memory,
|
||||
no memory movement will be performed. Default is False.
|
||||
execute_shuffle (bool): If `True`, execute the shuffle
|
||||
in elastic expert parallel (EEP). Default is True.
|
||||
global_expert_loads (list[torch.Tensor] | None): The global expert
|
||||
loads when scaling is done in EEP.
|
||||
List of expert loads for the main and drafter
|
||||
(when spec decode is used) models.
|
||||
rank_mapping (dict[int, int] | None): The rank mapping
|
||||
when scaling is done in EEP.
|
||||
"""
|
||||
@@ -734,67 +673,34 @@ class EplbState:
|
||||
"(profile)" if is_profile else "",
|
||||
)
|
||||
|
||||
if global_expert_loads is None:
|
||||
# Map the physical expert load to global logical experts
|
||||
global_expert_load_windows = []
|
||||
if not execute_shuffle:
|
||||
num_models = torch.tensor(
|
||||
[len(self.model_states)], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
torch.distributed.broadcast(
|
||||
num_models, group=get_ep_group().cpu_group, group_src=0
|
||||
)
|
||||
|
||||
for eplb_model_state in self.model_states.values():
|
||||
logical_expert_load_window = torch.zeros(
|
||||
self.expert_load_window_size,
|
||||
eplb_model_state.model.num_moe_layers,
|
||||
eplb_model_state.model.num_logical_experts,
|
||||
dtype=eplb_model_state.expert_load_window.dtype,
|
||||
device=eplb_model_state.expert_load_window.device,
|
||||
)
|
||||
logical_expert_load_window.scatter_add_(
|
||||
dim=-1,
|
||||
index=eplb_model_state.physical_to_logical_map.unsqueeze(0)
|
||||
.expand_as(eplb_model_state.expert_load_window)
|
||||
.long(),
|
||||
src=eplb_model_state.expert_load_window,
|
||||
)
|
||||
|
||||
if not execute_shuffle:
|
||||
metadata = torch.tensor(
|
||||
[
|
||||
eplb_model_state.model.num_moe_layers,
|
||||
eplb_model_state.model.num_logical_experts,
|
||||
eplb_model_state.physical_to_logical_map.shape[1],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
torch.distributed.broadcast(
|
||||
metadata, group=get_ep_group().cpu_group, group_src=0
|
||||
)
|
||||
|
||||
global_expert_load_window = logical_expert_load_window.sum(dim=0)
|
||||
global_expert_load_windows.append(global_expert_load_window)
|
||||
# Perform all-reduce to get the expert load across all ranks for each model
|
||||
global_expert_load_windows = self._allreduce_list(
|
||||
global_expert_load_windows
|
||||
# Map the physical expert load to global logical experts
|
||||
global_expert_load_windows = []
|
||||
for eplb_model_state in self.model_states.values():
|
||||
expert_load_window = eplb_model_state.expert_load_window[
|
||||
:, :, : self.num_valid_physical_experts
|
||||
]
|
||||
logical_expert_load_window = torch.zeros(
|
||||
self.expert_load_window_size,
|
||||
eplb_model_state.model.num_moe_layers,
|
||||
eplb_model_state.model.num_logical_experts,
|
||||
dtype=eplb_model_state.expert_load_window.dtype,
|
||||
device=eplb_model_state.expert_load_window.device,
|
||||
)
|
||||
if not execute_shuffle:
|
||||
for eplb_model_state, global_expert_load_window in zip(
|
||||
self.model_states.values(), global_expert_load_windows
|
||||
):
|
||||
# (num_moe_layers, old_num_physical_experts)
|
||||
old_global_expert_indices = eplb_model_state.physical_to_logical_map
|
||||
torch.distributed.broadcast(
|
||||
old_global_expert_indices, group=ep_group, group_src=0
|
||||
)
|
||||
if not execute_shuffle:
|
||||
return global_expert_load_windows
|
||||
else:
|
||||
assert execute_shuffle
|
||||
global_expert_load_windows = global_expert_loads
|
||||
logical_expert_load_window.scatter_add_(
|
||||
dim=-1,
|
||||
index=eplb_model_state.physical_to_logical_map[
|
||||
:, : self.num_valid_physical_experts
|
||||
]
|
||||
.unsqueeze(0)
|
||||
.expand_as(expert_load_window)
|
||||
.long(),
|
||||
src=expert_load_window,
|
||||
)
|
||||
|
||||
global_expert_load_window = logical_expert_load_window.sum(dim=0)
|
||||
global_expert_load_windows.append(global_expert_load_window)
|
||||
# Perform all-reduce to get the expert load across all ranks for each model
|
||||
global_expert_load_windows = self._allreduce_list(global_expert_load_windows)
|
||||
|
||||
# TODO(bowen): Treat differently for prefill and decode nodes
|
||||
eplb_model_state = next(iter(self.model_states.values()))
|
||||
@@ -806,8 +712,10 @@ class EplbState:
|
||||
# NOTE(yongji): scale down, we need to rebalance the experts on
|
||||
# remaining GPUs, transfer the experts while we haven't shutdown
|
||||
# the GPUs to be released.
|
||||
cpu_group = get_ep_group().cpu_group
|
||||
num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping)
|
||||
coordinator = get_ep_group()
|
||||
assert isinstance(coordinator, StatelessGroupCoordinator)
|
||||
tcp_store_group = coordinator.tcp_store_group
|
||||
num_nodes = _node_count_with_rank_mapping(tcp_store_group, rank_mapping)
|
||||
num_gpus = sum(new_rank != -1 for new_rank in rank_mapping.values())
|
||||
num_replicas = (
|
||||
num_replicas // ep_group.size() * num_gpus
|
||||
@@ -933,7 +841,6 @@ class EplbState:
|
||||
if self.async_worker is None:
|
||||
self.async_worker = start_async_worker(
|
||||
self,
|
||||
rank_mapping=rank_mapping,
|
||||
is_profile=is_profile,
|
||||
)
|
||||
|
||||
@@ -1089,83 +996,6 @@ class EplbState:
|
||||
model_state.new_logical_to_physical_map = None
|
||||
model_state.new_logical_replica_count = None
|
||||
|
||||
@staticmethod
|
||||
def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
||||
"""
|
||||
Receive the expert load and old placement from the master rank.
|
||||
"""
|
||||
ep_group = get_ep_group()
|
||||
num_models = torch.empty(1, dtype=torch.int32, device="cpu")
|
||||
torch.distributed.broadcast(num_models, group=ep_group.cpu_group, group_src=0)
|
||||
num_models = num_models.item()
|
||||
global_expert_loads = []
|
||||
old_global_expert_indices_per_model = []
|
||||
for _ in range(num_models):
|
||||
metadata = torch.empty(3, dtype=torch.int32, device="cpu")
|
||||
torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0)
|
||||
num_moe_layers, num_logical_experts, num_old_physical_experts = (
|
||||
metadata.tolist()
|
||||
)
|
||||
global_expert_load = torch.zeros(
|
||||
(num_moe_layers, num_logical_experts),
|
||||
dtype=torch.int64,
|
||||
device=ep_group.device,
|
||||
)
|
||||
all_reduce(global_expert_load, group=ep_group.device_group)
|
||||
old_global_expert_indices = torch.empty(
|
||||
(num_moe_layers, num_old_physical_experts),
|
||||
dtype=torch.int64,
|
||||
device=ep_group.device,
|
||||
)
|
||||
torch.distributed.broadcast(
|
||||
old_global_expert_indices,
|
||||
group=ep_group.device_group,
|
||||
group_src=0,
|
||||
)
|
||||
global_expert_loads.append(global_expert_load)
|
||||
old_global_expert_indices_per_model.append(old_global_expert_indices)
|
||||
return global_expert_loads, old_global_expert_indices_per_model
|
||||
|
||||
@classmethod
|
||||
def get_eep_state(
|
||||
cls, parallel_config: ParallelConfig
|
||||
) -> tuple[
|
||||
list[torch.Tensor] | None,
|
||||
list[torch.Tensor] | None,
|
||||
dict[int, int] | None,
|
||||
]:
|
||||
num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu")
|
||||
torch.distributed.broadcast(
|
||||
num_local_physical_experts,
|
||||
group=get_ep_group().cpu_group,
|
||||
group_src=0,
|
||||
)
|
||||
num_local_physical_experts = int(num_local_physical_experts.item())
|
||||
new_ep_size = get_ep_group().world_size
|
||||
global_expert_loads, old_global_expert_indices_per_model = (
|
||||
EplbState.recv_state()
|
||||
)
|
||||
|
||||
# EP configuration for all models has to be the same so as eplb config
|
||||
num_logical_experts = global_expert_loads[0].shape[1]
|
||||
parallel_config.eplb_config.num_redundant_experts = (
|
||||
num_local_physical_experts * new_ep_size - num_logical_experts
|
||||
)
|
||||
assert (
|
||||
old_global_expert_indices_per_model[0].shape[1] % num_local_physical_experts
|
||||
== 0
|
||||
)
|
||||
old_ep_size = (
|
||||
old_global_expert_indices_per_model[0].shape[1]
|
||||
// num_local_physical_experts
|
||||
)
|
||||
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
|
||||
return (
|
||||
global_expert_loads,
|
||||
old_global_expert_indices_per_model,
|
||||
rank_mapping,
|
||||
)
|
||||
|
||||
def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]:
|
||||
"""
|
||||
All-reduce a list of tensors.
|
||||
@@ -1203,6 +1033,60 @@ class EplbState:
|
||||
load_pass_list.append(eplb_model_state.expert_load_pass.clone())
|
||||
return self._allreduce_list(load_pass_list)
|
||||
|
||||
@classmethod
|
||||
def from_mapping(
|
||||
cls,
|
||||
model: MixtureOfExperts,
|
||||
model_config: ModelConfig,
|
||||
device: torch.device,
|
||||
parallel_config: ParallelConfig,
|
||||
expanded_physical_to_logical: torch.Tensor,
|
||||
num_valid_physical_experts: int,
|
||||
) -> "EplbState":
|
||||
eplb_state = cls(
|
||||
parallel_config=parallel_config,
|
||||
device=device,
|
||||
)
|
||||
eplb_state.add_model(
|
||||
model=model,
|
||||
model_config=model_config,
|
||||
)
|
||||
eplb_state.num_valid_physical_experts = num_valid_physical_experts
|
||||
num_moe_layers = expanded_physical_to_logical.shape[0]
|
||||
num_physical_experts = expanded_physical_to_logical.shape[1]
|
||||
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
|
||||
eplb_model_state.physical_to_logical_map.copy_(expanded_physical_to_logical)
|
||||
|
||||
logical_to_physical_map = torch.full(
|
||||
(
|
||||
num_moe_layers,
|
||||
model.num_logical_experts,
|
||||
eplb_model_state.logical_to_physical_map.shape[2],
|
||||
),
|
||||
-1,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
logical_replica_count = torch.zeros(
|
||||
(num_moe_layers, model.num_logical_experts),
|
||||
dtype=torch.int64,
|
||||
)
|
||||
expanded_physical_to_logical_numpy = expanded_physical_to_logical.cpu().numpy()
|
||||
for layer_idx in range(num_moe_layers):
|
||||
for phys_idx in range(num_physical_experts):
|
||||
logical_idx = expanded_physical_to_logical_numpy[layer_idx, phys_idx]
|
||||
if logical_idx >= 0:
|
||||
replica_idx = logical_replica_count[layer_idx, logical_idx]
|
||||
logical_to_physical_map[layer_idx, logical_idx, replica_idx] = (
|
||||
phys_idx
|
||||
)
|
||||
logical_replica_count[layer_idx, logical_idx] += 1
|
||||
|
||||
logical_to_physical_map = logical_to_physical_map.to(device)
|
||||
logical_replica_count = logical_replica_count.to(device)
|
||||
eplb_model_state.logical_to_physical_map.copy_(logical_to_physical_map)
|
||||
eplb_model_state.logical_replica_count.copy_(logical_replica_count)
|
||||
return eplb_state
|
||||
|
||||
|
||||
@dataclass
|
||||
class EplbLayerState:
|
||||
|
||||
@@ -19,6 +19,8 @@ from torch.distributed import (
|
||||
get_global_rank,
|
||||
)
|
||||
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -249,10 +251,18 @@ def move_to_buffer(
|
||||
b[dst].copy_(w[src_local], non_blocking=True)
|
||||
|
||||
p2p_ops: list[P2POp] = []
|
||||
if isinstance(get_ep_group(), StatelessGroupCoordinator):
|
||||
ep_group = get_ep_group()
|
||||
is_stateless = True
|
||||
else:
|
||||
is_stateless = False
|
||||
|
||||
# Pre-compute global ranks mapping
|
||||
# Pre-compute global ranks mapping (only needed for non-stateless groups)
|
||||
ep_size = ep_group.size()
|
||||
rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)}
|
||||
if not is_stateless:
|
||||
rank_to_global = {
|
||||
rank: get_global_rank(ep_group, rank) for rank in range(ep_size)
|
||||
}
|
||||
|
||||
# 2. Post sends
|
||||
if send_count > 0:
|
||||
@@ -284,15 +294,23 @@ def move_to_buffer(
|
||||
if recver_pos < len(ranks_to_recv):
|
||||
recv_ranks.append(ranks_to_recv[recver_pos])
|
||||
for dst in recv_ranks:
|
||||
dst_global = rank_to_global[dst]
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.isend,
|
||||
w[src],
|
||||
dst_global,
|
||||
)
|
||||
for w in expert_weights
|
||||
]
|
||||
if is_stateless:
|
||||
for w in expert_weights:
|
||||
op = object.__new__(P2POp)
|
||||
op.op = torch.distributed.isend
|
||||
op.tensor = w[src]
|
||||
op.group_peer = dst
|
||||
p2p_ops.append(op)
|
||||
else:
|
||||
dst_global = rank_to_global[dst]
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.isend,
|
||||
w[src],
|
||||
dst_global,
|
||||
)
|
||||
for w in expert_weights
|
||||
]
|
||||
|
||||
# 3. Post recvs
|
||||
if recv_count > 0:
|
||||
@@ -321,26 +339,40 @@ def move_to_buffer(
|
||||
src = ranks_to_send[recver_pos // num_dst_per_sender]
|
||||
else:
|
||||
src = ranks_to_send[recver_pos - remainder_start]
|
||||
src_global = rank_to_global[src]
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.irecv,
|
||||
b[dst],
|
||||
src_global,
|
||||
)
|
||||
for b in expert_weights_buffers
|
||||
]
|
||||
if is_stateless:
|
||||
for b in expert_weights_buffers:
|
||||
op = object.__new__(P2POp)
|
||||
op.op = torch.distributed.irecv
|
||||
op.tensor = b[dst]
|
||||
op.group_peer = src
|
||||
p2p_ops.append(op)
|
||||
else:
|
||||
src_global = rank_to_global[src]
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.irecv,
|
||||
b[dst],
|
||||
src_global,
|
||||
)
|
||||
for b in expert_weights_buffers
|
||||
]
|
||||
|
||||
# 4. Execute the P2P operations. The real communication happens here.
|
||||
if p2p_ops and cuda_stream is not None:
|
||||
with torch.cuda.stream(cuda_stream):
|
||||
if is_stateless:
|
||||
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
|
||||
else:
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
elif p2p_ops:
|
||||
if is_stateless:
|
||||
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
|
||||
else:
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
elif p2p_ops:
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
# wait for the communication to finish
|
||||
return (
|
||||
is_unchanged,
|
||||
|
||||
@@ -33,7 +33,7 @@ from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Any, Protocol
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@@ -55,6 +55,9 @@ from vllm.utils.torch_utils import (
|
||||
direct_register_custom_op,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphCaptureContext:
|
||||
@@ -1157,6 +1160,55 @@ def init_model_parallel_group(
|
||||
)
|
||||
|
||||
|
||||
def _init_stateless_group(
|
||||
group_ranks: list[list[int]],
|
||||
group_name: str,
|
||||
group_ports: list[list[int]],
|
||||
host: str,
|
||||
backend: str,
|
||||
use_device_communicator: bool = True,
|
||||
) -> "StatelessGroupCoordinator":
|
||||
"""Create a StatelessGroupCoordinator with the given parameters."""
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
|
||||
world = get_world_group()
|
||||
return StatelessGroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=world.local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_device_communicator=use_device_communicator,
|
||||
group_name=group_name,
|
||||
host=host,
|
||||
group_ports=group_ports,
|
||||
global_rank=world.rank,
|
||||
global_world_size=world.world_size,
|
||||
)
|
||||
|
||||
|
||||
def _replace_active_groups(
|
||||
*,
|
||||
world: GroupCoordinator | None,
|
||||
dp: GroupCoordinator | None,
|
||||
ep: GroupCoordinator | None,
|
||||
eplb: GroupCoordinator | None,
|
||||
node_count: int | None,
|
||||
) -> None:
|
||||
"""Destroy the current DP/EP/WORLD/EPLB groups and replace them.
|
||||
|
||||
Destruction is collective — all ranks in the old groups must call this
|
||||
function together. Pass all-``None`` to tear down without replacement.
|
||||
"""
|
||||
global _WORLD, _DP, _EP, _EPLB, _NODE_COUNT
|
||||
for group in (_DP, _EP, _WORLD, _EPLB):
|
||||
if group is not None:
|
||||
group.destroy()
|
||||
_WORLD = world
|
||||
_DP = dp
|
||||
_EP = ep
|
||||
_EPLB = eplb
|
||||
_NODE_COUNT = node_count
|
||||
|
||||
|
||||
_TP: GroupCoordinator | None = None
|
||||
|
||||
|
||||
@@ -1254,6 +1306,39 @@ def set_custom_all_reduce(enable: bool):
|
||||
_ENABLE_CUSTOM_ALL_REDUCE = enable
|
||||
|
||||
|
||||
def _init_elastic_ep_world(
|
||||
config, local_rank: int, backend: str, rank: int, world_size: int
|
||||
) -> None:
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
|
||||
global _WORLD, _NODE_COUNT
|
||||
assert _WORLD is None, "world group already initialized"
|
||||
parallel_config = config.parallel_config
|
||||
global_rank = parallel_config.data_parallel_rank * world_size + rank
|
||||
global_world_size = parallel_config.world_size_across_dp
|
||||
all_ranks = list(range(global_world_size))
|
||||
group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)]
|
||||
if global_rank in all_ranks:
|
||||
group_ranks = [all_ranks]
|
||||
group_ports = [parallel_config.get_next_stateless_world_group_port()]
|
||||
world = StatelessGroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_device_communicator=False,
|
||||
group_name="world",
|
||||
host=parallel_config.data_parallel_master_ip,
|
||||
group_ports=group_ports,
|
||||
global_rank=global_rank,
|
||||
global_world_size=global_world_size,
|
||||
)
|
||||
assert parallel_config.nnodes_within_dp == 1, (
|
||||
"Elastic EP is not supported with multi-node TP/PP"
|
||||
)
|
||||
_NODE_COUNT = _node_count(world.tcp_store_group)
|
||||
_WORLD = world
|
||||
|
||||
|
||||
def init_distributed_environment(
|
||||
world_size: int = -1,
|
||||
rank: int = -1,
|
||||
@@ -1273,6 +1358,7 @@ def init_distributed_environment(
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
|
||||
config = get_current_vllm_config_or_none()
|
||||
enable_elastic_ep = config is not None and config.parallel_config.enable_elastic_ep
|
||||
if (
|
||||
config is not None
|
||||
and config.parallel_config.distributed_executor_backend != "external_launcher"
|
||||
@@ -1280,6 +1366,7 @@ def init_distributed_environment(
|
||||
config.parallel_config.nnodes > 1
|
||||
or config.parallel_config.data_parallel_size > 1
|
||||
)
|
||||
and not enable_elastic_ep
|
||||
):
|
||||
parallel_config = config.parallel_config
|
||||
# adjust to take into account data parallelism
|
||||
@@ -1333,6 +1420,18 @@ def init_distributed_environment(
|
||||
rank=rank,
|
||||
timeout=timeout,
|
||||
)
|
||||
if enable_elastic_ep:
|
||||
tp_pp_cpu_group = torch.distributed.new_group(
|
||||
backend="gloo", timeout=timeout
|
||||
)
|
||||
if _node_count(tp_pp_cpu_group) > 1:
|
||||
# NOTE(yongji): StatelessGroupCoordinator uses data_parallel_master_ip
|
||||
# to initialize all DP/EP groups, hence all ranks within TP/PP group
|
||||
# must reside on the same node
|
||||
raise RuntimeError(
|
||||
"Elastic EP is not yet supported with multi-node TP/PP"
|
||||
)
|
||||
|
||||
# set the local rank
|
||||
# local_rank is not available in torch ProcessGroup,
|
||||
# see https://github.com/pytorch/pytorch/issues/122816
|
||||
@@ -1341,6 +1440,9 @@ def init_distributed_environment(
|
||||
# setting, where we can use rank as local rank
|
||||
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
|
||||
global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
|
||||
if enable_elastic_ep:
|
||||
_init_elastic_ep_world(config, local_rank, backend, rank, world_size)
|
||||
return
|
||||
if _WORLD is None:
|
||||
ranks = list(range(torch.distributed.get_world_size()))
|
||||
_WORLD = init_world_group(ranks, local_rank, backend)
|
||||
@@ -1404,16 +1506,33 @@ def initialize_model_parallel(
|
||||
"""
|
||||
# Get world size and rank. Ensure some consistencies.
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size: int = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
|
||||
|
||||
data_parallel_size = 1
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
config = get_current_vllm_config_or_none()
|
||||
if config is not None:
|
||||
data_parallel_size = config.parallel_config.data_parallel_size
|
||||
config = get_current_vllm_config()
|
||||
data_parallel_size = config.parallel_config.data_parallel_size
|
||||
enable_elastic_ep = config.parallel_config.enable_elastic_ep
|
||||
if enable_elastic_ep:
|
||||
# Use stateless world group for global information
|
||||
world_size = get_world_group().world_size
|
||||
rank = get_world_group().rank
|
||||
backend = backend or "nccl"
|
||||
tp_pp_pcp_size = (
|
||||
tensor_model_parallel_size
|
||||
* pipeline_model_parallel_size
|
||||
* prefill_context_model_parallel_size
|
||||
)
|
||||
local_all_ranks = torch.arange(tp_pp_pcp_size).reshape(
|
||||
pipeline_model_parallel_size,
|
||||
prefill_context_model_parallel_size,
|
||||
tensor_model_parallel_size,
|
||||
)
|
||||
else:
|
||||
world_size = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
backend = backend or torch.distributed.get_backend(
|
||||
get_world_group().device_group
|
||||
)
|
||||
|
||||
# the layout order is: ExternalDP x DP x PP x TP
|
||||
# ExternalDP is the data parallel group that is not part of the model,
|
||||
@@ -1437,7 +1556,9 @@ def initialize_model_parallel(
|
||||
assert _TP is None, "tensor model parallel group is already initialized"
|
||||
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
|
||||
if enable_elastic_ep:
|
||||
group_ranks = local_all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
# message queue broadcaster is only used in tensor model parallel group
|
||||
_TP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
@@ -1456,6 +1577,11 @@ def initialize_model_parallel(
|
||||
# TP group into tp_size//dcp_size DCP groups.
|
||||
group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
if enable_elastic_ep:
|
||||
group_ranks = local_all_ranks.reshape(
|
||||
-1, decode_context_model_parallel_size
|
||||
).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_DCP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
@@ -1472,6 +1598,13 @@ def initialize_model_parallel(
|
||||
.unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
if enable_elastic_ep:
|
||||
group_ranks = (
|
||||
local_all_ranks.transpose(1, 2)
|
||||
.reshape(-1, prefill_context_model_parallel_size)
|
||||
.unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_PCP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="pcp"
|
||||
)
|
||||
@@ -1483,6 +1616,13 @@ def initialize_model_parallel(
|
||||
all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
if enable_elastic_ep:
|
||||
group_ranks = (
|
||||
local_all_ranks.transpose(0, 2)
|
||||
.reshape(-1, pipeline_model_parallel_size)
|
||||
.unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_PP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="pp"
|
||||
)
|
||||
@@ -1491,14 +1631,27 @@ def initialize_model_parallel(
|
||||
assert _DP is None, "data parallel group is already initialized"
|
||||
group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_DP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="dp"
|
||||
)
|
||||
if enable_elastic_ep:
|
||||
parallel_config = config.parallel_config
|
||||
dp_ports = [
|
||||
parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks
|
||||
]
|
||||
_DP = _init_stateless_group(
|
||||
group_ranks,
|
||||
"dp",
|
||||
dp_ports,
|
||||
parallel_config.data_parallel_master_ip,
|
||||
backend,
|
||||
)
|
||||
else:
|
||||
_DP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="dp"
|
||||
)
|
||||
|
||||
global _EP
|
||||
assert _EP is None, "expert parallel group is already initialized"
|
||||
# Don't create EP group for dense models.
|
||||
if config is None or config.model_config is None or config.model_config.is_moe:
|
||||
if config.model_config is None or config.model_config.is_moe:
|
||||
group_ranks = (
|
||||
all_ranks.transpose(1, 2)
|
||||
.reshape(
|
||||
@@ -1510,9 +1663,22 @@ def initialize_model_parallel(
|
||||
.unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_EP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="ep"
|
||||
)
|
||||
if enable_elastic_ep:
|
||||
parallel_config = config.parallel_config
|
||||
ep_ports = [
|
||||
parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks
|
||||
]
|
||||
_EP = _init_stateless_group(
|
||||
group_ranks,
|
||||
"ep",
|
||||
ep_ports,
|
||||
parallel_config.data_parallel_master_ip,
|
||||
backend,
|
||||
)
|
||||
else:
|
||||
_EP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="ep"
|
||||
)
|
||||
|
||||
# Create EPLB group with the same ranks as EP if EPLB is enabled.
|
||||
# This is a separate process group to isolate EPLB communications
|
||||
@@ -1525,10 +1691,25 @@ def initialize_model_parallel(
|
||||
and config.parallel_config is not None
|
||||
and config.parallel_config.enable_eplb
|
||||
):
|
||||
# Reuse the same group_ranks from EP
|
||||
_EPLB = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="eplb"
|
||||
)
|
||||
if enable_elastic_ep:
|
||||
eplb_ports = [
|
||||
parallel_config.get_next_stateless_eplb_group_port()
|
||||
for _ in group_ranks
|
||||
]
|
||||
_EPLB = _init_stateless_group(
|
||||
group_ranks,
|
||||
"eplb",
|
||||
eplb_ports,
|
||||
parallel_config.data_parallel_master_ip,
|
||||
backend,
|
||||
)
|
||||
else:
|
||||
_EPLB = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="eplb",
|
||||
)
|
||||
# If no EP group needed, _EP remains None
|
||||
# If no EPLB group needed, _EPLB remains None
|
||||
|
||||
@@ -1558,7 +1739,11 @@ def ensure_model_parallel_initialized(
|
||||
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
|
||||
values if the model parallel groups are initialized.
|
||||
"""
|
||||
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
|
||||
world_group = get_world_group()
|
||||
if hasattr(world_group, "backend"):
|
||||
backend = backend or world_group.backend
|
||||
else:
|
||||
backend = backend or torch.distributed.get_backend(world_group.device_group)
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size,
|
||||
|
||||
@@ -0,0 +1,322 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import Backend, ProcessGroup
|
||||
|
||||
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
|
||||
from vllm.distributed.parallel_state import (
|
||||
GroupCoordinator,
|
||||
TensorMetadata,
|
||||
_get_unique_name,
|
||||
_register_group,
|
||||
_split_tensor_dict,
|
||||
)
|
||||
from vllm.distributed.utils import (
|
||||
StatelessProcessGroup,
|
||||
stateless_destroy_torch_distributed_process_group,
|
||||
stateless_init_torch_distributed_process_group,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class StatelessGroupCoordinator(GroupCoordinator):
|
||||
"""
|
||||
A stateless version of the GroupCoordinator class in parallel_state,
|
||||
It will create CPU, device and TCPStore based communication groups
|
||||
that are independent of PyTorch's WORLD group. Hence,
|
||||
communication groups with a different set of participants GPUs
|
||||
can be created without destroying the existing ones.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group_ranks: list[list[int]],
|
||||
local_rank: int,
|
||||
torch_distributed_backend: str | Backend,
|
||||
use_device_communicator: bool,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: str | None = None,
|
||||
host: str = "127.0.0.1",
|
||||
group_ports: list[list[int]] | None = None,
|
||||
global_rank: int = 0,
|
||||
global_world_size: int = 1,
|
||||
):
|
||||
group_name = group_name or "anonymous"
|
||||
self.unique_name = _get_unique_name(group_name)
|
||||
_register_group(self)
|
||||
|
||||
self.rank = global_rank
|
||||
self.local_rank = local_rank
|
||||
|
||||
self_device_group = None
|
||||
self_cpu_group = None
|
||||
self_tcp_store_group = None
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
backend = str(torch_distributed_backend)
|
||||
self.backend = backend
|
||||
assert group_ports is not None, "group_ports is not provided"
|
||||
for idx, ranks in enumerate(group_ranks):
|
||||
if self.rank in ranks:
|
||||
self.ranks = ranks
|
||||
self.world_size = len(ranks)
|
||||
self.rank_in_group = ranks.index(self.rank)
|
||||
|
||||
ports = group_ports[idx]
|
||||
device_port = ports[0]
|
||||
cpu_port = ports[1]
|
||||
tcp_store_port = ports[2]
|
||||
|
||||
device_group = stateless_init_torch_distributed_process_group(
|
||||
host=host,
|
||||
port=device_port,
|
||||
rank=self.rank_in_group,
|
||||
world_size=self.world_size,
|
||||
backend=backend,
|
||||
group_name=f"{self.unique_name}_device",
|
||||
)
|
||||
cpu_group = stateless_init_torch_distributed_process_group(
|
||||
host=host,
|
||||
port=cpu_port,
|
||||
rank=self.rank_in_group,
|
||||
world_size=self.world_size,
|
||||
backend="gloo",
|
||||
group_name=f"{self.unique_name}_cpu",
|
||||
)
|
||||
tcp_store_group = StatelessProcessGroup.create(
|
||||
host=host,
|
||||
port=tcp_store_port,
|
||||
rank=self.rank_in_group,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
|
||||
self_device_group = device_group
|
||||
self_cpu_group = cpu_group
|
||||
self_tcp_store_group = tcp_store_group
|
||||
|
||||
assert self_cpu_group is not None
|
||||
assert self_device_group is not None
|
||||
assert self_tcp_store_group is not None
|
||||
|
||||
self.cpu_group = self_cpu_group
|
||||
self.device_group = self_device_group
|
||||
self.tcp_store_group = self_tcp_store_group
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
self.device = torch.device(f"cuda:{local_rank}")
|
||||
elif current_platform.is_xpu():
|
||||
self.device = torch.device(f"xpu:{local_rank}")
|
||||
elif current_platform.is_out_of_tree():
|
||||
self.device = torch.device(f"{current_platform.device_name}:{local_rank}")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.use_device_communicator = use_device_communicator
|
||||
self.device_communicator = None
|
||||
if use_device_communicator and self.world_size > 1:
|
||||
device_comm_cls = resolve_obj_by_qualname(
|
||||
current_platform.get_device_communicator_cls()
|
||||
)
|
||||
assert device_comm_cls == CudaCommunicator
|
||||
self.device_communicator = CudaCommunicator(
|
||||
cpu_group=self.cpu_group,
|
||||
device=self.device,
|
||||
device_group=self.device_group,
|
||||
unique_name=self.unique_name,
|
||||
global_ranks=self.ranks,
|
||||
global_world_size=global_world_size,
|
||||
tcp_store_group=self.tcp_store_group,
|
||||
)
|
||||
|
||||
self.mq_broadcaster = None
|
||||
|
||||
self.use_custom_op_call = (
|
||||
current_platform.is_cuda_alike() or current_platform.is_tpu()
|
||||
)
|
||||
self.use_cpu_custom_send_recv = False
|
||||
|
||||
def destroy(self):
|
||||
if self.device_communicator:
|
||||
self.device_communicator.destroy()
|
||||
if self.device_group:
|
||||
stateless_destroy_torch_distributed_process_group(self.device_group)
|
||||
if self.cpu_group:
|
||||
stateless_destroy_torch_distributed_process_group(self.cpu_group)
|
||||
|
||||
def size(self) -> int:
|
||||
"""Return the world size of this group."""
|
||||
return self.world_size
|
||||
|
||||
def broadcast(self, input_: torch.Tensor, src: int = 0):
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
|
||||
if self.device_communicator and input_.is_cuda:
|
||||
return self.device_communicator.broadcast(input_, src)
|
||||
else:
|
||||
return self.tcp_store_group.broadcast(input_, src)
|
||||
|
||||
def broadcast_object(self, obj=None, src: int = 0):
|
||||
if self.world_size == 1:
|
||||
return obj
|
||||
return self.tcp_store_group.broadcast_obj(obj, src)
|
||||
|
||||
def broadcast_object_list(
|
||||
self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None
|
||||
):
|
||||
assert src < self.world_size
|
||||
|
||||
if self.world_size == 1:
|
||||
return obj_list
|
||||
|
||||
if self.rank_in_group == src:
|
||||
for obj in obj_list:
|
||||
self.tcp_store_group.broadcast_obj(obj, src)
|
||||
else:
|
||||
for i in range(len(obj_list)):
|
||||
obj_list[i] = self.tcp_store_group.broadcast_obj(None, src)
|
||||
|
||||
return obj_list
|
||||
|
||||
def broadcast_tensor_dict(
|
||||
self,
|
||||
tensor_dict: dict[str, torch.Tensor | Any] | None = None,
|
||||
src: int = 0,
|
||||
group: ProcessGroup | None = None,
|
||||
metadata_group: ProcessGroup | None = None,
|
||||
) -> dict[str, torch.Tensor | Any] | None:
|
||||
if self.world_size == 1:
|
||||
return tensor_dict
|
||||
|
||||
if self.rank_in_group == src:
|
||||
assert isinstance(tensor_dict, dict), (
|
||||
f"Expecting a dictionary, got {type(tensor_dict)}"
|
||||
)
|
||||
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
||||
else:
|
||||
metadata_list = None
|
||||
tensor_list = []
|
||||
|
||||
recv_metadata_list: list[tuple[str, Any]] = self.tcp_store_group.broadcast_obj(
|
||||
metadata_list, src
|
||||
)
|
||||
|
||||
if self.rank_in_group != src:
|
||||
tensor_dict = {}
|
||||
for key, value in recv_metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = torch.empty(
|
||||
value.size, dtype=value.dtype, device=value.device
|
||||
)
|
||||
tensor_list.append(tensor)
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
tensor_dict[key] = value
|
||||
|
||||
for tensor in tensor_list:
|
||||
if tensor.numel() == 0:
|
||||
continue
|
||||
if self.device_communicator and tensor.is_cuda:
|
||||
tensor.copy_(self.device_communicator.broadcast(tensor, src))
|
||||
else:
|
||||
tensor.copy_(self.tcp_store_group.broadcast(tensor, src))
|
||||
|
||||
return tensor_dict
|
||||
|
||||
def send_object(self, obj, dst: int) -> None:
|
||||
assert dst < self.world_size
|
||||
assert dst != self.rank_in_group
|
||||
self.tcp_store_group.send_obj(obj, dst)
|
||||
|
||||
def recv_object(self, src: int):
|
||||
assert src < self.world_size
|
||||
assert src != self.rank_in_group
|
||||
return self.tcp_store_group.recv_obj(src)
|
||||
|
||||
def send_tensor_dict(
|
||||
self,
|
||||
tensor_dict: dict[str, torch.Tensor | Any],
|
||||
dst: int | None = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
all_gather_tensors: dict[str, bool] | None = None,
|
||||
) -> dict[str, torch.Tensor | Any] | None:
|
||||
if self.world_size == 1:
|
||||
return tensor_dict
|
||||
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
assert dst < self.world_size
|
||||
|
||||
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
||||
self.tcp_store_group.send_obj(metadata_list, dst)
|
||||
|
||||
for tensor in tensor_list:
|
||||
if tensor.numel() == 0:
|
||||
continue
|
||||
if self.device_communicator and tensor.is_cuda:
|
||||
self.device_communicator.send(tensor, dst)
|
||||
else:
|
||||
self.tcp_store_group.send(tensor, dst)
|
||||
|
||||
return None
|
||||
|
||||
def recv_tensor_dict(
|
||||
self,
|
||||
src: int | None = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
all_gather_tensors: dict[str, bool] | None = None,
|
||||
) -> dict[str, torch.Tensor | Any] | None:
|
||||
if self.world_size == 1:
|
||||
return None
|
||||
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
assert src < self.world_size
|
||||
|
||||
recv_metadata_list = self.tcp_store_group.recv_obj(src)
|
||||
tensor_dict = {}
|
||||
for key, value in recv_metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
|
||||
if tensor.numel() > 0:
|
||||
if self.device_communicator and tensor.is_cuda:
|
||||
tensor = self.device_communicator.recv(
|
||||
tensor.size(), tensor.dtype, src
|
||||
)
|
||||
else:
|
||||
tensor = self.tcp_store_group.recv(tensor, src)
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
tensor_dict[key] = value
|
||||
return tensor_dict
|
||||
|
||||
def barrier(self):
|
||||
self.tcp_store_group.barrier()
|
||||
|
||||
def gather(
|
||||
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> torch.Tensor | None:
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
|
||||
if self.rank_in_group == dst:
|
||||
gathered_list = [torch.empty_like(input_) for _ in range(self.world_size)]
|
||||
gathered_list[self.rank_in_group] = input_
|
||||
for src_rank in range(self.world_size):
|
||||
if src_rank != self.rank_in_group:
|
||||
gathered_list[src_rank] = self.device_communicator.recv(
|
||||
input_.size(), input_.dtype, src_rank
|
||||
)
|
||||
return torch.cat(gathered_list, dim=dim)
|
||||
else:
|
||||
self.device_communicator.send(input_, dst)
|
||||
return None
|
||||
+79
-14
@@ -18,7 +18,7 @@ from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup, TCPStore
|
||||
from torch.distributed import ProcessGroup, Store, TCPStore
|
||||
from torch.distributed.distributed_c10d import (
|
||||
Backend,
|
||||
PrefixStore,
|
||||
@@ -228,6 +228,55 @@ class StatelessProcessGroup:
|
||||
gathered_objs.append(recv_obj)
|
||||
return gathered_objs
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
|
||||
"""Broadcast a tensor from source rank to all other ranks."""
|
||||
if self.rank == src:
|
||||
tensor_bytes = pickle.dumps(tensor)
|
||||
self.expire_data()
|
||||
key = f"broadcast_tensor/{src}/{self.broadcast_send_counter}"
|
||||
self.store.set(key, tensor_bytes)
|
||||
self.broadcast_send_counter += 1
|
||||
self.entries.append((key, time.time()))
|
||||
return tensor
|
||||
else:
|
||||
key = f"broadcast_tensor/{src}/{self.broadcast_recv_src_counter[src]}"
|
||||
tensor = pickle.loads(self.store.get(key))
|
||||
self.broadcast_recv_src_counter[src] += 1
|
||||
return tensor
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: int):
|
||||
"""Send a tensor to a destination rank."""
|
||||
self.expire_data()
|
||||
key = f"send_tensor/{dst}/{self.send_dst_counter[dst]}"
|
||||
self.store.set(key, pickle.dumps(tensor))
|
||||
self.send_dst_counter[dst] += 1
|
||||
self.entries.append((key, time.time()))
|
||||
|
||||
def recv(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
|
||||
"""Receive a tensor from a source rank."""
|
||||
key = f"send_tensor/{self.rank}/{self.recv_src_counter[src]}"
|
||||
received = pickle.loads(self.store.get(key))
|
||||
self.recv_src_counter[src] += 1
|
||||
tensor.copy_(received)
|
||||
return tensor
|
||||
|
||||
def all_reduce(
|
||||
self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM
|
||||
) -> torch.Tensor:
|
||||
"""All-reduce a tensor across all ranks."""
|
||||
tensors = self.all_gather_obj(tensor)
|
||||
result = tensors[0].clone()
|
||||
for t in tensors[1:]:
|
||||
if op == torch.distributed.ReduceOp.SUM:
|
||||
result.add_(t)
|
||||
elif op == torch.distributed.ReduceOp.PRODUCT:
|
||||
result.mul_(t)
|
||||
elif op == torch.distributed.ReduceOp.MAX:
|
||||
result = torch.maximum(result, t)
|
||||
elif op == torch.distributed.ReduceOp.MIN:
|
||||
result = torch.minimum(result, t)
|
||||
return result
|
||||
|
||||
def barrier(self, timeout: float = 30.0):
|
||||
"""A robust barrier to synchronize all ranks.
|
||||
|
||||
@@ -448,8 +497,14 @@ def init_gloo_process_group(
|
||||
|
||||
|
||||
def stateless_init_torch_distributed_process_group(
|
||||
host: str, port: int, rank: int, world_size: int, backend: str
|
||||
) -> ProcessGroup:
|
||||
host: str,
|
||||
port: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
backend: str,
|
||||
group_name: str | None = None,
|
||||
return_store: bool = False,
|
||||
) -> ProcessGroup | tuple[ProcessGroup, Store]:
|
||||
"""
|
||||
A replacement for `torch.distributed.init_process_group` that does not
|
||||
pollute the global state. The created ProcessGroup object can be used for
|
||||
@@ -496,25 +551,35 @@ def stateless_init_torch_distributed_process_group(
|
||||
# Use a PrefixStore to avoid accidental overrides of keys used by
|
||||
# different systems (e.g. RPC) in case the store is multi-tenant.
|
||||
prefix_store = PrefixStore(init_method, store)
|
||||
try:
|
||||
|
||||
if backend == "gloo":
|
||||
pg = init_gloo_process_group(
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return current_platform.stateless_init_device_torch_dist_pg(
|
||||
pg = current_platform.stateless_init_device_torch_dist_pg(
|
||||
backend=backend,
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
except NotImplementedError:
|
||||
# If platform doesn't implement stateless_init_device_torch_dist_pg, it
|
||||
# will raise a NotImplementedError. In this case, we fall back to gloo.
|
||||
return init_gloo_process_group(
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if group_name is not None:
|
||||
from torch._C._distributed_c10d import _register_process_group
|
||||
|
||||
pg._set_group_name(group_name)
|
||||
_register_process_group(group_name, pg)
|
||||
|
||||
if return_store:
|
||||
return pg, store
|
||||
else:
|
||||
return pg
|
||||
|
||||
|
||||
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:
|
||||
|
||||
@@ -419,6 +419,7 @@ class EngineArgs:
|
||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||
moe_backend: MoEBackend = KernelConfig.moe_backend
|
||||
all2all_backend: All2AllBackend = ParallelConfig.all2all_backend
|
||||
enable_elastic_ep: bool = ParallelConfig.enable_elastic_ep
|
||||
enable_dbo: bool = ParallelConfig.enable_dbo
|
||||
ubatch_size: int = ParallelConfig.ubatch_size
|
||||
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
|
||||
@@ -896,6 +897,9 @@ class EngineArgs:
|
||||
"--ubatch-size",
|
||||
**parallel_kwargs["ubatch_size"],
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--enable-elastic-ep", **parallel_kwargs["enable_elastic_ep"]
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--dbo-decode-token-threshold",
|
||||
**parallel_kwargs["dbo_decode_token_threshold"],
|
||||
@@ -1698,6 +1702,7 @@ class EngineArgs:
|
||||
is_moe_model=model_config.is_moe,
|
||||
enable_expert_parallel=self.enable_expert_parallel,
|
||||
all2all_backend=self.all2all_backend,
|
||||
enable_elastic_ep=self.enable_elastic_ep,
|
||||
enable_dbo=self.enable_dbo,
|
||||
ubatch_size=self.ubatch_size,
|
||||
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
|
||||
|
||||
@@ -246,8 +246,12 @@ def run_multi_api_server(args: argparse.Namespace):
|
||||
|
||||
api_server_manager: APIServerProcessManager | None = None
|
||||
|
||||
from vllm.v1.engine.utils import get_engine_zmq_addresses
|
||||
|
||||
addresses = get_engine_zmq_addresses(vllm_config, num_api_servers)
|
||||
|
||||
with launch_core_engines(
|
||||
vllm_config, executor_class, log_stats, num_api_servers
|
||||
vllm_config, executor_class, log_stats, addresses, num_api_servers
|
||||
) as (local_engine_manager, coordinator, addresses):
|
||||
# Construct common args for the APIServerProcessManager up-front.
|
||||
api_server_manager_kwargs = dict(
|
||||
|
||||
@@ -243,6 +243,8 @@ if TYPE_CHECKING:
|
||||
VLLM_LORA_DISABLE_PDL: bool = False
|
||||
VLLM_ENABLE_CUDA_COMPATIBILITY: bool = False
|
||||
VLLM_CUDA_COMPATIBILITY_PATH: str | None = None
|
||||
VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False
|
||||
VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@@ -1617,6 +1619,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_CUDA_COMPATIBILITY_PATH": lambda: os.environ.get(
|
||||
"VLLM_CUDA_COMPATIBILITY_PATH", None
|
||||
),
|
||||
# Whether it is a scale up launch engine for elastic EP,
|
||||
# Should only be set by EngineCoreClient.
|
||||
"VLLM_ELASTIC_EP_SCALE_UP_LAUNCH": lambda: bool(
|
||||
int(os.getenv("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH", "0"))
|
||||
),
|
||||
# Whether to wait for all requests to drain before sending the
|
||||
# scaling command in elastic EP.
|
||||
"VLLM_ELASTIC_EP_DRAIN_REQUESTS": lambda: bool(
|
||||
int(os.getenv("VLLM_ELASTIC_EP_DRAIN_REQUESTS", "0"))
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -627,6 +627,7 @@ class FusedMoE(CustomOp):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
self.base_quant_method = self.quant_method
|
||||
|
||||
# Disable shared expert overlap if:
|
||||
# - we are using eplb with non-default backend, because of correctness issues
|
||||
@@ -683,7 +684,7 @@ class FusedMoE(CustomOp):
|
||||
# routing_tables only needed for round-robin expert placement with
|
||||
# DeepEP all2all backend.
|
||||
routing_tables = self._maybe_init_expert_routing_tables()
|
||||
prepare_finalize = self.quant_method.maybe_make_prepare_finalize(
|
||||
prepare_finalize = self.base_quant_method.maybe_make_prepare_finalize(
|
||||
routing_tables=routing_tables
|
||||
)
|
||||
if prepare_finalize is not None:
|
||||
@@ -693,7 +694,7 @@ class FusedMoE(CustomOp):
|
||||
self._replace_quant_method(
|
||||
FusedMoEModularMethod.make(
|
||||
self,
|
||||
self.quant_method,
|
||||
self.base_quant_method,
|
||||
prepare_finalize,
|
||||
self.shared_experts,
|
||||
inplace=not self.moe_config.disable_inplace,
|
||||
|
||||
@@ -6,10 +6,13 @@ pynvml. However, it should not initialize cuda context.
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from functools import cache, wraps
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
import torch
|
||||
from torch.distributed import PrefixStore, ProcessGroup
|
||||
from torch.distributed.distributed_c10d import is_nccl_available
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
# import custom ops, trigger op registration
|
||||
@@ -482,6 +485,37 @@ class CudaPlatformBase(Platform):
|
||||
def get_static_graph_wrapper_cls(cls) -> str:
|
||||
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
|
||||
|
||||
@classmethod
|
||||
def stateless_init_device_torch_dist_pg(
|
||||
cls,
|
||||
backend: str,
|
||||
prefix_store: PrefixStore,
|
||||
group_rank: int,
|
||||
group_size: int,
|
||||
timeout: timedelta,
|
||||
) -> ProcessGroup:
|
||||
assert is_nccl_available()
|
||||
pg: ProcessGroup = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
from torch.distributed.distributed_c10d import ProcessGroupNCCL
|
||||
|
||||
backend_options = ProcessGroupNCCL.Options()
|
||||
backend_options._timeout = timeout
|
||||
|
||||
backend_class = ProcessGroupNCCL(
|
||||
prefix_store, group_rank, group_size, backend_options
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
device = torch.device("cuda")
|
||||
pg._set_default_backend(backend_type)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
return pg
|
||||
|
||||
@classmethod
|
||||
def device_count(cls) -> int:
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
@@ -2,10 +2,13 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from functools import cache, lru_cache, wraps
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.distributed import PrefixStore, ProcessGroup
|
||||
from torch.distributed.distributed_c10d import is_nccl_available
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
@@ -656,6 +659,37 @@ class RocmPlatform(Platform):
|
||||
def get_static_graph_wrapper_cls(cls) -> str:
|
||||
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
|
||||
|
||||
@classmethod
|
||||
def stateless_init_device_torch_dist_pg(
|
||||
cls,
|
||||
backend: str,
|
||||
prefix_store: PrefixStore,
|
||||
group_rank: int,
|
||||
group_size: int,
|
||||
timeout: timedelta,
|
||||
) -> ProcessGroup:
|
||||
assert is_nccl_available()
|
||||
pg: ProcessGroup = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
from torch.distributed.distributed_c10d import ProcessGroupNCCL
|
||||
|
||||
backend_options = ProcessGroupNCCL.Options()
|
||||
backend_options._timeout = timeout
|
||||
|
||||
backend_class = ProcessGroupNCCL(
|
||||
prefix_store, group_rank, group_size, backend_options
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
device = torch.device("cuda")
|
||||
pg._set_default_backend(backend_type)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
return pg
|
||||
|
||||
@classmethod
|
||||
def device_count(cls) -> int:
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
@@ -29,6 +29,15 @@ PauseMode = Literal["abort", "wait", "keep"]
|
||||
# so form part of the external API.
|
||||
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
|
||||
|
||||
EEP_NOTIFICATION_CALL_ID = -1
|
||||
|
||||
|
||||
class EEPNotificationType(enum.Enum):
|
||||
NEW_CORE_ENGINES_INIT_READY = "NEW_CORE_ENGINES_INIT_READY"
|
||||
NEW_CORE_ENGINES_WEIGHTS_INIT_READY = "NEW_CORE_ENGINES_WEIGHTS_INIT_READY"
|
||||
RECONFIGURE_FINISHED = "RECONFIGURE_FINISHED"
|
||||
SHUTDOWN_COMPLETE = "SHUTDOWN_COMPLETE"
|
||||
|
||||
|
||||
class FinishReason(enum.IntEnum):
|
||||
"""
|
||||
@@ -235,6 +244,11 @@ class ReconfigureDistributedRequest(msgspec.Struct):
|
||||
new_data_parallel_rank_local: int
|
||||
new_data_parallel_master_ip: str
|
||||
new_data_parallel_master_port: int
|
||||
new_data_parallel_master_port_list: list[int]
|
||||
new_stateless_world_group_port_list: list[list[int]]
|
||||
new_stateless_dp_group_port_list: list[list[int]]
|
||||
new_stateless_ep_group_port_list: list[list[int]]
|
||||
new_stateless_eplb_group_port_list: list[list[int]]
|
||||
|
||||
|
||||
class ReconfigureRankType(enum.IntEnum):
|
||||
|
||||
+27
-14
@@ -20,6 +20,7 @@ from vllm.distributed.weight_transfer.base import (
|
||||
)
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.protocol import EngineClient, StreamingInput
|
||||
from vllm.entrypoints.serve.elastic_ep.middleware import set_scaling_elastic_ep
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -647,7 +648,11 @@ class AsyncLLM(EngineClient):
|
||||
engine_core = self.engine_core
|
||||
output_processor = self.output_processor
|
||||
log_stats = self.log_stats
|
||||
logger_manager = self.logger_manager
|
||||
# We use a mutable list for logger_manager so that it can be updated
|
||||
# during elastic EP scaling (see scale_elastic_ep) without creating
|
||||
# a circular reference via self.
|
||||
self._logger_ref = [self.logger_manager]
|
||||
logger_ref = self._logger_ref
|
||||
renderer = self.renderer
|
||||
chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
|
||||
|
||||
@@ -691,8 +696,8 @@ class AsyncLLM(EngineClient):
|
||||
# 4) Logging.
|
||||
# TODO(rob): make into a coroutine and launch it in
|
||||
# background thread once Prometheus overhead is non-trivial.
|
||||
if logger_manager:
|
||||
logger_manager.record(
|
||||
if logger_ref[0]:
|
||||
logger_ref[0].record(
|
||||
engine_idx=outputs.engine_index,
|
||||
scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats,
|
||||
@@ -976,17 +981,13 @@ class AsyncLLM(EngineClient):
|
||||
new_data_parallel_size,
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"Waiting for requests to drain before scaling up to %s engines...",
|
||||
new_data_parallel_size,
|
||||
)
|
||||
await self.wait_for_requests_to_drain(drain_timeout)
|
||||
logger.info(
|
||||
"Requests have been drained, proceeding with scale to %s engines",
|
||||
new_data_parallel_size,
|
||||
)
|
||||
await self.engine_core.scale_elastic_ep(new_data_parallel_size)
|
||||
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
|
||||
|
||||
if envs.VLLM_ELASTIC_EP_DRAIN_REQUESTS:
|
||||
logger.info(
|
||||
"VLLM_ELASTIC_EP_DRAIN_REQUESTS is set, "
|
||||
"waiting for requests to drain before scaling"
|
||||
)
|
||||
await self.wait_for_requests_to_drain(drain_timeout)
|
||||
|
||||
# recreate stat loggers
|
||||
if new_data_parallel_size > old_data_parallel_size and self.log_stats:
|
||||
@@ -999,6 +1000,18 @@ class AsyncLLM(EngineClient):
|
||||
engine_idxs=list(range(new_data_parallel_size)),
|
||||
custom_stat_loggers=None,
|
||||
)
|
||||
# Update the mutable ref so output_handler picks up the
|
||||
# new logger without creating a circular reference via self.
|
||||
if hasattr(self, "_logger_ref"):
|
||||
self._logger_ref[0] = self.logger_manager
|
||||
self.logger_manager.log_engine_initialized()
|
||||
|
||||
set_scaling_elastic_ep(True)
|
||||
try:
|
||||
await self.engine_core.scale_elastic_ep(new_data_parallel_size)
|
||||
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
|
||||
finally:
|
||||
set_scaling_elastic_ep(False)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
|
||||
@@ -71,6 +71,9 @@ class DPCoordinator:
|
||||
)
|
||||
|
||||
local_only_eng = dp_size == parallel_config.data_parallel_size_local
|
||||
# NOTE(yongji): handling scaling from intra-node to inter-node
|
||||
if parallel_config.enable_elastic_ep:
|
||||
local_only_eng = False
|
||||
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
|
||||
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
|
||||
|
||||
@@ -201,6 +204,7 @@ class DPCoordinatorProc:
|
||||
|
||||
poller = zmq.Poller()
|
||||
poller.register(publish_front, zmq.POLLIN)
|
||||
poller.register(publish_back, zmq.POLLIN)
|
||||
poller.register(output_back, zmq.POLLIN)
|
||||
last_publish_time = 0
|
||||
while True:
|
||||
@@ -231,6 +235,22 @@ class DPCoordinatorProc:
|
||||
events = dict(events)
|
||||
wave_state_changed = False
|
||||
|
||||
if publish_back in events:
|
||||
buffer = publish_back.recv()
|
||||
if buffer == b"\x01":
|
||||
# NOTE(yongji): newly started engine subscribed
|
||||
# We need to send READY message here instead of receiving
|
||||
# SCALE_ELASTIC_EP notification from engine core client
|
||||
# as SCALE_ELASTIC_EP is only sent when
|
||||
# new engines finished initialization.
|
||||
# Subscription message, on the other hand, is sent
|
||||
# by each engine during initialization
|
||||
publish_back.send(b"READY")
|
||||
else:
|
||||
logger.error(
|
||||
"DP Coordinator receives unexpected message from engines"
|
||||
)
|
||||
|
||||
if publish_front in events:
|
||||
buffer = publish_front.recv()
|
||||
if buffer in (b"\x01", b"\x00"):
|
||||
@@ -259,7 +279,6 @@ class DPCoordinatorProc:
|
||||
# current_wave
|
||||
# we note that 0 is the wave number for the new
|
||||
# engine
|
||||
engines_running = False
|
||||
logger.info(
|
||||
"DPCoordinator scaled up from %s to %s engines",
|
||||
current_count,
|
||||
|
||||
+171
-52
@@ -17,6 +17,7 @@ from typing import Any, TypeVar, cast
|
||||
import msgspec
|
||||
import zmq
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||
from vllm.envs import enable_envs_cache
|
||||
@@ -44,6 +45,8 @@ from vllm.v1.core.kv_cache_utils import (
|
||||
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.engine import (
|
||||
EEP_NOTIFICATION_CALL_ID,
|
||||
EEPNotificationType,
|
||||
EngineCoreOutput,
|
||||
EngineCoreOutputs,
|
||||
EngineCoreRequest,
|
||||
@@ -110,6 +113,9 @@ class EngineCore:
|
||||
|
||||
self.available_gpu_memory_for_kv_cache = -1
|
||||
|
||||
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
|
||||
self._eep_scale_up_before_kv_init()
|
||||
|
||||
# Setup KV Caches and update CacheConfig after profiling.
|
||||
num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
|
||||
vllm_config
|
||||
@@ -233,12 +239,10 @@ class EngineCore:
|
||||
|
||||
has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
|
||||
if has_kv_cache:
|
||||
if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
|
||||
dp_group = getattr(self, "dp_group", None)
|
||||
assert dp_group is not None
|
||||
self.available_gpu_memory_for_kv_cache = (
|
||||
ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
|
||||
)
|
||||
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
|
||||
# NOTE(yongji): should already be set
|
||||
# during _eep_scale_up_before_kv_init
|
||||
assert self.available_gpu_memory_for_kv_cache > 0
|
||||
available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
|
||||
kv_cache_specs
|
||||
)
|
||||
@@ -752,11 +756,22 @@ class EngineCore:
|
||||
self.structured_output_manager.grammar_init(req)
|
||||
return req, request.current_wave
|
||||
|
||||
def _eep_scale_up_before_kv_init(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _eep_send_engine_core_notification(
|
||||
self,
|
||||
notification_type: EEPNotificationType,
|
||||
vllm_config: VllmConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EngineCoreProc(EngineCore):
|
||||
"""ZMQ-wrapper for running EngineCore in background process."""
|
||||
|
||||
ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
|
||||
addresses: EngineZmqAddresses
|
||||
|
||||
@instrument(span_name="EngineCoreProc init")
|
||||
def __init__(
|
||||
@@ -807,6 +822,13 @@ class EngineCoreProc(EngineCore):
|
||||
# and "hybrid" LB modes.
|
||||
self.publish_dp_lb_stats = internal_dp_balancing
|
||||
|
||||
self.addresses = addresses
|
||||
self.process_input_queue_block = True
|
||||
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
|
||||
self._eep_send_engine_core_notification(
|
||||
EEPNotificationType.NEW_CORE_ENGINES_INIT_READY,
|
||||
vllm_config=vllm_config,
|
||||
)
|
||||
self._init_data_parallel(vllm_config)
|
||||
|
||||
super().__init__(
|
||||
@@ -1119,8 +1141,14 @@ class EngineCoreProc(EngineCore):
|
||||
if logger.isEnabledFor(DEBUG):
|
||||
logger.debug("EngineCore waiting for work.")
|
||||
waited = True
|
||||
req = self.input_queue.get()
|
||||
self._handle_client_request(*req)
|
||||
block = self.process_input_queue_block
|
||||
try:
|
||||
req = self.input_queue.get(block=block)
|
||||
self._handle_client_request(*req)
|
||||
except queue.Empty:
|
||||
break
|
||||
if not block:
|
||||
break
|
||||
|
||||
if waited:
|
||||
logger.debug("EngineCore loop active.")
|
||||
@@ -1290,6 +1318,11 @@ class EngineCoreProc(EngineCore):
|
||||
for input_socket, _ in poller.poll():
|
||||
# (RequestType, RequestData)
|
||||
type_frame, *data_frames = input_socket.recv_multipart(copy=False)
|
||||
# NOTE(yongji): ignore READY message sent by DP coordinator
|
||||
# that is used to notify newly started engines
|
||||
if type_frame.buffer == b"READY":
|
||||
assert input_socket == coord_socket
|
||||
continue
|
||||
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
||||
|
||||
# Deserialize the request data.
|
||||
@@ -1488,6 +1521,10 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
self.current_wave = 0
|
||||
self.last_counts = (0, 0)
|
||||
|
||||
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
|
||||
|
||||
self.eep_scaling_state: ElasticEPScalingState | None = None
|
||||
|
||||
# Initialize the engine.
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
super().__init__(
|
||||
@@ -1511,7 +1548,9 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
assert 0 <= local_dp_rank <= dp_rank < dp_size
|
||||
|
||||
self.dp_rank = dp_rank
|
||||
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
||||
self.dp_group, self.dp_store = (
|
||||
vllm_config.parallel_config.stateless_init_dp_group(return_store=True)
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
super().shutdown()
|
||||
@@ -1574,7 +1613,12 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
self._process_input_queue()
|
||||
|
||||
# 2) Step the engine core.
|
||||
if self.eep_scaling_state is not None:
|
||||
_ = self.eep_scaling_state.progress()
|
||||
if self.eep_scaling_state.is_complete():
|
||||
self.process_input_queue_block = True
|
||||
self.eep_scaling_state = None
|
||||
|
||||
executed = self._process_engine_step()
|
||||
self._maybe_publish_request_counts()
|
||||
|
||||
@@ -1624,54 +1668,129 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
def reinitialize_distributed(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
stateless_destroy_torch_distributed_process_group(self.dp_group)
|
||||
self.shutdown()
|
||||
from copy import deepcopy
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
old_dp_size = parallel_config.data_parallel_size
|
||||
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
|
||||
if reconfig_request.new_data_parallel_rank != -1:
|
||||
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
|
||||
# local rank specifies device visibility, it should not be changed
|
||||
assert (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
== ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
)
|
||||
parallel_config.data_parallel_master_ip = (
|
||||
reconfig_request.new_data_parallel_master_ip
|
||||
)
|
||||
parallel_config.data_parallel_master_port = (
|
||||
reconfig_request.new_data_parallel_master_port
|
||||
)
|
||||
if reconfig_request.new_data_parallel_rank != -2:
|
||||
self.dp_rank = parallel_config.data_parallel_rank
|
||||
self.dp_group = parallel_config.stateless_init_dp_group()
|
||||
reconfig_request.new_data_parallel_master_port = (
|
||||
parallel_config.data_parallel_master_port
|
||||
)
|
||||
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
|
||||
|
||||
self.model_executor.reinitialize_distributed(reconfig_request)
|
||||
if reconfig_request.new_data_parallel_size > old_dp_size:
|
||||
assert self.available_gpu_memory_for_kv_cache > 0
|
||||
# pass available_gpu_memory_for_kv_cache from existing
|
||||
# engine-cores to new engine-cores so they can directly
|
||||
# use it in _initialize_kv_caches() rather than profiling.
|
||||
ParallelConfig.sync_kv_cache_memory_size(
|
||||
self.dp_group, self.available_gpu_memory_for_kv_cache
|
||||
)
|
||||
# NOTE(yongji): newly joined workers require dummy_run even
|
||||
# CUDA graph is not used
|
||||
self.model_executor.collective_rpc("compile_or_warm_up_model")
|
||||
new_parallel_config = deepcopy(self.vllm_config.parallel_config)
|
||||
old_dp_size = new_parallel_config.data_parallel_size
|
||||
new_parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
self.shutdown()
|
||||
logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
|
||||
else:
|
||||
logger.info(
|
||||
"Distributed environment reinitialized for DP rank %s", self.dp_rank
|
||||
new_parallel_config.data_parallel_rank = (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
)
|
||||
new_parallel_config.data_parallel_master_ip = (
|
||||
reconfig_request.new_data_parallel_master_ip
|
||||
)
|
||||
new_parallel_config.data_parallel_master_port = (
|
||||
reconfig_request.new_data_parallel_master_port
|
||||
)
|
||||
new_parallel_config._data_parallel_master_port_list = (
|
||||
reconfig_request.new_data_parallel_master_port_list
|
||||
)
|
||||
|
||||
is_scale_down = reconfig_request.new_data_parallel_size < old_dp_size
|
||||
is_shutdown = (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
|
||||
)
|
||||
|
||||
self.eep_scaling_state = ElasticEPScalingState(
|
||||
model_executor=self.model_executor,
|
||||
engine_core=self,
|
||||
vllm_config=self.vllm_config,
|
||||
new_parallel_config=new_parallel_config,
|
||||
worker_type="removing" if is_shutdown else "existing",
|
||||
scale_type="scale_down" if is_scale_down else "scale_up",
|
||||
reconfig_request=reconfig_request,
|
||||
)
|
||||
self.process_input_queue_block = False
|
||||
logger.info(
|
||||
"[Elastic EP] Received reconfiguration request and starting scaling up/down"
|
||||
)
|
||||
|
||||
def _eep_send_engine_core_notification(
|
||||
self,
|
||||
notification_type: EEPNotificationType,
|
||||
vllm_config: VllmConfig | None = None,
|
||||
):
|
||||
"""
|
||||
Send notifications to EngineCoreClient, which can then forward
|
||||
the notifications to other engine core processes. It is used for:
|
||||
1) In scale up: new core engines to notify exisiting core engines
|
||||
that they are ready;
|
||||
2) In scale down: removing core engines to notify EngineCoreClient
|
||||
so EngineCoreClient can release their ray placement groups;
|
||||
3) Both scale up/down: to notify EngineCoreClient that exisiting
|
||||
core engines have already switched to the new parallel setup.
|
||||
"""
|
||||
if vllm_config is None:
|
||||
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
|
||||
else:
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
notification_data = (notification_type.value, dp_rank)
|
||||
outputs = EngineCoreOutputs(
|
||||
utility_output=UtilityOutput(
|
||||
call_id=EEP_NOTIFICATION_CALL_ID,
|
||||
result=UtilityResult(notification_data),
|
||||
)
|
||||
)
|
||||
outputs.engine_index = self.engine_index
|
||||
|
||||
if hasattr(self, "output_thread") and self.output_thread.is_alive():
|
||||
self.output_queue.put_nowait((0, outputs))
|
||||
else:
|
||||
encoder = MsgpackEncoder()
|
||||
with (
|
||||
zmq.Context() as ctx,
|
||||
make_zmq_socket(
|
||||
ctx, self.addresses.outputs[0], zmq.PUSH, linger=4000
|
||||
) as socket,
|
||||
):
|
||||
socket.send_multipart(encoder.encode(outputs))
|
||||
|
||||
def eep_handle_engine_core_notification(
|
||||
self, notification_type: str | EEPNotificationType
|
||||
):
|
||||
"""
|
||||
Handle notification received from EngineCoreClient
|
||||
(forwarded from new core engines).
|
||||
"""
|
||||
assert self.eep_scaling_state is not None
|
||||
if isinstance(notification_type, str):
|
||||
notification_type = EEPNotificationType(notification_type)
|
||||
self.eep_scaling_state.handle_notification(notification_type)
|
||||
|
||||
def _eep_scale_up_before_kv_init(self):
|
||||
from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState
|
||||
|
||||
self.eep_scaling_state = ElasticEPScalingState(
|
||||
model_executor=self.model_executor,
|
||||
engine_core=self,
|
||||
vllm_config=self.vllm_config,
|
||||
new_parallel_config=self.vllm_config.parallel_config,
|
||||
worker_type="new",
|
||||
scale_type="scale_up",
|
||||
reconfig_request=None,
|
||||
)
|
||||
self.model_executor.collective_rpc("init_device")
|
||||
self.model_executor.collective_rpc("load_model")
|
||||
self._eep_send_engine_core_notification(
|
||||
EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
|
||||
)
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("receive_weights",)
|
||||
)
|
||||
self.available_gpu_memory_for_kv_cache = (
|
||||
ParallelConfig.sync_kv_cache_memory_size(self.dp_group, -1)
|
||||
)
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("prepare_new_worker",)
|
||||
)
|
||||
self.process_input_queue_block = False
|
||||
|
||||
|
||||
class EngineCoreActorMixin:
|
||||
|
||||
+268
-51
@@ -28,11 +28,12 @@ from vllm.tracing import instrument
|
||||
from vllm.utils.async_utils import in_loop
|
||||
from vllm.utils.network_utils import (
|
||||
close_sockets,
|
||||
get_open_port,
|
||||
get_open_zmq_inproc_path,
|
||||
make_zmq_socket,
|
||||
)
|
||||
from vllm.v1.engine import (
|
||||
EEP_NOTIFICATION_CALL_ID,
|
||||
EEPNotificationType,
|
||||
EngineCoreOutputs,
|
||||
EngineCoreRequest,
|
||||
EngineCoreRequestType,
|
||||
@@ -47,6 +48,7 @@ from vllm.v1.engine.exceptions import EngineDeadError
|
||||
from vllm.v1.engine.utils import (
|
||||
CoreEngineActorManager,
|
||||
CoreEngineProcManager,
|
||||
get_engine_zmq_addresses,
|
||||
launch_core_engines,
|
||||
)
|
||||
from vllm.v1.executor import Executor
|
||||
@@ -445,6 +447,63 @@ class BackgroundResources:
|
||||
raise EngineDeadError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ElasticScalingCache:
|
||||
existing_core_engines: list[EngineIdentity]
|
||||
num_new_core_engines: int
|
||||
pending_notifications: dict[EEPNotificationType, set[int]]
|
||||
|
||||
|
||||
def allocate_stateless_group_ports(parallel_config, new_data_parallel_size: int):
|
||||
"""
|
||||
Allocate stateless group ports for elastic EP.
|
||||
"""
|
||||
from vllm.utils.network_utils import get_open_ports_list
|
||||
|
||||
assert parallel_config.enable_elastic_ep, "Elastic EP must be enabled"
|
||||
world_size = parallel_config.world_size
|
||||
new_world_size_across_dp = world_size * new_data_parallel_size
|
||||
num_world_groups = 1
|
||||
num_dp_groups = max(1, new_world_size_across_dp // new_data_parallel_size)
|
||||
num_ep_groups = max(
|
||||
1,
|
||||
new_world_size_across_dp
|
||||
// (new_data_parallel_size * parallel_config.tensor_parallel_size),
|
||||
)
|
||||
num_eplb_groups = num_ep_groups
|
||||
total_ports_needed = (
|
||||
num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
|
||||
) * 3 + 5
|
||||
all_ports = get_open_ports_list(total_ports_needed)
|
||||
new_data_parallel_master_port_list = all_ports[-5:]
|
||||
all_ports = all_ports[:-5]
|
||||
new_stateless_world_group_port_list = [
|
||||
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
|
||||
]
|
||||
start_idx = num_world_groups * 3
|
||||
new_stateless_dp_group_port_list = [
|
||||
all_ports[i : i + 3] for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
|
||||
]
|
||||
start_idx += num_dp_groups * 3
|
||||
new_stateless_ep_group_port_list = [
|
||||
all_ports[i : i + 3] for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
|
||||
]
|
||||
start_idx += num_ep_groups * 3
|
||||
new_stateless_eplb_group_port_list = [
|
||||
all_ports[i : i + 3]
|
||||
for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
|
||||
]
|
||||
|
||||
parallel_config._stateless_world_group_port_list = (
|
||||
new_stateless_world_group_port_list
|
||||
)
|
||||
parallel_config._stateless_dp_group_port_list = new_stateless_dp_group_port_list
|
||||
parallel_config._stateless_ep_group_port_list = new_stateless_ep_group_port_list
|
||||
parallel_config._stateless_eplb_group_port_list = new_stateless_eplb_group_port_list
|
||||
parallel_config.data_parallel_master_port = new_data_parallel_master_port_list.pop()
|
||||
parallel_config._data_parallel_master_port_list = new_data_parallel_master_port_list
|
||||
|
||||
|
||||
class MPClient(EngineCoreClient):
|
||||
"""
|
||||
MPClient: base client for multi-proc EngineCore.
|
||||
@@ -491,32 +550,37 @@ class MPClient(EngineCoreClient):
|
||||
input_address = client_addresses["input_address"]
|
||||
output_address = client_addresses["output_address"]
|
||||
self.stats_update_address = client_addresses.get("stats_update_address")
|
||||
self.input_socket = self.resources.input_socket = make_zmq_socket(
|
||||
self.ctx, input_address, zmq.ROUTER, bind=True
|
||||
)
|
||||
self.resources.output_socket = make_zmq_socket(
|
||||
self.ctx, output_address, zmq.PULL
|
||||
)
|
||||
else:
|
||||
# Engines are managed by this client.
|
||||
with launch_core_engines(vllm_config, executor_class, log_stats) as (
|
||||
engine_manager,
|
||||
coordinator,
|
||||
addresses = get_engine_zmq_addresses(vllm_config)
|
||||
self.input_socket = self.resources.input_socket = make_zmq_socket(
|
||||
self.ctx, addresses.inputs[0], zmq.ROUTER, bind=True
|
||||
)
|
||||
self.resources.output_socket = make_zmq_socket(
|
||||
self.ctx, addresses.outputs[0], zmq.PULL
|
||||
)
|
||||
|
||||
with launch_core_engines(
|
||||
vllm_config,
|
||||
executor_class,
|
||||
log_stats,
|
||||
addresses,
|
||||
):
|
||||
) as (engine_manager, coordinator, addresses):
|
||||
self.resources.coordinator = coordinator
|
||||
self.resources.engine_manager = engine_manager
|
||||
|
||||
(input_address,) = addresses.inputs
|
||||
(output_address,) = addresses.outputs
|
||||
self.stats_update_address = addresses.frontend_stats_publish_address
|
||||
if coordinator is not None:
|
||||
assert self.stats_update_address == (
|
||||
coordinator.get_stats_publish_address()
|
||||
)
|
||||
|
||||
# Create input and output sockets.
|
||||
self.input_socket = self.resources.input_socket = make_zmq_socket(
|
||||
self.ctx, input_address, zmq.ROUTER, bind=True
|
||||
)
|
||||
self.resources.output_socket = make_zmq_socket(
|
||||
self.ctx, output_address, zmq.PULL
|
||||
)
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_index
|
||||
@@ -877,6 +941,10 @@ class AsyncMPClient(MPClient):
|
||||
output_socket = resources.output_socket
|
||||
assert output_socket is not None
|
||||
|
||||
notification_callback_handler: (
|
||||
Callable[[AsyncMPClient, Sequence[Any]], Any] | None
|
||||
) = getattr(self.__class__, "eep_process_engine_core_notification", None)
|
||||
|
||||
async def process_outputs_socket():
|
||||
try:
|
||||
while True:
|
||||
@@ -884,7 +952,26 @@ class AsyncMPClient(MPClient):
|
||||
resources.validate_alive(frames)
|
||||
outputs: EngineCoreOutputs = decoder.decode(frames)
|
||||
if outputs.utility_output:
|
||||
_process_utility_output(outputs.utility_output, utility_results)
|
||||
if (
|
||||
outputs.utility_output.call_id == EEP_NOTIFICATION_CALL_ID
|
||||
and notification_callback_handler is not None
|
||||
):
|
||||
assert _self_ref is not None
|
||||
_self = _self_ref()
|
||||
if not _self:
|
||||
return
|
||||
if outputs.utility_output.result is None:
|
||||
continue
|
||||
notification_data = outputs.utility_output.result.result
|
||||
assert isinstance(notification_data, Sequence)
|
||||
assert len(notification_data) == 2
|
||||
asyncio.create_task(
|
||||
notification_callback_handler(_self, notification_data)
|
||||
)
|
||||
else:
|
||||
_process_utility_output(
|
||||
outputs.utility_output, utility_results
|
||||
)
|
||||
continue
|
||||
|
||||
if output_handler is not None:
|
||||
@@ -1081,6 +1168,8 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
# Used only by DPLBAsyncMPClient subclass.
|
||||
self.lb_engines: list[list[int]] = [[0, 0] for _ in self.core_engines]
|
||||
|
||||
self.eep_scaling_cache: ElasticScalingCache | None = None
|
||||
|
||||
self.first_req_sock_addr = get_open_zmq_inproc_path()
|
||||
self.first_req_send_socket = self.resources.first_req_send_socket = (
|
||||
make_zmq_socket(self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=True)
|
||||
@@ -1101,12 +1190,6 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
assert self.stats_update_address is not None
|
||||
stats_addr: str = self.stats_update_address
|
||||
assert len(self.engine_ranks_managed) > 0
|
||||
# NOTE: running and waiting counts are all global from
|
||||
# the Coordinator include all global EngineCores. This
|
||||
# slice includes just the cores managed by this client.
|
||||
count_slice = slice(
|
||||
self.engine_ranks_managed[0], self.engine_ranks_managed[-1] + 1
|
||||
)
|
||||
|
||||
async def run_engine_stats_update_task():
|
||||
with (
|
||||
@@ -1145,6 +1228,29 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
):
|
||||
# Extract new engine count from the decoded message
|
||||
new_engine_count = decoded[1]
|
||||
# Update engine_ranks_managed and count_slice
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
assert dp_rank == 0
|
||||
assert dp_size == new_engine_count
|
||||
assert not (
|
||||
parallel_config.data_parallel_hybrid_lb
|
||||
or parallel_config.data_parallel_external_lb
|
||||
)
|
||||
num_ranks = dp_size
|
||||
self.engine_ranks_managed = list(
|
||||
range(dp_rank, dp_rank + num_ranks)
|
||||
)
|
||||
if len(self.lb_engines) < new_engine_count:
|
||||
self.lb_engines = self.lb_engines + [
|
||||
[0, 0]
|
||||
for _ in range(
|
||||
new_engine_count - len(self.lb_engines)
|
||||
)
|
||||
]
|
||||
else:
|
||||
self.lb_engines = self.lb_engines[:new_engine_count]
|
||||
# Send scale up notification to coordinator
|
||||
scale_msg = msgspec.msgpack.encode(
|
||||
("SCALE_ELASTIC_EP", new_engine_count)
|
||||
@@ -1178,6 +1284,11 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
self.current_wave = wave
|
||||
self.engines_running = running
|
||||
if counts is not None:
|
||||
# Running and waiting counts are global from the
|
||||
# Coordinator including all EngineCores. Slice to get
|
||||
# just the cores managed by this client.
|
||||
ranks = self.engine_ranks_managed
|
||||
count_slice = slice(ranks[0], ranks[-1] + 1)
|
||||
sliced_counts = counts[count_slice]
|
||||
self.lb_engines = sliced_counts
|
||||
logger.debug(
|
||||
@@ -1287,6 +1398,67 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
for req_id in outputs.finished_requests:
|
||||
self.reqs_in_flight.pop(req_id, None)
|
||||
|
||||
@staticmethod
|
||||
async def eep_process_engine_core_notification(
|
||||
self: "DPLBAsyncMPClient", notification_data: tuple[str, int]
|
||||
):
|
||||
cache = self.eep_scaling_cache
|
||||
notification_type_str, dp_rank = notification_data
|
||||
try:
|
||||
notification_type = EEPNotificationType(notification_type_str)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Unknown EEP notification type: {notification_type_str}"
|
||||
) from e
|
||||
|
||||
if notification_type == EEPNotificationType.RECONFIGURE_FINISHED:
|
||||
from vllm.v1.engine import UtilityResult
|
||||
|
||||
# NOTE(yongji): process a dummy UtilityOutput to resolve the future
|
||||
# awaited in _eep_wait_for_setup_switch_complete(), signaling that
|
||||
# all engine cores have completed reconfiguration.
|
||||
dummy_output = UtilityOutput(
|
||||
call_id=EEP_NOTIFICATION_CALL_ID, result=UtilityResult(None)
|
||||
)
|
||||
_process_utility_output(dummy_output, self.utility_results)
|
||||
return
|
||||
assert cache is not None
|
||||
if notification_type not in cache.pending_notifications:
|
||||
cache.pending_notifications[notification_type] = set()
|
||||
if dp_rank in cache.pending_notifications[notification_type]:
|
||||
raise ValueError(
|
||||
f"Duplicate notification {notification_type} from dp_rank {dp_rank}"
|
||||
)
|
||||
cache.pending_notifications[notification_type].add(dp_rank)
|
||||
if len(cache.pending_notifications[notification_type]) >= abs(
|
||||
cache.num_new_core_engines
|
||||
):
|
||||
if notification_type == EEPNotificationType.SHUTDOWN_COMPLETE:
|
||||
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
|
||||
assert cache.num_new_core_engines < 0
|
||||
old_dp_size = len(cache.existing_core_engines)
|
||||
new_dp_size = old_dp_size + cache.num_new_core_engines
|
||||
self.resources.engine_manager.scale_down_elastic_ep(
|
||||
old_dp_size, new_dp_size
|
||||
)
|
||||
else:
|
||||
await asyncio.gather(
|
||||
*[
|
||||
self._call_utility_async(
|
||||
"eep_handle_engine_core_notification",
|
||||
notification_type,
|
||||
engine=engine,
|
||||
)
|
||||
for engine in cache.existing_core_engines
|
||||
]
|
||||
)
|
||||
cache.pending_notifications[notification_type] = set()
|
||||
if notification_type in [
|
||||
EEPNotificationType.SHUTDOWN_COMPLETE,
|
||||
EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY,
|
||||
]:
|
||||
self.eep_scaling_cache = None
|
||||
|
||||
async def abort_requests_async(self, request_ids: list[str]) -> None:
|
||||
if not request_ids or self.resources.engine_dead:
|
||||
return
|
||||
@@ -1333,6 +1505,20 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
cur_data_parallel_size, new_data_parallel_size
|
||||
)
|
||||
|
||||
async def _eep_wait_for_setup_switch_complete(self) -> None:
|
||||
"""
|
||||
Wait for core engines to switch to the new setup.
|
||||
|
||||
In eep_process_engine_core_notification(), a dummy UtilityOutput with
|
||||
EEP_NOTIFICATION_CALL_ID will be set when RECONFIGURE_FINISHED
|
||||
notification is received from engine 0. We create a future with
|
||||
that call_id and wait for it to be resolved.
|
||||
"""
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
self.utility_results[EEP_NOTIFICATION_CALL_ID] = future
|
||||
self._ensure_output_queue_task()
|
||||
await future
|
||||
|
||||
async def _scale_up_elastic_ep(
|
||||
self, cur_data_parallel_size: int, new_data_parallel_size: int
|
||||
) -> None:
|
||||
@@ -1340,38 +1526,57 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
and reconfiguring existing ones."""
|
||||
cur_data_parallel_size = len(self.core_engines)
|
||||
|
||||
# Phase 1: Send reconfigure messages to all existing engines and wait
|
||||
# for them to be sent
|
||||
self.eep_scaling_cache = ElasticScalingCache(
|
||||
existing_core_engines=self.core_engines.copy(),
|
||||
num_new_core_engines=new_data_parallel_size - cur_data_parallel_size,
|
||||
pending_notifications=dict(),
|
||||
)
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
allocate_stateless_group_ports(parallel_config, new_data_parallel_size)
|
||||
|
||||
# Phase 1: Send reconfig messages to existing engines
|
||||
reconfig_futures = []
|
||||
self.vllm_config.parallel_config.data_parallel_master_port = get_open_port()
|
||||
for engine in self.core_engines:
|
||||
reconfig_request = ReconfigureDistributedRequest(
|
||||
new_data_parallel_size=new_data_parallel_size,
|
||||
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
|
||||
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
|
||||
new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip,
|
||||
new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port,
|
||||
new_data_parallel_master_ip=parallel_config.data_parallel_master_ip,
|
||||
new_data_parallel_master_port=parallel_config.data_parallel_master_port,
|
||||
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
|
||||
new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list,
|
||||
new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list,
|
||||
new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list,
|
||||
new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list,
|
||||
)
|
||||
coro = self._call_utility_async(
|
||||
"reinitialize_distributed", reconfig_request, engine=engine
|
||||
)
|
||||
reconfig_futures.append(asyncio.create_task(coro))
|
||||
|
||||
logger.info("All reconfigure messages sent, starting engine creation")
|
||||
|
||||
# Phase 2: Create new engines now that reconfig messages have been sent
|
||||
# self.resources.engine_manager is guaranteed to be
|
||||
# CoreEngineActorManager for RayDPClient
|
||||
# Phase 2: Create new engines
|
||||
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
|
||||
self.resources.engine_manager.scale_up_elastic_ep(
|
||||
self.vllm_config, new_data_parallel_size
|
||||
parallel_config.eplb_config.num_redundant_experts = 0
|
||||
start_new_worker_future = asyncio.to_thread(
|
||||
self.resources.engine_manager.scale_up_elastic_ep,
|
||||
self.vllm_config,
|
||||
new_data_parallel_size,
|
||||
)
|
||||
wait_future = self._eep_wait_for_setup_switch_complete()
|
||||
|
||||
# Phase 3: Wait for new engines to be created
|
||||
# and reconfig messages to be received
|
||||
await asyncio.gather(start_new_worker_future, *reconfig_futures)
|
||||
logger.info("[Elastic EP] Successfully started new engines")
|
||||
|
||||
# Create new CoreEngine objects for the new engines
|
||||
new_engine_identities = set()
|
||||
for i in range(cur_data_parallel_size, new_data_parallel_size):
|
||||
new_engine = i.to_bytes(2, "little")
|
||||
self.core_engines.append(new_engine)
|
||||
# NOTE(yongji): we don't update lb_engines here,
|
||||
# we let run_engine_stats_update_task to update it.
|
||||
new_engine_identities.add(new_engine)
|
||||
|
||||
# Wait for ready messages from new engines on the input socket
|
||||
@@ -1387,10 +1592,11 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
identity, _ = sync_input_socket.recv_multipart()
|
||||
new_engine_identities.discard(identity)
|
||||
|
||||
# Phase 3: Wait for all existing engines to complete reconfiguration
|
||||
logger.info("Waiting for existing engines to complete reconfiguration")
|
||||
await asyncio.gather(*reconfig_futures)
|
||||
|
||||
# NOTE(yongji): Before we schedule any requests on the new workers,
|
||||
# we should wait for them to switch to the new setup.
|
||||
await wait_future
|
||||
# Update the parallel config
|
||||
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
|
||||
# Notify coordinator about scale up through existing
|
||||
# stats_update_task connection
|
||||
self._ensure_stats_update_task()
|
||||
@@ -1399,8 +1605,6 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
)
|
||||
await self.first_req_send_socket.send(scale_up_marker)
|
||||
|
||||
# Update the parallel config
|
||||
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
|
||||
logger.info(
|
||||
"[Elastic EP] Scale up completed, new data parallel size: %s",
|
||||
new_data_parallel_size,
|
||||
@@ -1413,7 +1617,14 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
reconfiguring existing engine cores."""
|
||||
cur_data_parallel_size = len(self.core_engines)
|
||||
|
||||
self.vllm_config.parallel_config.data_parallel_master_port = get_open_port()
|
||||
self.eep_scaling_cache = ElasticScalingCache(
|
||||
existing_core_engines=self.core_engines.copy(),
|
||||
num_new_core_engines=new_data_parallel_size - cur_data_parallel_size,
|
||||
pending_notifications=dict(),
|
||||
)
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
allocate_stateless_group_ports(parallel_config, new_data_parallel_size)
|
||||
|
||||
reconfig_futures = []
|
||||
for cur_dp_rank, engine in enumerate(self.core_engines):
|
||||
@@ -1421,8 +1632,13 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
new_data_parallel_size=new_data_parallel_size,
|
||||
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
|
||||
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
|
||||
new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip,
|
||||
new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port,
|
||||
new_data_parallel_master_ip=parallel_config.data_parallel_master_ip,
|
||||
new_data_parallel_master_port=parallel_config.data_parallel_master_port,
|
||||
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
|
||||
new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list,
|
||||
new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list,
|
||||
new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list,
|
||||
new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list,
|
||||
)
|
||||
if cur_dp_rank >= new_data_parallel_size:
|
||||
reconfig_request.new_data_parallel_rank = (
|
||||
@@ -1433,23 +1649,24 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
)
|
||||
reconfig_futures.append(asyncio.create_task(coro))
|
||||
|
||||
for _ in range(new_data_parallel_size, cur_data_parallel_size):
|
||||
self.core_engines.pop()
|
||||
# NOTE(yongji): Immediately stop sending requests to the removing engines.
|
||||
self.core_engines = self.core_engines[:new_data_parallel_size]
|
||||
self.lb_engines = self.lb_engines[:new_data_parallel_size]
|
||||
wait_future = self._eep_wait_for_setup_switch_complete()
|
||||
|
||||
await asyncio.gather(*reconfig_futures)
|
||||
|
||||
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
|
||||
self.resources.engine_manager.scale_down_elastic_ep(
|
||||
cur_data_parallel_size, new_data_parallel_size
|
||||
)
|
||||
|
||||
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
|
||||
self._ensure_stats_update_task()
|
||||
scale_down_marker = msgspec.msgpack.encode(
|
||||
("SCALE_ELASTIC_EP", new_data_parallel_size)
|
||||
)
|
||||
await self.first_req_send_socket.send(scale_down_marker)
|
||||
|
||||
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
|
||||
# NOTE(yongji): Unlike scaling up,
|
||||
# here we don't actually need to wait for the setup switch to complete.
|
||||
# We may want to remove it in the future.
|
||||
await wait_future
|
||||
logger.info(
|
||||
"[Elastic EP] Scale down completed, new data parallel size: %s",
|
||||
new_data_parallel_size,
|
||||
|
||||
+47
-21
@@ -277,6 +277,8 @@ class CoreEngineActorManager:
|
||||
else:
|
||||
ray.init()
|
||||
|
||||
vllm_config.parallel_config.allocate_elastic_ep_ports()
|
||||
|
||||
if placement_groups is not None:
|
||||
assert local_dp_ranks is not None, (
|
||||
"local_dp_ranks must be provided if placement_groups is provided"
|
||||
@@ -584,6 +586,8 @@ class CoreEngineActorManager:
|
||||
|
||||
node_ip = node.node_ip
|
||||
node_id = node.node_id
|
||||
if device_str not in available_resources[node_id]:
|
||||
continue
|
||||
available_gpus = int(available_resources[node_id][device_str])
|
||||
|
||||
# Get total GPUs on this node from the node's resources
|
||||
@@ -773,11 +777,50 @@ class CoreEngineActorManager:
|
||||
ray.util.remove_placement_group(pg)
|
||||
|
||||
|
||||
def get_engine_zmq_addresses(
|
||||
vllm_config: VllmConfig,
|
||||
num_api_servers: int = 1,
|
||||
) -> EngineZmqAddresses:
|
||||
"""Allocate ZMQ addresses for engine-client communication."""
|
||||
parallel_config = vllm_config.parallel_config
|
||||
local_engine_count = parallel_config.data_parallel_size_local
|
||||
local_start_index = parallel_config.data_parallel_rank_local
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
local_engines_only = parallel_config.local_engines_only
|
||||
|
||||
# In offline mode there is an LLM instance per DP rank and
|
||||
# one core engine per LLM, see
|
||||
# examples/offline_inference/data_parallel.py.
|
||||
offline_mode = local_start_index is not None
|
||||
|
||||
# client_local_only = True for cases where this front-end
|
||||
# sends requests only to colocated engines.
|
||||
client_local_only = (
|
||||
offline_mode or local_engines_only or (local_engine_count == dp_size)
|
||||
)
|
||||
# NOTE(yongji): handling scaling from intra-node to inter-node
|
||||
if parallel_config.enable_elastic_ep:
|
||||
client_local_only = False
|
||||
|
||||
return EngineZmqAddresses(
|
||||
inputs=[
|
||||
get_engine_client_zmq_addr(client_local_only, host)
|
||||
for _ in range(num_api_servers)
|
||||
],
|
||||
outputs=[
|
||||
get_engine_client_zmq_addr(client_local_only, host)
|
||||
for _ in range(num_api_servers)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def launch_core_engines(
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
addresses: EngineZmqAddresses,
|
||||
num_api_servers: int = 1,
|
||||
) -> Iterator[
|
||||
tuple[
|
||||
@@ -796,29 +839,8 @@ def launch_core_engines(
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
local_engines_only = parallel_config.local_engines_only
|
||||
|
||||
# In offline mode there is an LLM instance per DP rank and
|
||||
# one core engine per LLM, see
|
||||
# examples/offline_inference/data_parallel.py.
|
||||
offline_mode = local_start_index is not None
|
||||
|
||||
# client_local_only = True for cases where this front-end
|
||||
# sends requests only to colocated engines.
|
||||
client_local_only = (
|
||||
offline_mode or local_engines_only or (local_engine_count == dp_size)
|
||||
)
|
||||
|
||||
# Set up input and output addresses.
|
||||
addresses = EngineZmqAddresses(
|
||||
inputs=[
|
||||
get_engine_client_zmq_addr(client_local_only, host)
|
||||
for _ in range(num_api_servers)
|
||||
],
|
||||
outputs=[
|
||||
get_engine_client_zmq_addr(client_local_only, host)
|
||||
for _ in range(num_api_servers)
|
||||
],
|
||||
)
|
||||
|
||||
# Run the DP Coordinator process with rank 0 when in online DP mode.
|
||||
# The coordinator is needed for:
|
||||
# 1. Internal/hybrid LB: collecting and publishing queue stats for load balancing
|
||||
@@ -885,6 +907,10 @@ def launch_core_engines(
|
||||
# will be False.
|
||||
handshake_local_only = offline_mode or local_engine_count == dp_size
|
||||
|
||||
# NOTE(yongji): handling scaling from intra-node to inter-node
|
||||
if parallel_config.enable_elastic_ep:
|
||||
handshake_local_only = False
|
||||
|
||||
handshake_address = get_engine_client_zmq_addr(
|
||||
handshake_local_only, host, parallel_config.data_parallel_rpc_port
|
||||
)
|
||||
|
||||
@@ -38,6 +38,7 @@ from vllm.distributed.parallel_state import (
|
||||
get_pcp_group,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from vllm.envs import enable_envs_cache
|
||||
from vllm.logger import init_logger
|
||||
@@ -580,17 +581,20 @@ class WorkerProc:
|
||||
)
|
||||
self.async_output_copy_thread.start()
|
||||
|
||||
# Initialize device
|
||||
self.worker.init_device()
|
||||
|
||||
# Set process title and log prefix
|
||||
self.setup_proc_title_and_log_prefix(
|
||||
enable_ep=vllm_config.parallel_config.enable_expert_parallel
|
||||
)
|
||||
|
||||
# Load model
|
||||
self._init_message_queues(input_shm_handle, vllm_config)
|
||||
self.worker.load_model()
|
||||
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
|
||||
if not is_eep_new_worker:
|
||||
self.worker.init_device()
|
||||
# Update process title now that parallel groups are initialized
|
||||
self.setup_proc_title_and_log_prefix(
|
||||
enable_ep=vllm_config.parallel_config.enable_expert_parallel
|
||||
)
|
||||
self.worker.load_model()
|
||||
|
||||
# Enable environment variable cache (e.g. assume no more
|
||||
# environment variable overrides after this point)
|
||||
@@ -885,6 +889,13 @@ class WorkerProc:
|
||||
|
||||
@staticmethod
|
||||
def setup_proc_title_and_log_prefix(enable_ep: bool) -> None:
|
||||
# Check if parallel groups are initialized first
|
||||
if not model_parallel_is_initialized():
|
||||
# Parallel groups not yet initialized, use default process name
|
||||
set_process_title(name="Worker")
|
||||
decorate_logs("Worker")
|
||||
return
|
||||
|
||||
dp_size = get_dp_group().world_size
|
||||
dp_rank = get_dp_group().rank_in_group
|
||||
pp_size = get_pp_group().world_size
|
||||
|
||||
@@ -382,8 +382,10 @@ class RayDistributedExecutor(Executor):
|
||||
all_kwargs.append(kwargs)
|
||||
self.collective_rpc("init_worker", args=(all_kwargs,))
|
||||
|
||||
self.collective_rpc("init_device")
|
||||
self.collective_rpc("load_model")
|
||||
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
|
||||
if not is_eep_new_worker:
|
||||
self.collective_rpc("init_device")
|
||||
self.collective_rpc("load_model")
|
||||
|
||||
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
||||
self.pp_tp_workers.append([])
|
||||
|
||||
@@ -14,7 +14,6 @@ import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.serial_utils import run_method
|
||||
@@ -43,9 +42,11 @@ class UniProcExecutor(Executor):
|
||||
max_workers=1, thread_name_prefix="WorkerAsyncOutput"
|
||||
)
|
||||
|
||||
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
|
||||
self.driver_worker.init_worker(all_kwargs=[kwargs])
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.load_model()
|
||||
if not is_eep_new_worker:
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.load_model()
|
||||
|
||||
def _distributed_args(self) -> tuple[str, int, int]:
|
||||
"""Return (distributed_init_method, rank, local_rank)."""
|
||||
@@ -122,16 +123,6 @@ class UniProcExecutor(Executor):
|
||||
# it's running.
|
||||
return
|
||||
|
||||
def reinitialize_distributed(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
self.driver_worker.reinitialize_distributed(reconfig_request)
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
|
||||
):
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if worker := self.driver_worker:
|
||||
worker.shutdown()
|
||||
|
||||
@@ -53,7 +53,12 @@ class CPUModelRunner(GPUModelRunner):
|
||||
v.gpu = v.cpu
|
||||
|
||||
@instrument(span_name="Loading (CPU)")
|
||||
def load_model(self, eep_scale_up: bool = False) -> None:
|
||||
def load_model(self, load_dummy_weights: bool = False) -> None:
|
||||
if load_dummy_weights:
|
||||
raise ValueError(
|
||||
"Loading dummy weights (needed for elastic EP scale-up) "
|
||||
"Is not supported by the CPU Model Runner."
|
||||
)
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
|
||||
@@ -461,6 +461,8 @@ class GPUModelRunner(
|
||||
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
|
||||
|
||||
self.eplb_state: EplbState | None = None
|
||||
# NOTE(yongji): flag to temporarily disable EPLB during scaling up/down
|
||||
self.eep_eplb_suppressed = False
|
||||
"""
|
||||
State of the expert parallelism load balancer.
|
||||
|
||||
@@ -2702,7 +2704,7 @@ class GPUModelRunner(
|
||||
"""
|
||||
Step for the EPLB (Expert Parallelism Load Balancing) state.
|
||||
"""
|
||||
if not self.parallel_config.enable_eplb:
|
||||
if not self.parallel_config.enable_eplb or self.eep_eplb_suppressed:
|
||||
return
|
||||
|
||||
assert self.eplb_state is not None
|
||||
@@ -2714,6 +2716,23 @@ class GPUModelRunner(
|
||||
log_stats=self.parallel_config.eplb_config.log_balancedness,
|
||||
)
|
||||
|
||||
def setup_eplb_from_mapping(
|
||||
self,
|
||||
expanded_physical_to_logical: torch.Tensor,
|
||||
old_num_physical_experts: int,
|
||||
) -> None:
|
||||
model = self.get_model()
|
||||
assert is_mixture_of_experts(model)
|
||||
|
||||
self.eplb_state = EplbState.from_mapping(
|
||||
model=model,
|
||||
model_config=self.model_config,
|
||||
device=self.device,
|
||||
parallel_config=self.parallel_config,
|
||||
expanded_physical_to_logical=expanded_physical_to_logical,
|
||||
num_valid_physical_experts=old_num_physical_experts,
|
||||
)
|
||||
|
||||
def _pool(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -4175,21 +4194,16 @@ class GPUModelRunner(
|
||||
setattr(self, config_name, new_config)
|
||||
|
||||
@instrument(span_name="Loading (GPU)")
|
||||
def load_model(self, eep_scale_up: bool = False) -> None:
|
||||
def load_model(self, load_dummy_weights: bool = False) -> None:
|
||||
"""
|
||||
Args:
|
||||
eep_scale_up: the model loading is for elastic EP scale up.
|
||||
load_dummy_weights: load dummy weights instead of real weights.
|
||||
"""
|
||||
logger.info_once(
|
||||
"Starting to load model %s...",
|
||||
self.model_config.model,
|
||||
scope="global",
|
||||
)
|
||||
global_expert_loads, old_global_expert_indices_per_model, rank_mapping = (
|
||||
EplbState.get_eep_state(self.parallel_config)
|
||||
if eep_scale_up
|
||||
else (None, None, None)
|
||||
)
|
||||
|
||||
if self.parallel_config.enable_eplb:
|
||||
self.eplb_state = EplbState(self.parallel_config, self.device)
|
||||
@@ -4198,6 +4212,8 @@ class GPUModelRunner(
|
||||
try:
|
||||
with DeviceMemoryProfiler() as m:
|
||||
time_before_load = time.perf_counter()
|
||||
if load_dummy_weights:
|
||||
self.load_config.load_format = "dummy"
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
self.model = model_loader.load_model(
|
||||
vllm_config=self.vllm_config, model_config=self.model_config
|
||||
@@ -4214,6 +4230,9 @@ class GPUModelRunner(
|
||||
and is_mixture_of_experts(self.drafter.model)
|
||||
and self.parallel_config.enable_eplb
|
||||
):
|
||||
assert not self.parallel_config.enable_elastic_ep, (
|
||||
"Elastic EP is not supported with drafter model."
|
||||
)
|
||||
spec_config = self.vllm_config.speculative_config
|
||||
assert spec_config is not None
|
||||
assert spec_config.draft_model_config is not None
|
||||
@@ -4221,17 +4240,6 @@ class GPUModelRunner(
|
||||
"EPLB is enabled for drafter model %s.",
|
||||
spec_config.draft_model_config.model,
|
||||
)
|
||||
|
||||
global_expert_load = (
|
||||
global_expert_loads[eplb_models]
|
||||
if global_expert_loads
|
||||
else None
|
||||
)
|
||||
old_global_expert_indices = (
|
||||
old_global_expert_indices_per_model[eplb_models]
|
||||
if old_global_expert_indices_per_model
|
||||
else None
|
||||
)
|
||||
if self.eplb_state is None:
|
||||
self.eplb_state = EplbState(
|
||||
self.parallel_config, self.device
|
||||
@@ -4239,9 +4247,6 @@ class GPUModelRunner(
|
||||
self.eplb_state.add_model(
|
||||
self.drafter.model,
|
||||
spec_config.draft_model_config,
|
||||
global_expert_load,
|
||||
old_global_expert_indices,
|
||||
rank_mapping,
|
||||
)
|
||||
eplb_models += 1
|
||||
|
||||
@@ -4283,11 +4288,12 @@ class GPUModelRunner(
|
||||
time_after_load - time_before_load,
|
||||
scope="local",
|
||||
)
|
||||
prepare_communication_buffer_for_model(self.model)
|
||||
if (drafter := getattr(self, "drafter", None)) and (
|
||||
drafter_model := getattr(drafter, "model", None)
|
||||
):
|
||||
prepare_communication_buffer_for_model(drafter_model)
|
||||
if not load_dummy_weights:
|
||||
prepare_communication_buffer_for_model(self.model)
|
||||
if (drafter := getattr(self, "drafter", None)) and (
|
||||
drafter_model := getattr(drafter, "model", None)
|
||||
):
|
||||
prepare_communication_buffer_for_model(drafter_model)
|
||||
mm_config = self.model_config.multimodal_config
|
||||
self.is_multimodal_pruning_enabled = (
|
||||
supports_multimodal_pruning(self.get_model())
|
||||
@@ -4295,26 +4301,19 @@ class GPUModelRunner(
|
||||
and mm_config.is_multimodal_pruning_enabled()
|
||||
)
|
||||
|
||||
if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb:
|
||||
if (
|
||||
is_mixture_of_experts(self.model)
|
||||
and self.parallel_config.enable_eplb
|
||||
and not load_dummy_weights
|
||||
):
|
||||
logger.info_once("EPLB is enabled for model %s.", self.model_config.model)
|
||||
global_expert_load = (
|
||||
global_expert_loads[eplb_models] if global_expert_loads else None
|
||||
)
|
||||
old_global_expert_indices = (
|
||||
old_global_expert_indices_per_model[eplb_models]
|
||||
if old_global_expert_indices_per_model
|
||||
else None
|
||||
)
|
||||
assert self.eplb_state is not None
|
||||
self.eplb_state.add_model(
|
||||
self.model,
|
||||
self.model_config,
|
||||
global_expert_load,
|
||||
old_global_expert_indices,
|
||||
rank_mapping,
|
||||
)
|
||||
if self.eplb_state.is_async:
|
||||
self.eplb_state.start_async_loop(rank_mapping=rank_mapping)
|
||||
self.eplb_state.start_async_loop()
|
||||
|
||||
if (
|
||||
self.vllm_config.compilation_config.mode
|
||||
|
||||
+27
-228
@@ -7,11 +7,10 @@ import os
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from types import NoneType
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
@@ -32,14 +31,12 @@ from vllm.distributed.kv_transfer import (
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
Handle,
|
||||
get_pcp_group,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
)
|
||||
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.models.interfaces import is_mixture_of_experts
|
||||
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
|
||||
@@ -49,7 +46,6 @@ from vllm.tracing import instrument
|
||||
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import (
|
||||
AsyncModelRunnerOutput,
|
||||
@@ -124,6 +120,10 @@ class Worker(WorkerBase):
|
||||
precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
|
||||
torch.set_float32_matmul_precision(precision)
|
||||
|
||||
from vllm.distributed.elastic_ep.elastic_execute import ElasticEPScalingExecutor
|
||||
|
||||
self.elastic_ep_executor = ElasticEPScalingExecutor(self)
|
||||
|
||||
# Buffers saved before sleep
|
||||
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
|
||||
|
||||
@@ -317,12 +317,29 @@ class Worker(WorkerBase):
|
||||
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
|
||||
# to hijack tensor allocation.
|
||||
def load_model(self) -> None:
|
||||
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
|
||||
dummy_weights = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
|
||||
if dummy_weights:
|
||||
(
|
||||
expanded_physical_to_logical,
|
||||
num_logical_experts,
|
||||
old_num_physical_experts,
|
||||
) = self.elastic_ep_executor.receive_expert_mapping()
|
||||
num_physical_experts = expanded_physical_to_logical.shape[1]
|
||||
self.parallel_config.eplb_config.num_redundant_experts = (
|
||||
num_physical_experts - num_logical_experts
|
||||
)
|
||||
|
||||
with (
|
||||
self._maybe_get_memory_pool_context(tag="weights"),
|
||||
set_current_vllm_config(self.vllm_config),
|
||||
):
|
||||
self.model_runner.load_model(eep_scale_up=eep_scale_up)
|
||||
self.model_runner.load_model(load_dummy_weights=dummy_weights)
|
||||
|
||||
if dummy_weights:
|
||||
self.model_runner.setup_eplb_from_mapping(
|
||||
expanded_physical_to_logical, old_num_physical_experts
|
||||
)
|
||||
self.model_runner.eep_eplb_suppressed = True
|
||||
|
||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||
self.model_runner.update_config(overrides)
|
||||
@@ -801,227 +818,6 @@ class Worker(WorkerBase):
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info(
|
||||
"[Elastic EP] Starting expert resharding before scaling down..."
|
||||
)
|
||||
rank_mapping = {
|
||||
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
|
||||
for old_ep_rank in range(old_ep_size)
|
||||
}
|
||||
assert self.model_runner.eplb_state is not None
|
||||
self.model_runner.eplb_state.rearrange(
|
||||
execute_shuffle=True,
|
||||
global_expert_loads=None,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Expert resharding completed!")
|
||||
|
||||
def _eplb_after_scale_up(
|
||||
self,
|
||||
old_ep_size: int,
|
||||
new_ep_size: int,
|
||||
global_expert_loads: list[torch.Tensor] | None,
|
||||
) -> None:
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Starting expert resharding after scaling up...")
|
||||
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
|
||||
assert self.model_runner.eplb_state is not None
|
||||
self.model_runner.eplb_state.rearrange(
|
||||
execute_shuffle=True,
|
||||
global_expert_loads=global_expert_loads,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Expert resharding completed!")
|
||||
|
||||
def _reconfigure_parallel_config(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
"""
|
||||
Update parallel config with provided reconfig_request
|
||||
"""
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank_local = (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
)
|
||||
parallel_config.data_parallel_master_ip = (
|
||||
reconfig_request.new_data_parallel_master_ip
|
||||
)
|
||||
parallel_config.data_parallel_master_port = (
|
||||
reconfig_request.new_data_parallel_master_port
|
||||
)
|
||||
|
||||
def _reconfigure_moe(
|
||||
self, old_ep_size: int, new_ep_size: int
|
||||
) -> list[torch.Tensor] | None:
|
||||
"""
|
||||
Reconfigure MoE modules with provided reconfig_request
|
||||
|
||||
Return the global expert load if new_ep_size > old_ep_size,
|
||||
otherwise None
|
||||
"""
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
prepare_communication_buffer_for_model,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEParallelConfig,
|
||||
)
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
|
||||
def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
|
||||
return [
|
||||
module
|
||||
for module in model.modules()
|
||||
if (
|
||||
module.__class__.__name__ == "FusedMoE"
|
||||
or module.__class__.__name__ == "SharedFusedMoE"
|
||||
)
|
||||
]
|
||||
|
||||
def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
|
||||
assert all(
|
||||
module.moe_config.num_local_experts == num_local_experts
|
||||
for module in moe_modules
|
||||
), "All MoE modules must have the same number of experts"
|
||||
for module in moe_modules:
|
||||
module.moe_config.num_experts = num_local_experts * new_ep_size
|
||||
module.global_num_experts = module.moe_config.num_experts
|
||||
tp_size = get_tp_group().world_size
|
||||
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
sp_size = tp_size if is_sequence_parallel else 1
|
||||
module.moe_parallel_config = FusedMoEParallelConfig.make(
|
||||
tp_size_=tp_size,
|
||||
pcp_size_=get_pcp_group().world_size,
|
||||
dp_size_=get_dp_group().world_size,
|
||||
sp_size_=sp_size,
|
||||
vllm_parallel_config=parallel_config,
|
||||
)
|
||||
module.moe_config.moe_parallel_config = module.moe_parallel_config
|
||||
return moe_modules
|
||||
|
||||
model_moe_modules = get_moe_modules(self.model_runner.model)
|
||||
num_local_experts = model_moe_modules[0].moe_config.num_local_experts
|
||||
|
||||
update_moe_modules(model_moe_modules, num_local_experts)
|
||||
drafter_model = None
|
||||
if hasattr(self.model_runner, "drafter") and hasattr(
|
||||
self.model_runner.drafter, "model"
|
||||
):
|
||||
drafter_model = self.model_runner.drafter.model
|
||||
if drafter_model is not None and is_mixture_of_experts(drafter_model):
|
||||
drafter_moe_modules = get_moe_modules(drafter_model)
|
||||
# Check if drafter and model have matching configs
|
||||
assert (
|
||||
drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
|
||||
), "Drafter and model configs should be the same"
|
||||
update_moe_modules(drafter_moe_modules, num_local_experts)
|
||||
|
||||
if new_ep_size < old_ep_size:
|
||||
num_local_physical_experts = num_local_experts
|
||||
assert self.model_runner.eplb_state is not None
|
||||
new_physical_experts = (
|
||||
self.model_runner.eplb_state.physical_to_logical_map.shape[1] # type: ignore[attr-defined]
|
||||
)
|
||||
parallel_config.eplb_config.num_redundant_experts = (
|
||||
new_physical_experts
|
||||
- self.model_runner.eplb_state.logical_replica_count.shape[1] # type: ignore[attr-defined]
|
||||
)
|
||||
global_expert_loads = None
|
||||
else:
|
||||
num_local_physical_experts_tensor = torch.tensor(
|
||||
[num_local_experts], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
torch.distributed.broadcast(
|
||||
num_local_physical_experts_tensor,
|
||||
group=get_ep_group().cpu_group,
|
||||
group_src=0,
|
||||
)
|
||||
num_local_physical_experts = int(num_local_physical_experts_tensor.item())
|
||||
new_physical_experts = num_local_physical_experts * new_ep_size
|
||||
assert self.model_runner.eplb_state is not None
|
||||
global_expert_loads_any = self.model_runner.eplb_state.rearrange(
|
||||
execute_shuffle=False
|
||||
)
|
||||
global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
|
||||
parallel_config.eplb_config.num_redundant_experts = (
|
||||
new_physical_experts - global_expert_loads[0].shape[1]
|
||||
)
|
||||
prepare_communication_buffer_for_model(self.model_runner.model)
|
||||
if drafter_model is not None:
|
||||
prepare_communication_buffer_for_model(drafter_model)
|
||||
self.model_runner.model.update_physical_experts_metadata(
|
||||
num_physical_experts=new_physical_experts,
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
)
|
||||
return global_expert_loads
|
||||
|
||||
def reinitialize_distributed(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
cleanup_dist_env_and_memory,
|
||||
get_ep_group,
|
||||
)
|
||||
|
||||
old_ep_size = get_ep_group().world_size
|
||||
old_ep_rank = get_ep_group().rank
|
||||
new_ep_size = (
|
||||
reconfig_request.new_data_parallel_size
|
||||
* get_tp_group().world_size
|
||||
* get_pp_group().world_size
|
||||
)
|
||||
if new_ep_size < old_ep_size:
|
||||
self._eplb_before_scale_down(old_ep_size, new_ep_size)
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
|
||||
):
|
||||
assert old_ep_rank >= new_ep_size
|
||||
# shutdown
|
||||
return
|
||||
|
||||
self._reconfigure_parallel_config(reconfig_request)
|
||||
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
init_worker_distributed_environment(
|
||||
self.vllm_config,
|
||||
self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank,
|
||||
)
|
||||
|
||||
global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
|
||||
|
||||
if new_ep_size > old_ep_size:
|
||||
assert global_expert_loads is not None
|
||||
self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
@@ -1118,6 +914,9 @@ class Worker(WorkerBase):
|
||||
if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
|
||||
weight_transfer_engine.shutdown()
|
||||
|
||||
def elastic_ep_execute(self, execute_method: str, *args, **kwargs):
|
||||
return self.elastic_ep_executor.execute(execute_method, *args, **kwargs)
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
vllm_config: VllmConfig,
|
||||
|
||||
@@ -66,6 +66,23 @@ class WorkspaceManager:
|
||||
],
|
||||
)
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Unlock the workspace to allow growth.
|
||||
|
||||
This is used during elastic EP scaling when the workspace size
|
||||
needs to grow due to changes in the number of experts.
|
||||
"""
|
||||
self._locked = False
|
||||
if envs.VLLM_DEBUG_WORKSPACE:
|
||||
logger.info(
|
||||
"[WORKSPACE DEBUG] Workspace unlocked. Current sizes: %s",
|
||||
[
|
||||
self._workspace_size_bytes(ws) / _MB
|
||||
for ws in self._current_workspaces
|
||||
if ws is not None
|
||||
],
|
||||
)
|
||||
|
||||
def is_locked(self) -> bool:
|
||||
"""Check if workspace is locked."""
|
||||
return self._locked
|
||||
@@ -242,6 +259,17 @@ def lock_workspace() -> None:
|
||||
current_workspace_manager().lock()
|
||||
|
||||
|
||||
def unlock_workspace() -> None:
|
||||
"""Unlock the workspace to allow growth.
|
||||
|
||||
This is used during elastic EP scaling when the workspace size
|
||||
needs to grow due to changes in the number of experts.
|
||||
After scaling operations complete, lock_workspace() should be
|
||||
called again to prevent unexpected allocations.
|
||||
"""
|
||||
current_workspace_manager().unlock()
|
||||
|
||||
|
||||
def reset_workspace_manager() -> None:
|
||||
"""Reset the workspace manager to uninitialized state.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user