[Core] Remove busy loop from idle buffer readers (#28053)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: Nick Hill <nickhill123@gmail.com>
Co-authored-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Joe Runde
2026-03-04 00:44:20 -07:00
committed by GitHub
parent 5d199ac8f2
commit 6f0dd93801
5 changed files with 579 additions and 135 deletions
@@ -124,8 +124,6 @@ def test_models(
[
("facebook/opt-125m", "ray", "", "L4", {}),
("facebook/opt-125m", "mp", "", "L4", {}),
("facebook/opt-125m", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
("facebook/opt-125m", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}),
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
("facebook/opt-125m", "ray", "", "A100", {}),
+285 -8
View File
@@ -1,11 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing
import random
import threading
import time
from unittest import mock
import multiprocess as mp
import numpy as np
import pytest
import torch.distributed as dist
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
@@ -22,7 +25,14 @@ def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]:
return [np.random.randint(1, 100, i) for i in sizes]
def distributed_run(fn, world_size):
def distributed_run(fn, world_size, timeout=60):
"""Run a function in multiple processes with proper error handling.
Args:
fn: Function to run in each process
world_size: Number of processes to spawn
timeout: Maximum time in seconds to wait for processes (default: 60)
"""
number_of_processes = world_size
processes = []
for i in range(number_of_processes):
@@ -33,19 +43,45 @@ def distributed_run(fn, world_size):
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = "12345"
p = multiprocessing.Process(target=fn, args=(env,))
p = mp.Process(target=fn, args=(env,))
processes.append(p)
p.start()
for p in processes:
p.join()
# Monitor processes and fail fast if any process fails
start_time = time.time()
failed_processes = []
for p in processes:
assert p.exitcode == 0
# Wait for all processes, checking for failures
while time.time() - start_time < timeout:
all_done = True
for i, p in enumerate(processes):
if p.is_alive():
all_done = False
elif p.exitcode != 0:
# Process failed
failed_processes.append((i, p.exitcode))
break
if failed_processes or all_done:
break
time.sleep(0.1) # Check every 100ms
# Check for timeout if no failures detected yet
for i, p in enumerate(processes):
if p.is_alive():
p.kill()
p.join()
# Report failures
if failed_processes:
error_msg = "Distributed test failed:\n"
for rank, status in failed_processes:
error_msg += f" Rank {rank}: Exit code {status}\n"
raise AssertionError(error_msg)
def worker_fn_wrapper(fn):
# `multiprocessing.Process` cannot accept environment variables directly
# `mp.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapped_fn(env):
@@ -115,3 +151,244 @@ def worker_fn():
def test_shm_broadcast():
distributed_run(worker_fn, 4)
@worker_fn_wrapper
def worker_fn_test_shutdown_busy():
rank = dist.get_rank()
writer_rank = 2
message_queue = MessageQueue.create_from_process_group(
dist.group.WORLD, 40 * 1024, 2, writer_rank
)
if not message_queue._is_writer:
# Put into busy mode
message_queue._spin_condition.busy_loop_s = 9999
shutdown_event = threading.Event()
def shutdown_thread(mq, shutdown_event):
shutdown_event.wait()
mq.shutdown()
threading.Thread(
target=shutdown_thread, args=(message_queue, shutdown_event)
).start()
with pytest.raises(TimeoutError):
message_queue.dequeue(timeout=0.01)
shutdown_event.set()
with pytest.raises(RuntimeError, match="cancelled"):
message_queue.dequeue(timeout=1)
assert message_queue.shutting_down
print(f"torch distributed passed the test! Rank {rank}")
dist.barrier()
def test_message_queue_shutdown_busy(caplog_vllm):
distributed_run(worker_fn_test_shutdown_busy, 4)
print(caplog_vllm.text)
@worker_fn_wrapper
def worker_fn_test_shutdown_idle():
rank = dist.get_rank()
writer_rank = 2
message_queue = MessageQueue.create_from_process_group(
dist.group.WORLD, 40 * 1024, 2, writer_rank
)
if not message_queue._is_writer:
# Put into idle mode
message_queue._spin_condition.last_read = 0
shutdown_event = threading.Event()
def shutdown_thread(mq, shutdown_event):
shutdown_event.wait()
mq.shutdown()
threading.Thread(
target=shutdown_thread, args=(message_queue, shutdown_event)
).start()
with pytest.raises(TimeoutError):
message_queue.dequeue(timeout=0.01)
shutdown_event.set()
with pytest.raises(RuntimeError, match="cancelled"):
message_queue.dequeue(timeout=1)
assert message_queue.shutting_down
print(f"torch distributed passed the test! Rank {rank}")
dist.barrier()
def test_message_queue_shutdown_idle():
distributed_run(worker_fn_test_shutdown_idle, 4)
@worker_fn_wrapper
def worker_fn_test_idle_to_busy():
rank = dist.get_rank()
writer_rank = 2
message_queue = MessageQueue.create_from_process_group(
dist.group.WORLD, 40 * 1024, 2, writer_rank
)
message1 = "hello world"
message2 = np.random.randint(1, 100, 100)
with mock.patch.object(
message_queue._spin_condition, "wait", wraps=message_queue._spin_condition.wait
) as wrapped_wait:
if not message_queue._is_writer:
# Put into idle mode
message_queue._spin_condition.last_read = 0
# no messages, so expect a TimeoutError
with pytest.raises(TimeoutError):
message_queue.dequeue(timeout=0.01)
# wait should only be called once while idle
assert wrapped_wait.call_count == 1
# sync with the writer and wait for message1
dist.barrier()
recv_message = message_queue.dequeue(timeout=5)
assert recv_message == message1
# second call to wait, with a message read, this puts in a busy spin
assert wrapped_wait.call_count == 2
# sync with the writer and wait for message2
dist.barrier()
recv_message = message_queue.dequeue(timeout=1)
assert np.array_equal(recv_message, message2)
# in busy mode, we expect wait to have been called multiple times
assert wrapped_wait.call_count > 3
else:
# writer writes two messages in sync with the reader
dist.barrier()
# sleep delays the send to ensure reader enters the read loop
time.sleep(0.1)
message_queue.enqueue(message1)
dist.barrier()
time.sleep(0.1)
message_queue.enqueue(message2)
message_queue.shutdown()
assert message_queue.shutting_down
print(f"torch distributed passed the test! Rank {rank}")
def test_message_queue_idle_wake():
distributed_run(worker_fn_test_idle_to_busy, 4)
@worker_fn_wrapper
def worker_fn_test_busy_to_idle():
rank = dist.get_rank()
writer_rank = 2
message_queue = MessageQueue.create_from_process_group(
dist.group.WORLD, 40 * 1024, 2, writer_rank
)
message1 = 12345
message2 = list(range(3))
with mock.patch.object(
message_queue._spin_condition, "wait", wraps=message_queue._spin_condition.wait
) as wrapped_wait:
if not message_queue._is_writer:
# Put into busy mode
message_queue._spin_condition.busy_loop_s = 9999
# sync with the writer and wait for message1
dist.barrier()
recv_message = message_queue.dequeue(timeout=1)
assert recv_message == message1
# in busy mode, we expect wait to have been called many times
assert wrapped_wait.call_count > 1
# simulate busy loop ending
message_queue._spin_condition.busy_loop_s = 0
# ensure we enter idle mode, then record call count
with pytest.raises(TimeoutError):
message_queue.dequeue(timeout=0.01)
call_count = wrapped_wait.call_count
# sync with the writer and wait for message2
dist.barrier()
recv_message = message_queue.dequeue(timeout=1)
assert recv_message == message2
# call to wait after idle should only happen once
assert wrapped_wait.call_count == call_count + 1
else:
# writer writes two messages in sync with the reader
dist.barrier()
# sleep delays the send to ensure reader enters the read loop
time.sleep(0.1)
message_queue.enqueue(message1)
dist.barrier()
time.sleep(0.1)
message_queue.enqueue(message2)
message_queue.shutdown()
assert message_queue.shutting_down
print(f"torch distributed passed the test! Rank {rank}")
def test_message_queue_busy_to_idle():
distributed_run(worker_fn_test_busy_to_idle, 4)
def test_warning_logs(caplog_vllm):
"""
Test that warning logs are emitted at VLLM_RINGBUFFER_WARNING_INTERVAL intervals
when indefinite=False, and are not emitted when indefinite=True.
"""
# Patch the warning log interval to every 1 ms during reads
with mock.patch(
"vllm.distributed.device_communicators.shm_broadcast.VLLM_RINGBUFFER_WARNING_INTERVAL",
new=0.001, # 1 ms
):
writer = MessageQueue(
n_reader=1,
n_local_reader=1,
max_chunk_bytes=1024 * 1024, # 1MB chunks
max_chunks=10,
)
reader = MessageQueue.create_from_handle(writer.export_handle(), rank=0)
writer.wait_until_ready()
reader.wait_until_ready()
# We should have at least one warning log here
# "0 seconds" expected due to rounding of 1ms test interval
with pytest.raises(TimeoutError):
reader.dequeue(timeout=0.01, indefinite=False)
assert any(
"No available shared memory broadcast block found in 0 seconds"
in record.message
for record in caplog_vllm.records
)
caplog_vllm.clear()
# We should have no warnings this time
with pytest.raises(TimeoutError):
reader.dequeue(timeout=0.01, indefinite=True)
assert all(
"No available shared memory broadcast block found in 0 seconds"
not in record.message
for record in caplog_vllm.records
)
# Clean up when done
writer.shutdown()
reader.shutdown()
@@ -2,13 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import pickle
import sys
import threading
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
from pickle import PickleBuffer
from threading import Event
from typing import TYPE_CHECKING, Any, cast
from unittest.mock import patch
@@ -18,6 +18,7 @@ import zmq
from torch.distributed import ProcessGroup
from zmq import ( # type: ignore
IPV6, # type: ignore
PUB,
SUB,
SUBSCRIBE,
XPUB,
@@ -32,6 +33,7 @@ from vllm.platforms import current_platform
from vllm.utils.network_utils import (
get_ip,
get_open_port,
get_open_zmq_inproc_path,
get_open_zmq_ipc_path,
is_valid_ipv6_address,
)
@@ -78,50 +80,125 @@ def to_bytes_big(value: int, size: int) -> bytes:
logger = init_logger(__name__)
def long_wait_time_msg(threshold: int) -> str:
return (
"No available shared memory broadcast block found "
f"in {threshold} seconds. This typically happens "
"when some processes are hanging or doing some "
"time-consuming work (e.g. compilation, "
"weight/kv cache quantization)."
)
LONG_WAIT_TIME_LOG_MSG = (
"No available shared memory broadcast block found "
"in %d seconds. This typically happens "
"when some processes are hanging or doing some "
"time-consuming work (e.g. compilation, "
"weight/kv cache quantization)."
)
class SpinTimer:
def record_activity(self):
pass
def spin(self):
sched_yield()
class SpinSleepTimer(SpinTimer):
class SpinCondition:
"""
In setups which have long inactivity periods it is desirable to reduce
system power consumption when vllm does nothing. This would lead to more
CPU thermal headroom when a request eventually comes, especially when
multiple GPUs are connected as each GPU would otherwise pin one thread at
100% CPU usage.
This class implements an interface similar to a threading.Condition. It
allows a writer to notify readers to wake up and read from the shared memory
buffer. This notification is done over a zmq socket.
The simplest solution is to reduce polling frequency when there is no
activity for a certain period of time.
For optimal performance under load we don't want the readers to need to poll
the zmq socket for every read. So the `wait` method here will return
immediately when reads are frequent, and will only enter "idle mode" and
await a notification on the zmq socket after a period of inactivity. This
allows the readers to spin quickly, hence "SpinCondition".
To support clean shutdown, a separate thread in the reader's process must be
able to wake the reader so that it can exit. A separate cancel() method is
implemented with an in-process socket to allow this interruption.
"""
def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1):
self.last_activity = time.monotonic()
self.busy_loop_s = busy_loop_s
self.wait_sleep_s = wait_sleep_s
def __init__(
self,
is_reader: bool,
context: zmq.Context,
notify_address: str,
busy_loop_s: float = 1,
):
self.is_reader = is_reader
def record_activity(self):
self.last_activity = time.monotonic()
if is_reader:
# Time of last shm buffer read
self.last_read = time.monotonic()
def spin(self):
curr_time = time.monotonic()
if curr_time >= self.last_activity + self.busy_loop_s:
time.sleep(self.wait_sleep_s)
# Time to keep busy-looping on the shm buffer before going idle
self.busy_loop_s = busy_loop_s
# Readers subscribe to write notifications
self.local_notify_socket: zmq.Socket = context.socket(SUB)
# Set zmq.CONFLATE to only keep the last message that the socket
# receives. This prevents us from piling up notification messages
# under high load when we aren't polling the socket.
self.local_notify_socket.setsockopt(zmq.CONFLATE, 1)
# Subscribe to all messages on the socket
self.local_notify_socket.setsockopt_string(SUBSCRIBE, "")
self.local_notify_socket.connect(notify_address)
# Readers require a process-local socket to poll for cancellation
cancel_path = get_open_zmq_inproc_path()
self.write_cancel_socket: zmq.Socket = context.socket(zmq.PAIR)
self.write_cancel_socket.bind(cancel_path)
self.read_cancel_socket: zmq.Socket = context.socket(zmq.PAIR)
self.read_cancel_socket.connect(cancel_path)
# Poller allows waiting on either `.notify()` or `.cancel()`
self.poller = zmq.Poller()
self.poller.register(self.read_cancel_socket, zmq.POLLIN)
self.poller.register(self.local_notify_socket, zmq.POLLIN)
else:
# Writer side publishes write notifications
self.local_notify_socket: zmq.Socket = context.socket(PUB) # type: ignore
# Set high water mark to 1 - we don't need to send a massive amount of
# pings during busy operation. PUB sockets will silently drop subsequent
# messages after the high water mark is reached.
self.local_notify_socket.setsockopt(zmq.SNDHWM, 1)
self.local_notify_socket.bind(notify_address)
self.last_read = 0
self.busy_loop_s = 0
self.read_cancel_socket = None
self.write_cancel_socket = None
self.poller = None
def record_read(self):
self.last_read = time.monotonic()
def cancel(self):
# Sends cancellation ping that will cause the reader to wake up.
# This is done from a monitor thread in the same process as the reader.
if self.is_reader:
logger.debug("Canceling waiting reads on SHM Buffer")
self.write_cancel_socket.send(b"\x00")
def wait(self, timeout_ms: int | None = None) -> None:
"""Wait for data on the shared memory buffer.
Yields the scheduler then returns immediately if it has been less than
self.busy_loop_s since the last read.
Otherwise, enters idle mode and awaits a socket ping for at most
`timeout_ms` milliseconds, or indefinitely if timeout_ms is None.
"""
assert self.is_reader, "Only readers can wait"
current_time = time.monotonic()
if current_time <= self.last_read + self.busy_loop_s:
sched_yield()
else:
events = dict(self.poller.poll(timeout=timeout_ms))
if self.read_cancel_socket in events:
logger.debug("Poller received cancel event")
elif self.local_notify_socket in events:
logger.debug("Poller received notify event")
# Since zmq.CONFLATE is set, there will only be one notification
# to read from the socket
self.local_notify_socket.recv(flags=zmq.NOBLOCK, copy=False)
else:
logger.debug("Poller timed out")
def notify(self):
"""Notifies all readers to wake up"""
assert not self.is_reader, "Only writers can notify"
self.local_notify_socket.send(b"\x00")
class ShmRingBuffer:
@@ -265,6 +342,7 @@ class Handle:
buffer_handle: tuple[int, int, int, str] | None = None
local_subscribe_addr: str | None = None
local_notify_addr: str | None = None
remote_subscribe_addr: str | None = None
remote_addr_ipv6: bool = False
@@ -288,7 +366,7 @@ class MessageQueue:
self.n_local_reader = n_local_reader
n_remote_reader = n_reader - n_local_reader
self.n_remote_reader = n_remote_reader
self.shutting_down = False
context = Context()
if n_local_reader > 0:
@@ -310,11 +388,19 @@ class MessageQueue:
self.local_socket.bind(local_subscribe_addr)
self.current_idx = 0
# Create the notification side of the SpinCondition
local_notify_addr = get_open_zmq_ipc_path()
self._spin_condition = SpinCondition(
is_reader=False, context=context, notify_address=local_notify_addr
)
else:
self.buffer = None # type: ignore
local_subscribe_addr = None
self.local_socket = None
self.current_idx = -1
local_notify_addr = None
self._spin_condition = None # type: ignore
remote_addr_ipv6 = False
if n_remote_reader > 0:
@@ -341,12 +427,12 @@ class MessageQueue:
self.local_reader_rank = -1
# rank does not matter for remote readers
self._is_remote_reader = False
self._read_spin_timer = SpinTimer()
self.handle = Handle(
local_reader_ranks=local_reader_ranks,
buffer_handle=self.buffer.handle() if self.buffer is not None else None,
local_subscribe_addr=local_subscribe_addr,
local_notify_addr=local_notify_addr,
remote_subscribe_addr=remote_subscribe_addr,
remote_addr_ipv6=remote_addr_ipv6,
)
@@ -379,9 +465,9 @@ class MessageQueue:
self.local_socket.connect(socket_addr)
self.remote_socket = None
self._read_spin_timer = (
SpinSleepTimer() if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer()
assert isinstance(handle.local_notify_addr, str)
self._spin_condition = SpinCondition(
is_reader=True, context=context, notify_address=handle.local_notify_addr
)
else:
self.buffer = None # type: ignore
@@ -399,7 +485,9 @@ class MessageQueue:
socket_addr = handle.remote_subscribe_addr
logger.debug("Connecting to %s", socket_addr)
self.remote_socket.connect(socket_addr)
self._spin_condition = None # type: ignore
self.shutting_down = False
return self
def wait_until_ready(self):
@@ -435,6 +523,13 @@ class MessageQueue:
recv = self.remote_socket.recv()
assert recv == b"READY"
def shutdown(self):
"""If this is an idle reader, wakes it up so it can clean up and shut
down"""
self.shutting_down = True
if self._spin_condition is not None:
self._spin_condition.cancel()
@contextmanager
def acquire_write(self, timeout: float | None = None):
assert self._is_writer, "Only writers can acquire write"
@@ -465,7 +560,7 @@ class MessageQueue:
# if we wait for a long time, log a message
if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning:
logger.info(
long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL)
LONG_WAIT_TIME_LOG_MSG, VLLM_RINGBUFFER_WARNING_INTERVAL
)
n_warning += 1
@@ -503,16 +598,60 @@ class MessageQueue:
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
break
class ReadTimeoutWithWarnings:
def __init__(self, timeout: float | None, should_warn: bool) -> None:
self.started = time.monotonic()
self.deadline = sys.maxsize if timeout is None else self.started + timeout
# if should_warn, we need to wake up periodically to log
self.warning_wait_time_ms: int | None = (
VLLM_RINGBUFFER_WARNING_INTERVAL * 1000 if should_warn else None
)
self._should_warn = should_warn
self.n_warning = 1
self.timeout = timeout
def timeout_ms(self) -> int | None:
"""Returns a timeout that is:
- min(time to deadline, time to next warning) if we're logging warnings
- time to deadline, if we're not logging warnings
- None if the timeout is None and we're not logging warnings
- raise TimeoutError if we are past the deadline
"""
warning_wait_time = self.warning_wait_time_ms
if self.timeout is None:
return warning_wait_time
time_left_ms = int((self.deadline - time.monotonic()) * 1000)
if time_left_ms <= 0:
raise TimeoutError
if warning_wait_time and warning_wait_time < time_left_ms:
return warning_wait_time
return time_left_ms
def should_warn(self) -> bool:
"""Returns true if it's time to log a warning for a timeout that is not
indefinite"""
if self._should_warn:
elapsed = time.monotonic() - self.started
if elapsed >= VLLM_RINGBUFFER_WARNING_INTERVAL * self.n_warning:
self.n_warning += 1
return True
return False
@contextmanager
def acquire_read(
self,
timeout: float | None = None,
cancel: Event | None = None,
indefinite: bool = False,
):
assert self._is_local_reader, "Only readers can acquire read"
start_time = time.monotonic()
n_warning = 1
read_timeout = self.ReadTimeoutWithWarnings(
timeout=timeout, should_warn=not indefinite
)
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
while True:
# Memory fence ensures we see the latest writes from the writer.
@@ -529,26 +668,16 @@ class MessageQueue:
# for readers, `self.current_idx` is the next block to read
# if this block is not ready,
# we need to wait until it is written
self._spin_condition.wait(timeout_ms=read_timeout.timeout_ms())
# Release the processor to other threads
self._read_spin_timer.spin()
if cancel is not None and cancel.is_set():
if self.shutting_down:
raise RuntimeError("cancelled")
# if we time out, raise an exception
elapsed = time.monotonic() - start_time
if timeout is not None and elapsed > timeout:
raise TimeoutError
# if we wait for a long time, log a message
if not indefinite and (
elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning
):
if read_timeout.should_warn():
logger.info(
long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL)
LONG_WAIT_TIME_LOG_MSG, VLLM_RINGBUFFER_WARNING_INTERVAL
)
n_warning += 1
continue
# found a block that is not read by this reader
@@ -565,7 +694,7 @@ class MessageQueue:
memory_fence()
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
self._read_spin_timer.record_activity()
self._spin_condition.record_read()
break
def enqueue(self, obj, timeout: float | None = None):
@@ -608,18 +737,19 @@ class MessageQueue:
buf[offset:buf_offset] = to_bytes_big(buf_len, 4)
buf[buf_offset : (offset := buf_offset + buf_len)] = buffer
self._spin_condition.notify()
if self.n_remote_reader > 0:
self.remote_socket.send_multipart(all_buffers, copy=False)
def dequeue(
self,
timeout: float | None = None,
cancel: Event | None = None,
indefinite: bool = False,
):
"""Read from message queue with optional timeout (in seconds)"""
if self._is_local_reader:
with self.acquire_read(timeout, cancel, indefinite) as buf:
with self.acquire_read(timeout, indefinite) as buf:
overflow = buf[0] == 1
if not overflow:
offset = 3
-5
View File
@@ -179,7 +179,6 @@ if TYPE_CHECKING:
VLLM_MOONCAKE_BOOTSTRAP_PORT: int = 8998
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
VLLM_SLEEP_WHEN_IDLE: bool = False
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300
VLLM_KV_CACHE_LAYOUT: Literal["NHD", "HND"] | None = None
@@ -1338,9 +1337,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS": lambda: int(
os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")
),
# Reduce CPU usage when vLLM is idle. Enabling this will incur small
# latency penalty when a request eventually comes.
"VLLM_SLEEP_WHEN_IDLE": lambda: bool(int(os.getenv("VLLM_SLEEP_WHEN_IDLE", "0"))),
# Control the max chunk bytes (in MB) for the rpc message queue.
# Object larger than this threshold will be broadcast to worker
# processes via zmq.
@@ -1751,7 +1747,6 @@ def compile_factors() -> dict[str, object]:
"VLLM_HTTP_TIMEOUT_KEEP_ALIVE",
"VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS",
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH",
"VLLM_SLEEP_WHEN_IDLE",
"VLLM_IMAGE_FETCH_TIMEOUT",
"VLLM_VIDEO_FETCH_TIMEOUT",
"VLLM_AUDIO_FETCH_TIMEOUT",
+103 -59
View File
@@ -104,7 +104,6 @@ class MultiprocExecutor(Executor):
# and ensure workers will be terminated.
self._finalizer = weakref.finalize(self, self.shutdown)
self.is_failed = False
self.shutdown_event = threading.Event()
self.failure_callback: FailureCallback | None = None
tp_size, pp_size, pcp_size = self._get_parallel_sizes()
@@ -158,20 +157,31 @@ class MultiprocExecutor(Executor):
global_start_rank = (
self.local_world_size * self.parallel_config.node_rank_within_dp
)
# Keep track of socket file descriptors that are inherited by the
# worker when using fork, so that we can close them in subsequent
# workers
inherited_fds: list[int] = []
for local_rank in range(self.local_world_size):
global_rank = global_start_rank + local_rank
is_driver_worker = self._is_driver_worker(global_rank)
unready_workers.append(
WorkerProc.make_worker_process(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=global_rank,
distributed_init_method=distributed_init_method,
input_shm_handle=scheduler_output_handle,
shared_worker_lock=shared_worker_lock,
is_driver_worker=is_driver_worker,
)
unready_worker_handle = WorkerProc.make_worker_process(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=global_rank,
distributed_init_method=distributed_init_method,
input_shm_handle=scheduler_output_handle,
shared_worker_lock=shared_worker_lock,
is_driver_worker=is_driver_worker,
inherited_fds=inherited_fds,
)
unready_workers.append(unready_worker_handle)
if context.get_start_method() == "fork":
inherited_fds.extend(
[
unready_worker_handle.death_writer.fileno(),
unready_worker_handle.ready_pipe.fileno(),
]
)
# Workers must be created before wait_for_ready to avoid
# deadlock, since worker.init_device() does a device sync.
@@ -220,6 +230,7 @@ class MultiprocExecutor(Executor):
for uw in unready_workers:
if uw.death_writer is not None:
uw.death_writer.close()
uw.death_writer = None
self._ensure_worker_termination([uw.proc for uw in unready_workers])
self.output_rank = self._get_output_rank()
@@ -255,6 +266,7 @@ class MultiprocExecutor(Executor):
died = multiprocessing.connection.wait(sentinels)
_self = self_ref()
if not _self or getattr(_self, "shutting_down", False):
logger.debug("MultiprocWorkerMonitor: shutdown already initiated")
return
_self.is_failed = True
proc_name = next(h.proc.name for h in workers if h.proc.sentinel == died[0])
@@ -354,8 +366,6 @@ class MultiprocExecutor(Executor):
if output_rank is not None:
response_mqs = (response_mqs[output_rank],)
shutdown_event = self.shutdown_event
def get_response():
responses = []
for mq in response_mqs:
@@ -363,9 +373,7 @@ class MultiprocExecutor(Executor):
None if deadline is None else (deadline - time.monotonic())
)
try:
status, result = mq.dequeue(
timeout=dequeue_timeout, cancel=shutdown_event
)
status, result = mq.dequeue(timeout=dequeue_timeout)
except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e
if status != WorkerProc.ResponseStatus.SUCCESS:
@@ -408,20 +416,26 @@ class MultiprocExecutor(Executor):
active_procs = lambda: [proc for proc in worker_procs if proc.is_alive()]
# Give processes time to clean themselves up properly first
logger.debug("Worker Termination: allow workers to gracefully shutdown")
if wait_for_termination(active_procs(), 4):
return
# Send SIGTERM if still running
logger.debug("Worker Termination: workers still running sending SIGTERM")
for p in active_procs():
p.terminate()
if not wait_for_termination(active_procs(), 4):
# Send SIGKILL if still running
logger.debug(
"Worker Termination: resorting to SIGKILL to take down workers"
)
for p in active_procs():
p.kill()
def shutdown(self):
"""Properly shut down the executor and its workers"""
if not getattr(self, "shutting_down", False):
logger.debug("Triggering shutdown of workers")
self.shutting_down = True
# Make sure all the worker processes are terminated first.
@@ -431,12 +445,20 @@ class MultiprocExecutor(Executor):
if w.death_writer is not None:
w.death_writer.close()
w.death_writer = None
w.worker_response_mq = None
self._ensure_worker_termination([w.proc for w in workers])
self.shutdown_event.set()
for w in workers:
# Shutdown response queues
if w.worker_response_mq is not None:
w.worker_response_mq.shutdown()
w.worker_response_mq = None
self.rpc_broadcast_mq = None
if self.rpc_broadcast_mq is not None:
self.rpc_broadcast_mq.shutdown()
self.rpc_broadcast_mq = None
for mq in self.response_mqs:
mq.shutdown()
self.response_mqs = []
def check_health(self) -> None:
self.collective_rpc("check_health", timeout=10)
@@ -609,24 +631,26 @@ class WorkerProc:
input_shm_handle, # Receive SchedulerOutput
shared_worker_lock: LockType,
is_driver_worker: bool,
inherited_fds: list[int],
) -> UnreadyWorkerProcHandle:
context = get_mp_context()
# (reader, writer)
reader, writer = context.Pipe(duplex=False)
# Create death pipe to detect parent process exit
# Ready pipe to communicate readiness from child to parent
ready_reader, ready_writer = context.Pipe(duplex=False)
# Death pipe to let child detect parent process exit
death_reader, death_writer = context.Pipe(duplex=False)
process_kwargs = {
"vllm_config": vllm_config,
"local_rank": local_rank,
"rank": rank,
"distributed_init_method": distributed_init_method,
"input_shm_handle": input_shm_handle,
"ready_pipe": (reader, writer),
"ready_pipe": ready_writer,
"death_pipe": death_reader,
"shared_worker_lock": shared_worker_lock,
"is_driver_worker": is_driver_worker,
# Have the worker close parent end of this worker's pipes too
"inherited_fds": inherited_fds
+ [ready_reader.fileno(), death_writer.fileno()],
}
# Run EngineCore busy loop in background process.
proc = context.Process(
@@ -637,10 +661,12 @@ class WorkerProc:
)
proc.start()
writer.close()
# Close child ends of pipes here in the parent
ready_writer.close()
death_reader.close()
# Keep death_writer open in parent - when parent exits,
# death_reader in child will get EOFError
return UnreadyWorkerProcHandle(proc, rank, reader, death_writer)
return UnreadyWorkerProcHandle(proc, rank, ready_reader, death_writer)
@staticmethod
def wait_for_response_handle_ready(
@@ -703,12 +729,41 @@ class WorkerProc:
return cast(list[WorkerProcHandle], ready_proc_handles)
def shutdown(self):
if self.rpc_broadcast_mq is not None:
self.rpc_broadcast_mq.shutdown()
if self.worker_response_mq is not None:
self.worker_response_mq.shutdown()
self.worker.shutdown()
self.rpc_broadcast_mq = None
self.worker_response_mq = None
destroy_model_parallel()
destroy_distributed_environment()
def monitor_death_pipe(self, death_pipe, shutdown_requested: threading.Event):
if death_pipe is None:
return
def death_pipe_monitor(queues_to_shutdown: list[MessageQueue]):
try:
# This will block until parent process exits (pipe closes)
death_pipe.recv()
except EOFError:
logger.info_once("Parent process exited, terminating worker queues")
shutdown_requested.set()
for mq in queues_to_shutdown:
if mq is not None:
mq.shutdown()
except Exception as e:
logger.warning("Death monitoring error: %s", e)
# Pass queue references directly to avoid gc issues if passing self
Thread(
target=death_pipe_monitor,
args=([self.rpc_broadcast_mq, self.worker_response_mq],),
daemon=True,
name="DeathPipeMonitor",
).start()
@staticmethod
def worker_main(*args, **kwargs):
"""Worker initialization and execution loops.
@@ -717,12 +772,12 @@ class WorkerProc:
# Signal handler used for graceful termination.
# SystemExit exception is only raised once to allow this and worker
# processes to terminate without error
shutdown_requested = False
shutdown_requested = threading.Event()
def signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested:
shutdown_requested = True
if not shutdown_requested.is_set():
shutdown_requested.set()
logger.debug(
"WorkerProc handling signal %d, raising SystemExit", signum
)
@@ -733,33 +788,20 @@ class WorkerProc:
signal.signal(signal.SIGINT, signal_handler)
worker = None
# tuple[Connection, Connection]
reader, ready_writer = kwargs.pop("ready_pipe")
death_pipe: Connection | None = kwargs.pop("death_pipe", None)
shutdown_event = threading.Event()
# Start death monitoring thread if death_pipe is provided
if death_pipe is not None:
ready_writer = kwargs.pop("ready_pipe")
death_pipe = kwargs.pop("death_pipe", None)
def monitor_parent_death():
try:
# This will block until parent process exits (pipe closes)
death_pipe.recv()
except EOFError:
# Parent process has exited, terminate this worker
logger.info_once("Parent process exited, terminating worker")
# Send signal to self to trigger clean shutdown
shutdown_event.set()
except Exception as e:
logger.warning("Death monitoring error: %s", e)
death_monitor = Thread(
target=monitor_parent_death, daemon=True, name="WorkerDeathMonitor"
)
death_monitor.start()
# Close inherited pipes from parent (incl. other worker pipes)
# Explicitly passing in existing pipes and closing them makes the pipe
# behave when using fork. Otherwise, a hidden reference to the pipes
# exist in the child process and prevents EOF closure.
for fd in kwargs.pop("inherited_fds", []):
try:
os.close(fd)
except Exception as e:
logger.warning("Exception closing inherited connection: %s", e)
try:
reader.close()
# Initialize tracer
rank = kwargs.get("rank", 0)
maybe_init_worker_tracer(
@@ -771,6 +813,8 @@ class WorkerProc:
worker = WorkerProc(*args, **kwargs)
assert worker.worker_response_mq is not None
worker.monitor_death_pipe(death_pipe, shutdown_requested)
# Send READY once we know everything is loaded
ready_writer.send(
{
@@ -788,7 +832,7 @@ class WorkerProc:
ready_writer.close()
ready_writer = None
worker.worker_busy_loop(cancel=shutdown_event)
worker.worker_busy_loop()
except Exception:
# NOTE: if an Exception arises in busy_loop, we send
@@ -798,7 +842,7 @@ class WorkerProc:
if ready_writer is not None:
logger.exception("WorkerProc failed to start.")
elif shutdown_event.is_set():
elif shutdown_requested.is_set():
logger.info("WorkerProc shutting down.")
else:
logger.exception("WorkerProc failed.")
@@ -806,7 +850,7 @@ class WorkerProc:
# The parent sends a SIGTERM to all worker processes if
# any worker dies. Set this value so we don't re-throw
# SystemExit() to avoid zmq exceptions in __del__.
shutdown_requested = True
shutdown_requested.set()
except SystemExit as e:
# SystemExit is raised on SIGTERM or SIGKILL, which usually indicates that
@@ -859,12 +903,12 @@ class WorkerProc:
output = self.async_output_queue.get()
self.enqueue_output(output)
def worker_busy_loop(self, cancel: threading.Event | None = None):
def worker_busy_loop(self):
"""Main busy loop for Multiprocessing Workers"""
assert self.rpc_broadcast_mq is not None
while True:
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue(
cancel=cancel, indefinite=True
indefinite=True
)
try:
if isinstance(method, str):