mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Core] Remove busy loop from idle buffer readers (#28053)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> Signed-off-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: Travis Johnson <tsjohnso@us.ibm.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -124,8 +124,6 @@ def test_models(
|
||||
[
|
||||
("facebook/opt-125m", "ray", "", "L4", {}),
|
||||
("facebook/opt-125m", "mp", "", "L4", {}),
|
||||
("facebook/opt-125m", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
|
||||
("facebook/opt-125m", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
|
||||
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}),
|
||||
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
|
||||
("facebook/opt-125m", "ray", "", "A100", {}),
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import multiprocessing
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from unittest import mock
|
||||
|
||||
import multiprocess as mp
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
||||
@@ -22,7 +25,14 @@ def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]:
|
||||
return [np.random.randint(1, 100, i) for i in sizes]
|
||||
|
||||
|
||||
def distributed_run(fn, world_size):
|
||||
def distributed_run(fn, world_size, timeout=60):
|
||||
"""Run a function in multiple processes with proper error handling.
|
||||
|
||||
Args:
|
||||
fn: Function to run in each process
|
||||
world_size: Number of processes to spawn
|
||||
timeout: Maximum time in seconds to wait for processes (default: 60)
|
||||
"""
|
||||
number_of_processes = world_size
|
||||
processes = []
|
||||
for i in range(number_of_processes):
|
||||
@@ -33,19 +43,45 @@ def distributed_run(fn, world_size):
|
||||
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||
env["MASTER_ADDR"] = "localhost"
|
||||
env["MASTER_PORT"] = "12345"
|
||||
p = multiprocessing.Process(target=fn, args=(env,))
|
||||
p = mp.Process(target=fn, args=(env,))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
# Monitor processes and fail fast if any process fails
|
||||
start_time = time.time()
|
||||
failed_processes = []
|
||||
|
||||
for p in processes:
|
||||
assert p.exitcode == 0
|
||||
# Wait for all processes, checking for failures
|
||||
while time.time() - start_time < timeout:
|
||||
all_done = True
|
||||
for i, p in enumerate(processes):
|
||||
if p.is_alive():
|
||||
all_done = False
|
||||
elif p.exitcode != 0:
|
||||
# Process failed
|
||||
failed_processes.append((i, p.exitcode))
|
||||
break
|
||||
|
||||
if failed_processes or all_done:
|
||||
break
|
||||
time.sleep(0.1) # Check every 100ms
|
||||
|
||||
# Check for timeout if no failures detected yet
|
||||
for i, p in enumerate(processes):
|
||||
if p.is_alive():
|
||||
p.kill()
|
||||
p.join()
|
||||
|
||||
# Report failures
|
||||
if failed_processes:
|
||||
error_msg = "Distributed test failed:\n"
|
||||
for rank, status in failed_processes:
|
||||
error_msg += f" Rank {rank}: Exit code {status}\n"
|
||||
raise AssertionError(error_msg)
|
||||
|
||||
|
||||
def worker_fn_wrapper(fn):
|
||||
# `multiprocessing.Process` cannot accept environment variables directly
|
||||
# `mp.Process` cannot accept environment variables directly
|
||||
# so we need to pass the environment variables as arguments
|
||||
# and update the environment variables in the function
|
||||
def wrapped_fn(env):
|
||||
@@ -115,3 +151,244 @@ def worker_fn():
|
||||
|
||||
def test_shm_broadcast():
|
||||
distributed_run(worker_fn, 4)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn_test_shutdown_busy():
|
||||
rank = dist.get_rank()
|
||||
writer_rank = 2
|
||||
message_queue = MessageQueue.create_from_process_group(
|
||||
dist.group.WORLD, 40 * 1024, 2, writer_rank
|
||||
)
|
||||
|
||||
if not message_queue._is_writer:
|
||||
# Put into busy mode
|
||||
message_queue._spin_condition.busy_loop_s = 9999
|
||||
|
||||
shutdown_event = threading.Event()
|
||||
|
||||
def shutdown_thread(mq, shutdown_event):
|
||||
shutdown_event.wait()
|
||||
mq.shutdown()
|
||||
|
||||
threading.Thread(
|
||||
target=shutdown_thread, args=(message_queue, shutdown_event)
|
||||
).start()
|
||||
|
||||
with pytest.raises(TimeoutError):
|
||||
message_queue.dequeue(timeout=0.01)
|
||||
|
||||
shutdown_event.set()
|
||||
|
||||
with pytest.raises(RuntimeError, match="cancelled"):
|
||||
message_queue.dequeue(timeout=1)
|
||||
|
||||
assert message_queue.shutting_down
|
||||
|
||||
print(f"torch distributed passed the test! Rank {rank}")
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def test_message_queue_shutdown_busy(caplog_vllm):
|
||||
distributed_run(worker_fn_test_shutdown_busy, 4)
|
||||
print(caplog_vllm.text)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn_test_shutdown_idle():
|
||||
rank = dist.get_rank()
|
||||
writer_rank = 2
|
||||
message_queue = MessageQueue.create_from_process_group(
|
||||
dist.group.WORLD, 40 * 1024, 2, writer_rank
|
||||
)
|
||||
|
||||
if not message_queue._is_writer:
|
||||
# Put into idle mode
|
||||
message_queue._spin_condition.last_read = 0
|
||||
|
||||
shutdown_event = threading.Event()
|
||||
|
||||
def shutdown_thread(mq, shutdown_event):
|
||||
shutdown_event.wait()
|
||||
mq.shutdown()
|
||||
|
||||
threading.Thread(
|
||||
target=shutdown_thread, args=(message_queue, shutdown_event)
|
||||
).start()
|
||||
|
||||
with pytest.raises(TimeoutError):
|
||||
message_queue.dequeue(timeout=0.01)
|
||||
|
||||
shutdown_event.set()
|
||||
|
||||
with pytest.raises(RuntimeError, match="cancelled"):
|
||||
message_queue.dequeue(timeout=1)
|
||||
|
||||
assert message_queue.shutting_down
|
||||
|
||||
print(f"torch distributed passed the test! Rank {rank}")
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def test_message_queue_shutdown_idle():
|
||||
distributed_run(worker_fn_test_shutdown_idle, 4)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn_test_idle_to_busy():
|
||||
rank = dist.get_rank()
|
||||
writer_rank = 2
|
||||
message_queue = MessageQueue.create_from_process_group(
|
||||
dist.group.WORLD, 40 * 1024, 2, writer_rank
|
||||
)
|
||||
|
||||
message1 = "hello world"
|
||||
message2 = np.random.randint(1, 100, 100)
|
||||
with mock.patch.object(
|
||||
message_queue._spin_condition, "wait", wraps=message_queue._spin_condition.wait
|
||||
) as wrapped_wait:
|
||||
if not message_queue._is_writer:
|
||||
# Put into idle mode
|
||||
message_queue._spin_condition.last_read = 0
|
||||
|
||||
# no messages, so expect a TimeoutError
|
||||
with pytest.raises(TimeoutError):
|
||||
message_queue.dequeue(timeout=0.01)
|
||||
# wait should only be called once while idle
|
||||
assert wrapped_wait.call_count == 1
|
||||
|
||||
# sync with the writer and wait for message1
|
||||
dist.barrier()
|
||||
recv_message = message_queue.dequeue(timeout=5)
|
||||
assert recv_message == message1
|
||||
# second call to wait, with a message read, this puts in a busy spin
|
||||
assert wrapped_wait.call_count == 2
|
||||
|
||||
# sync with the writer and wait for message2
|
||||
dist.barrier()
|
||||
recv_message = message_queue.dequeue(timeout=1)
|
||||
assert np.array_equal(recv_message, message2)
|
||||
# in busy mode, we expect wait to have been called multiple times
|
||||
assert wrapped_wait.call_count > 3
|
||||
else:
|
||||
# writer writes two messages in sync with the reader
|
||||
dist.barrier()
|
||||
# sleep delays the send to ensure reader enters the read loop
|
||||
time.sleep(0.1)
|
||||
message_queue.enqueue(message1)
|
||||
|
||||
dist.barrier()
|
||||
time.sleep(0.1)
|
||||
message_queue.enqueue(message2)
|
||||
|
||||
message_queue.shutdown()
|
||||
assert message_queue.shutting_down
|
||||
print(f"torch distributed passed the test! Rank {rank}")
|
||||
|
||||
|
||||
def test_message_queue_idle_wake():
|
||||
distributed_run(worker_fn_test_idle_to_busy, 4)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn_test_busy_to_idle():
|
||||
rank = dist.get_rank()
|
||||
writer_rank = 2
|
||||
message_queue = MessageQueue.create_from_process_group(
|
||||
dist.group.WORLD, 40 * 1024, 2, writer_rank
|
||||
)
|
||||
|
||||
message1 = 12345
|
||||
message2 = list(range(3))
|
||||
with mock.patch.object(
|
||||
message_queue._spin_condition, "wait", wraps=message_queue._spin_condition.wait
|
||||
) as wrapped_wait:
|
||||
if not message_queue._is_writer:
|
||||
# Put into busy mode
|
||||
message_queue._spin_condition.busy_loop_s = 9999
|
||||
|
||||
# sync with the writer and wait for message1
|
||||
dist.barrier()
|
||||
recv_message = message_queue.dequeue(timeout=1)
|
||||
assert recv_message == message1
|
||||
# in busy mode, we expect wait to have been called many times
|
||||
assert wrapped_wait.call_count > 1
|
||||
|
||||
# simulate busy loop ending
|
||||
message_queue._spin_condition.busy_loop_s = 0
|
||||
# ensure we enter idle mode, then record call count
|
||||
with pytest.raises(TimeoutError):
|
||||
message_queue.dequeue(timeout=0.01)
|
||||
call_count = wrapped_wait.call_count
|
||||
|
||||
# sync with the writer and wait for message2
|
||||
dist.barrier()
|
||||
recv_message = message_queue.dequeue(timeout=1)
|
||||
assert recv_message == message2
|
||||
|
||||
# call to wait after idle should only happen once
|
||||
assert wrapped_wait.call_count == call_count + 1
|
||||
else:
|
||||
# writer writes two messages in sync with the reader
|
||||
dist.barrier()
|
||||
# sleep delays the send to ensure reader enters the read loop
|
||||
time.sleep(0.1)
|
||||
message_queue.enqueue(message1)
|
||||
|
||||
dist.barrier()
|
||||
time.sleep(0.1)
|
||||
message_queue.enqueue(message2)
|
||||
|
||||
message_queue.shutdown()
|
||||
assert message_queue.shutting_down
|
||||
print(f"torch distributed passed the test! Rank {rank}")
|
||||
|
||||
|
||||
def test_message_queue_busy_to_idle():
|
||||
distributed_run(worker_fn_test_busy_to_idle, 4)
|
||||
|
||||
|
||||
def test_warning_logs(caplog_vllm):
|
||||
"""
|
||||
Test that warning logs are emitted at VLLM_RINGBUFFER_WARNING_INTERVAL intervals
|
||||
when indefinite=False, and are not emitted when indefinite=True.
|
||||
"""
|
||||
|
||||
# Patch the warning log interval to every 1 ms during reads
|
||||
with mock.patch(
|
||||
"vllm.distributed.device_communicators.shm_broadcast.VLLM_RINGBUFFER_WARNING_INTERVAL",
|
||||
new=0.001, # 1 ms
|
||||
):
|
||||
writer = MessageQueue(
|
||||
n_reader=1,
|
||||
n_local_reader=1,
|
||||
max_chunk_bytes=1024 * 1024, # 1MB chunks
|
||||
max_chunks=10,
|
||||
)
|
||||
reader = MessageQueue.create_from_handle(writer.export_handle(), rank=0)
|
||||
writer.wait_until_ready()
|
||||
reader.wait_until_ready()
|
||||
|
||||
# We should have at least one warning log here
|
||||
# "0 seconds" expected due to rounding of 1ms test interval
|
||||
with pytest.raises(TimeoutError):
|
||||
reader.dequeue(timeout=0.01, indefinite=False)
|
||||
assert any(
|
||||
"No available shared memory broadcast block found in 0 seconds"
|
||||
in record.message
|
||||
for record in caplog_vllm.records
|
||||
)
|
||||
caplog_vllm.clear()
|
||||
|
||||
# We should have no warnings this time
|
||||
with pytest.raises(TimeoutError):
|
||||
reader.dequeue(timeout=0.01, indefinite=True)
|
||||
assert all(
|
||||
"No available shared memory broadcast block found in 0 seconds"
|
||||
not in record.message
|
||||
for record in caplog_vllm.records
|
||||
)
|
||||
|
||||
# Clean up when done
|
||||
writer.shutdown()
|
||||
reader.shutdown()
|
||||
|
||||
@@ -2,13 +2,13 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
import pickle
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import shared_memory
|
||||
from pickle import PickleBuffer
|
||||
from threading import Event
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -18,6 +18,7 @@ import zmq
|
||||
from torch.distributed import ProcessGroup
|
||||
from zmq import ( # type: ignore
|
||||
IPV6, # type: ignore
|
||||
PUB,
|
||||
SUB,
|
||||
SUBSCRIBE,
|
||||
XPUB,
|
||||
@@ -32,6 +33,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import (
|
||||
get_ip,
|
||||
get_open_port,
|
||||
get_open_zmq_inproc_path,
|
||||
get_open_zmq_ipc_path,
|
||||
is_valid_ipv6_address,
|
||||
)
|
||||
@@ -78,50 +80,125 @@ def to_bytes_big(value: int, size: int) -> bytes:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def long_wait_time_msg(threshold: int) -> str:
|
||||
return (
|
||||
"No available shared memory broadcast block found "
|
||||
f"in {threshold} seconds. This typically happens "
|
||||
"when some processes are hanging or doing some "
|
||||
"time-consuming work (e.g. compilation, "
|
||||
"weight/kv cache quantization)."
|
||||
)
|
||||
LONG_WAIT_TIME_LOG_MSG = (
|
||||
"No available shared memory broadcast block found "
|
||||
"in %d seconds. This typically happens "
|
||||
"when some processes are hanging or doing some "
|
||||
"time-consuming work (e.g. compilation, "
|
||||
"weight/kv cache quantization)."
|
||||
)
|
||||
|
||||
|
||||
class SpinTimer:
|
||||
def record_activity(self):
|
||||
pass
|
||||
|
||||
def spin(self):
|
||||
sched_yield()
|
||||
|
||||
|
||||
class SpinSleepTimer(SpinTimer):
|
||||
class SpinCondition:
|
||||
"""
|
||||
In setups which have long inactivity periods it is desirable to reduce
|
||||
system power consumption when vllm does nothing. This would lead to more
|
||||
CPU thermal headroom when a request eventually comes, especially when
|
||||
multiple GPUs are connected as each GPU would otherwise pin one thread at
|
||||
100% CPU usage.
|
||||
This class implements an interface similar to a threading.Condition. It
|
||||
allows a writer to notify readers to wake up and read from the shared memory
|
||||
buffer. This notification is done over a zmq socket.
|
||||
|
||||
The simplest solution is to reduce polling frequency when there is no
|
||||
activity for a certain period of time.
|
||||
For optimal performance under load we don't want the readers to need to poll
|
||||
the zmq socket for every read. So the `wait` method here will return
|
||||
immediately when reads are frequent, and will only enter "idle mode" and
|
||||
await a notification on the zmq socket after a period of inactivity. This
|
||||
allows the readers to spin quickly, hence "SpinCondition".
|
||||
|
||||
To support clean shutdown, a separate thread in the reader's process must be
|
||||
able to wake the reader so that it can exit. A separate cancel() method is
|
||||
implemented with an in-process socket to allow this interruption.
|
||||
"""
|
||||
|
||||
def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1):
|
||||
self.last_activity = time.monotonic()
|
||||
self.busy_loop_s = busy_loop_s
|
||||
self.wait_sleep_s = wait_sleep_s
|
||||
def __init__(
|
||||
self,
|
||||
is_reader: bool,
|
||||
context: zmq.Context,
|
||||
notify_address: str,
|
||||
busy_loop_s: float = 1,
|
||||
):
|
||||
self.is_reader = is_reader
|
||||
|
||||
def record_activity(self):
|
||||
self.last_activity = time.monotonic()
|
||||
if is_reader:
|
||||
# Time of last shm buffer read
|
||||
self.last_read = time.monotonic()
|
||||
|
||||
def spin(self):
|
||||
curr_time = time.monotonic()
|
||||
if curr_time >= self.last_activity + self.busy_loop_s:
|
||||
time.sleep(self.wait_sleep_s)
|
||||
# Time to keep busy-looping on the shm buffer before going idle
|
||||
self.busy_loop_s = busy_loop_s
|
||||
|
||||
# Readers subscribe to write notifications
|
||||
self.local_notify_socket: zmq.Socket = context.socket(SUB)
|
||||
# Set zmq.CONFLATE to only keep the last message that the socket
|
||||
# receives. This prevents us from piling up notification messages
|
||||
# under high load when we aren't polling the socket.
|
||||
self.local_notify_socket.setsockopt(zmq.CONFLATE, 1)
|
||||
# Subscribe to all messages on the socket
|
||||
self.local_notify_socket.setsockopt_string(SUBSCRIBE, "")
|
||||
self.local_notify_socket.connect(notify_address)
|
||||
|
||||
# Readers require a process-local socket to poll for cancellation
|
||||
cancel_path = get_open_zmq_inproc_path()
|
||||
self.write_cancel_socket: zmq.Socket = context.socket(zmq.PAIR)
|
||||
self.write_cancel_socket.bind(cancel_path)
|
||||
self.read_cancel_socket: zmq.Socket = context.socket(zmq.PAIR)
|
||||
self.read_cancel_socket.connect(cancel_path)
|
||||
|
||||
# Poller allows waiting on either `.notify()` or `.cancel()`
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.read_cancel_socket, zmq.POLLIN)
|
||||
self.poller.register(self.local_notify_socket, zmq.POLLIN)
|
||||
else:
|
||||
# Writer side publishes write notifications
|
||||
self.local_notify_socket: zmq.Socket = context.socket(PUB) # type: ignore
|
||||
# Set high water mark to 1 - we don't need to send a massive amount of
|
||||
# pings during busy operation. PUB sockets will silently drop subsequent
|
||||
# messages after the high water mark is reached.
|
||||
self.local_notify_socket.setsockopt(zmq.SNDHWM, 1)
|
||||
self.local_notify_socket.bind(notify_address)
|
||||
|
||||
self.last_read = 0
|
||||
self.busy_loop_s = 0
|
||||
self.read_cancel_socket = None
|
||||
self.write_cancel_socket = None
|
||||
self.poller = None
|
||||
|
||||
def record_read(self):
|
||||
self.last_read = time.monotonic()
|
||||
|
||||
def cancel(self):
|
||||
# Sends cancellation ping that will cause the reader to wake up.
|
||||
# This is done from a monitor thread in the same process as the reader.
|
||||
if self.is_reader:
|
||||
logger.debug("Canceling waiting reads on SHM Buffer")
|
||||
self.write_cancel_socket.send(b"\x00")
|
||||
|
||||
def wait(self, timeout_ms: int | None = None) -> None:
|
||||
"""Wait for data on the shared memory buffer.
|
||||
|
||||
Yields the scheduler then returns immediately if it has been less than
|
||||
self.busy_loop_s since the last read.
|
||||
|
||||
Otherwise, enters idle mode and awaits a socket ping for at most
|
||||
`timeout_ms` milliseconds, or indefinitely if timeout_ms is None.
|
||||
"""
|
||||
assert self.is_reader, "Only readers can wait"
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time <= self.last_read + self.busy_loop_s:
|
||||
sched_yield()
|
||||
else:
|
||||
events = dict(self.poller.poll(timeout=timeout_ms))
|
||||
|
||||
if self.read_cancel_socket in events:
|
||||
logger.debug("Poller received cancel event")
|
||||
elif self.local_notify_socket in events:
|
||||
logger.debug("Poller received notify event")
|
||||
# Since zmq.CONFLATE is set, there will only be one notification
|
||||
# to read from the socket
|
||||
self.local_notify_socket.recv(flags=zmq.NOBLOCK, copy=False)
|
||||
else:
|
||||
logger.debug("Poller timed out")
|
||||
|
||||
def notify(self):
|
||||
"""Notifies all readers to wake up"""
|
||||
assert not self.is_reader, "Only writers can notify"
|
||||
self.local_notify_socket.send(b"\x00")
|
||||
|
||||
|
||||
class ShmRingBuffer:
|
||||
@@ -265,6 +342,7 @@ class Handle:
|
||||
|
||||
buffer_handle: tuple[int, int, int, str] | None = None
|
||||
local_subscribe_addr: str | None = None
|
||||
local_notify_addr: str | None = None
|
||||
remote_subscribe_addr: str | None = None
|
||||
remote_addr_ipv6: bool = False
|
||||
|
||||
@@ -288,7 +366,7 @@ class MessageQueue:
|
||||
self.n_local_reader = n_local_reader
|
||||
n_remote_reader = n_reader - n_local_reader
|
||||
self.n_remote_reader = n_remote_reader
|
||||
|
||||
self.shutting_down = False
|
||||
context = Context()
|
||||
|
||||
if n_local_reader > 0:
|
||||
@@ -310,11 +388,19 @@ class MessageQueue:
|
||||
self.local_socket.bind(local_subscribe_addr)
|
||||
|
||||
self.current_idx = 0
|
||||
|
||||
# Create the notification side of the SpinCondition
|
||||
local_notify_addr = get_open_zmq_ipc_path()
|
||||
self._spin_condition = SpinCondition(
|
||||
is_reader=False, context=context, notify_address=local_notify_addr
|
||||
)
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
local_subscribe_addr = None
|
||||
self.local_socket = None
|
||||
self.current_idx = -1
|
||||
local_notify_addr = None
|
||||
self._spin_condition = None # type: ignore
|
||||
|
||||
remote_addr_ipv6 = False
|
||||
if n_remote_reader > 0:
|
||||
@@ -341,12 +427,12 @@ class MessageQueue:
|
||||
self.local_reader_rank = -1
|
||||
# rank does not matter for remote readers
|
||||
self._is_remote_reader = False
|
||||
self._read_spin_timer = SpinTimer()
|
||||
|
||||
self.handle = Handle(
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
buffer_handle=self.buffer.handle() if self.buffer is not None else None,
|
||||
local_subscribe_addr=local_subscribe_addr,
|
||||
local_notify_addr=local_notify_addr,
|
||||
remote_subscribe_addr=remote_subscribe_addr,
|
||||
remote_addr_ipv6=remote_addr_ipv6,
|
||||
)
|
||||
@@ -379,9 +465,9 @@ class MessageQueue:
|
||||
self.local_socket.connect(socket_addr)
|
||||
|
||||
self.remote_socket = None
|
||||
|
||||
self._read_spin_timer = (
|
||||
SpinSleepTimer() if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer()
|
||||
assert isinstance(handle.local_notify_addr, str)
|
||||
self._spin_condition = SpinCondition(
|
||||
is_reader=True, context=context, notify_address=handle.local_notify_addr
|
||||
)
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
@@ -399,7 +485,9 @@ class MessageQueue:
|
||||
socket_addr = handle.remote_subscribe_addr
|
||||
logger.debug("Connecting to %s", socket_addr)
|
||||
self.remote_socket.connect(socket_addr)
|
||||
self._spin_condition = None # type: ignore
|
||||
|
||||
self.shutting_down = False
|
||||
return self
|
||||
|
||||
def wait_until_ready(self):
|
||||
@@ -435,6 +523,13 @@ class MessageQueue:
|
||||
recv = self.remote_socket.recv()
|
||||
assert recv == b"READY"
|
||||
|
||||
def shutdown(self):
|
||||
"""If this is an idle reader, wakes it up so it can clean up and shut
|
||||
down"""
|
||||
self.shutting_down = True
|
||||
if self._spin_condition is not None:
|
||||
self._spin_condition.cancel()
|
||||
|
||||
@contextmanager
|
||||
def acquire_write(self, timeout: float | None = None):
|
||||
assert self._is_writer, "Only writers can acquire write"
|
||||
@@ -465,7 +560,7 @@ class MessageQueue:
|
||||
# if we wait for a long time, log a message
|
||||
if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning:
|
||||
logger.info(
|
||||
long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL)
|
||||
LONG_WAIT_TIME_LOG_MSG, VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||
)
|
||||
n_warning += 1
|
||||
|
||||
@@ -503,16 +598,60 @@ class MessageQueue:
|
||||
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
|
||||
break
|
||||
|
||||
class ReadTimeoutWithWarnings:
|
||||
def __init__(self, timeout: float | None, should_warn: bool) -> None:
|
||||
self.started = time.monotonic()
|
||||
self.deadline = sys.maxsize if timeout is None else self.started + timeout
|
||||
|
||||
# if should_warn, we need to wake up periodically to log
|
||||
self.warning_wait_time_ms: int | None = (
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL * 1000 if should_warn else None
|
||||
)
|
||||
|
||||
self._should_warn = should_warn
|
||||
self.n_warning = 1
|
||||
self.timeout = timeout
|
||||
|
||||
def timeout_ms(self) -> int | None:
|
||||
"""Returns a timeout that is:
|
||||
- min(time to deadline, time to next warning) if we're logging warnings
|
||||
- time to deadline, if we're not logging warnings
|
||||
- None if the timeout is None and we're not logging warnings
|
||||
- raise TimeoutError if we are past the deadline
|
||||
"""
|
||||
warning_wait_time = self.warning_wait_time_ms
|
||||
if self.timeout is None:
|
||||
return warning_wait_time
|
||||
|
||||
time_left_ms = int((self.deadline - time.monotonic()) * 1000)
|
||||
if time_left_ms <= 0:
|
||||
raise TimeoutError
|
||||
|
||||
if warning_wait_time and warning_wait_time < time_left_ms:
|
||||
return warning_wait_time
|
||||
|
||||
return time_left_ms
|
||||
|
||||
def should_warn(self) -> bool:
|
||||
"""Returns true if it's time to log a warning for a timeout that is not
|
||||
indefinite"""
|
||||
if self._should_warn:
|
||||
elapsed = time.monotonic() - self.started
|
||||
if elapsed >= VLLM_RINGBUFFER_WARNING_INTERVAL * self.n_warning:
|
||||
self.n_warning += 1
|
||||
return True
|
||||
return False
|
||||
|
||||
@contextmanager
|
||||
def acquire_read(
|
||||
self,
|
||||
timeout: float | None = None,
|
||||
cancel: Event | None = None,
|
||||
indefinite: bool = False,
|
||||
):
|
||||
assert self._is_local_reader, "Only readers can acquire read"
|
||||
start_time = time.monotonic()
|
||||
n_warning = 1
|
||||
read_timeout = self.ReadTimeoutWithWarnings(
|
||||
timeout=timeout, should_warn=not indefinite
|
||||
)
|
||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||
while True:
|
||||
# Memory fence ensures we see the latest writes from the writer.
|
||||
@@ -529,26 +668,16 @@ class MessageQueue:
|
||||
# for readers, `self.current_idx` is the next block to read
|
||||
# if this block is not ready,
|
||||
# we need to wait until it is written
|
||||
self._spin_condition.wait(timeout_ms=read_timeout.timeout_ms())
|
||||
|
||||
# Release the processor to other threads
|
||||
self._read_spin_timer.spin()
|
||||
|
||||
if cancel is not None and cancel.is_set():
|
||||
if self.shutting_down:
|
||||
raise RuntimeError("cancelled")
|
||||
|
||||
# if we time out, raise an exception
|
||||
elapsed = time.monotonic() - start_time
|
||||
if timeout is not None and elapsed > timeout:
|
||||
raise TimeoutError
|
||||
|
||||
# if we wait for a long time, log a message
|
||||
if not indefinite and (
|
||||
elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning
|
||||
):
|
||||
if read_timeout.should_warn():
|
||||
logger.info(
|
||||
long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL)
|
||||
LONG_WAIT_TIME_LOG_MSG, VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||
)
|
||||
n_warning += 1
|
||||
|
||||
continue
|
||||
# found a block that is not read by this reader
|
||||
@@ -565,7 +694,7 @@ class MessageQueue:
|
||||
memory_fence()
|
||||
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
|
||||
|
||||
self._read_spin_timer.record_activity()
|
||||
self._spin_condition.record_read()
|
||||
break
|
||||
|
||||
def enqueue(self, obj, timeout: float | None = None):
|
||||
@@ -608,18 +737,19 @@ class MessageQueue:
|
||||
buf[offset:buf_offset] = to_bytes_big(buf_len, 4)
|
||||
buf[buf_offset : (offset := buf_offset + buf_len)] = buffer
|
||||
|
||||
self._spin_condition.notify()
|
||||
|
||||
if self.n_remote_reader > 0:
|
||||
self.remote_socket.send_multipart(all_buffers, copy=False)
|
||||
|
||||
def dequeue(
|
||||
self,
|
||||
timeout: float | None = None,
|
||||
cancel: Event | None = None,
|
||||
indefinite: bool = False,
|
||||
):
|
||||
"""Read from message queue with optional timeout (in seconds)"""
|
||||
if self._is_local_reader:
|
||||
with self.acquire_read(timeout, cancel, indefinite) as buf:
|
||||
with self.acquire_read(timeout, indefinite) as buf:
|
||||
overflow = buf[0] == 1
|
||||
if not overflow:
|
||||
offset = 3
|
||||
|
||||
@@ -179,7 +179,6 @@ if TYPE_CHECKING:
|
||||
VLLM_MOONCAKE_BOOTSTRAP_PORT: int = 8998
|
||||
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
|
||||
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
|
||||
VLLM_SLEEP_WHEN_IDLE: bool = False
|
||||
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
|
||||
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300
|
||||
VLLM_KV_CACHE_LAYOUT: Literal["NHD", "HND"] | None = None
|
||||
@@ -1338,9 +1337,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS": lambda: int(
|
||||
os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")
|
||||
),
|
||||
# Reduce CPU usage when vLLM is idle. Enabling this will incur small
|
||||
# latency penalty when a request eventually comes.
|
||||
"VLLM_SLEEP_WHEN_IDLE": lambda: bool(int(os.getenv("VLLM_SLEEP_WHEN_IDLE", "0"))),
|
||||
# Control the max chunk bytes (in MB) for the rpc message queue.
|
||||
# Object larger than this threshold will be broadcast to worker
|
||||
# processes via zmq.
|
||||
@@ -1751,7 +1747,6 @@ def compile_factors() -> dict[str, object]:
|
||||
"VLLM_HTTP_TIMEOUT_KEEP_ALIVE",
|
||||
"VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS",
|
||||
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH",
|
||||
"VLLM_SLEEP_WHEN_IDLE",
|
||||
"VLLM_IMAGE_FETCH_TIMEOUT",
|
||||
"VLLM_VIDEO_FETCH_TIMEOUT",
|
||||
"VLLM_AUDIO_FETCH_TIMEOUT",
|
||||
|
||||
@@ -104,7 +104,6 @@ class MultiprocExecutor(Executor):
|
||||
# and ensure workers will be terminated.
|
||||
self._finalizer = weakref.finalize(self, self.shutdown)
|
||||
self.is_failed = False
|
||||
self.shutdown_event = threading.Event()
|
||||
self.failure_callback: FailureCallback | None = None
|
||||
|
||||
tp_size, pp_size, pcp_size = self._get_parallel_sizes()
|
||||
@@ -158,20 +157,31 @@ class MultiprocExecutor(Executor):
|
||||
global_start_rank = (
|
||||
self.local_world_size * self.parallel_config.node_rank_within_dp
|
||||
)
|
||||
# Keep track of socket file descriptors that are inherited by the
|
||||
# worker when using fork, so that we can close them in subsequent
|
||||
# workers
|
||||
inherited_fds: list[int] = []
|
||||
for local_rank in range(self.local_world_size):
|
||||
global_rank = global_start_rank + local_rank
|
||||
is_driver_worker = self._is_driver_worker(global_rank)
|
||||
unready_workers.append(
|
||||
WorkerProc.make_worker_process(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=global_rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
input_shm_handle=scheduler_output_handle,
|
||||
shared_worker_lock=shared_worker_lock,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
unready_worker_handle = WorkerProc.make_worker_process(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=global_rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
input_shm_handle=scheduler_output_handle,
|
||||
shared_worker_lock=shared_worker_lock,
|
||||
is_driver_worker=is_driver_worker,
|
||||
inherited_fds=inherited_fds,
|
||||
)
|
||||
unready_workers.append(unready_worker_handle)
|
||||
if context.get_start_method() == "fork":
|
||||
inherited_fds.extend(
|
||||
[
|
||||
unready_worker_handle.death_writer.fileno(),
|
||||
unready_worker_handle.ready_pipe.fileno(),
|
||||
]
|
||||
)
|
||||
|
||||
# Workers must be created before wait_for_ready to avoid
|
||||
# deadlock, since worker.init_device() does a device sync.
|
||||
@@ -220,6 +230,7 @@ class MultiprocExecutor(Executor):
|
||||
for uw in unready_workers:
|
||||
if uw.death_writer is not None:
|
||||
uw.death_writer.close()
|
||||
uw.death_writer = None
|
||||
self._ensure_worker_termination([uw.proc for uw in unready_workers])
|
||||
|
||||
self.output_rank = self._get_output_rank()
|
||||
@@ -255,6 +266,7 @@ class MultiprocExecutor(Executor):
|
||||
died = multiprocessing.connection.wait(sentinels)
|
||||
_self = self_ref()
|
||||
if not _self or getattr(_self, "shutting_down", False):
|
||||
logger.debug("MultiprocWorkerMonitor: shutdown already initiated")
|
||||
return
|
||||
_self.is_failed = True
|
||||
proc_name = next(h.proc.name for h in workers if h.proc.sentinel == died[0])
|
||||
@@ -354,8 +366,6 @@ class MultiprocExecutor(Executor):
|
||||
if output_rank is not None:
|
||||
response_mqs = (response_mqs[output_rank],)
|
||||
|
||||
shutdown_event = self.shutdown_event
|
||||
|
||||
def get_response():
|
||||
responses = []
|
||||
for mq in response_mqs:
|
||||
@@ -363,9 +373,7 @@ class MultiprocExecutor(Executor):
|
||||
None if deadline is None else (deadline - time.monotonic())
|
||||
)
|
||||
try:
|
||||
status, result = mq.dequeue(
|
||||
timeout=dequeue_timeout, cancel=shutdown_event
|
||||
)
|
||||
status, result = mq.dequeue(timeout=dequeue_timeout)
|
||||
except TimeoutError as e:
|
||||
raise TimeoutError(f"RPC call to {method} timed out.") from e
|
||||
if status != WorkerProc.ResponseStatus.SUCCESS:
|
||||
@@ -408,20 +416,26 @@ class MultiprocExecutor(Executor):
|
||||
|
||||
active_procs = lambda: [proc for proc in worker_procs if proc.is_alive()]
|
||||
# Give processes time to clean themselves up properly first
|
||||
logger.debug("Worker Termination: allow workers to gracefully shutdown")
|
||||
if wait_for_termination(active_procs(), 4):
|
||||
return
|
||||
|
||||
# Send SIGTERM if still running
|
||||
logger.debug("Worker Termination: workers still running sending SIGTERM")
|
||||
for p in active_procs():
|
||||
p.terminate()
|
||||
if not wait_for_termination(active_procs(), 4):
|
||||
# Send SIGKILL if still running
|
||||
logger.debug(
|
||||
"Worker Termination: resorting to SIGKILL to take down workers"
|
||||
)
|
||||
for p in active_procs():
|
||||
p.kill()
|
||||
|
||||
def shutdown(self):
|
||||
"""Properly shut down the executor and its workers"""
|
||||
if not getattr(self, "shutting_down", False):
|
||||
logger.debug("Triggering shutdown of workers")
|
||||
self.shutting_down = True
|
||||
|
||||
# Make sure all the worker processes are terminated first.
|
||||
@@ -431,12 +445,20 @@ class MultiprocExecutor(Executor):
|
||||
if w.death_writer is not None:
|
||||
w.death_writer.close()
|
||||
w.death_writer = None
|
||||
w.worker_response_mq = None
|
||||
self._ensure_worker_termination([w.proc for w in workers])
|
||||
|
||||
self.shutdown_event.set()
|
||||
for w in workers:
|
||||
# Shutdown response queues
|
||||
if w.worker_response_mq is not None:
|
||||
w.worker_response_mq.shutdown()
|
||||
w.worker_response_mq = None
|
||||
|
||||
self.rpc_broadcast_mq = None
|
||||
if self.rpc_broadcast_mq is not None:
|
||||
self.rpc_broadcast_mq.shutdown()
|
||||
self.rpc_broadcast_mq = None
|
||||
for mq in self.response_mqs:
|
||||
mq.shutdown()
|
||||
self.response_mqs = []
|
||||
|
||||
def check_health(self) -> None:
|
||||
self.collective_rpc("check_health", timeout=10)
|
||||
@@ -609,24 +631,26 @@ class WorkerProc:
|
||||
input_shm_handle, # Receive SchedulerOutput
|
||||
shared_worker_lock: LockType,
|
||||
is_driver_worker: bool,
|
||||
inherited_fds: list[int],
|
||||
) -> UnreadyWorkerProcHandle:
|
||||
context = get_mp_context()
|
||||
# (reader, writer)
|
||||
reader, writer = context.Pipe(duplex=False)
|
||||
|
||||
# Create death pipe to detect parent process exit
|
||||
# Ready pipe to communicate readiness from child to parent
|
||||
ready_reader, ready_writer = context.Pipe(duplex=False)
|
||||
# Death pipe to let child detect parent process exit
|
||||
death_reader, death_writer = context.Pipe(duplex=False)
|
||||
|
||||
process_kwargs = {
|
||||
"vllm_config": vllm_config,
|
||||
"local_rank": local_rank,
|
||||
"rank": rank,
|
||||
"distributed_init_method": distributed_init_method,
|
||||
"input_shm_handle": input_shm_handle,
|
||||
"ready_pipe": (reader, writer),
|
||||
"ready_pipe": ready_writer,
|
||||
"death_pipe": death_reader,
|
||||
"shared_worker_lock": shared_worker_lock,
|
||||
"is_driver_worker": is_driver_worker,
|
||||
# Have the worker close parent end of this worker's pipes too
|
||||
"inherited_fds": inherited_fds
|
||||
+ [ready_reader.fileno(), death_writer.fileno()],
|
||||
}
|
||||
# Run EngineCore busy loop in background process.
|
||||
proc = context.Process(
|
||||
@@ -637,10 +661,12 @@ class WorkerProc:
|
||||
)
|
||||
|
||||
proc.start()
|
||||
writer.close()
|
||||
# Close child ends of pipes here in the parent
|
||||
ready_writer.close()
|
||||
death_reader.close()
|
||||
# Keep death_writer open in parent - when parent exits,
|
||||
# death_reader in child will get EOFError
|
||||
return UnreadyWorkerProcHandle(proc, rank, reader, death_writer)
|
||||
return UnreadyWorkerProcHandle(proc, rank, ready_reader, death_writer)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_response_handle_ready(
|
||||
@@ -703,12 +729,41 @@ class WorkerProc:
|
||||
return cast(list[WorkerProcHandle], ready_proc_handles)
|
||||
|
||||
def shutdown(self):
|
||||
if self.rpc_broadcast_mq is not None:
|
||||
self.rpc_broadcast_mq.shutdown()
|
||||
if self.worker_response_mq is not None:
|
||||
self.worker_response_mq.shutdown()
|
||||
self.worker.shutdown()
|
||||
self.rpc_broadcast_mq = None
|
||||
self.worker_response_mq = None
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
|
||||
def monitor_death_pipe(self, death_pipe, shutdown_requested: threading.Event):
|
||||
if death_pipe is None:
|
||||
return
|
||||
|
||||
def death_pipe_monitor(queues_to_shutdown: list[MessageQueue]):
|
||||
try:
|
||||
# This will block until parent process exits (pipe closes)
|
||||
death_pipe.recv()
|
||||
except EOFError:
|
||||
logger.info_once("Parent process exited, terminating worker queues")
|
||||
shutdown_requested.set()
|
||||
for mq in queues_to_shutdown:
|
||||
if mq is not None:
|
||||
mq.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning("Death monitoring error: %s", e)
|
||||
|
||||
# Pass queue references directly to avoid gc issues if passing self
|
||||
Thread(
|
||||
target=death_pipe_monitor,
|
||||
args=([self.rpc_broadcast_mq, self.worker_response_mq],),
|
||||
daemon=True,
|
||||
name="DeathPipeMonitor",
|
||||
).start()
|
||||
|
||||
@staticmethod
|
||||
def worker_main(*args, **kwargs):
|
||||
"""Worker initialization and execution loops.
|
||||
@@ -717,12 +772,12 @@ class WorkerProc:
|
||||
# Signal handler used for graceful termination.
|
||||
# SystemExit exception is only raised once to allow this and worker
|
||||
# processes to terminate without error
|
||||
shutdown_requested = False
|
||||
shutdown_requested = threading.Event()
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
if not shutdown_requested:
|
||||
shutdown_requested = True
|
||||
if not shutdown_requested.is_set():
|
||||
shutdown_requested.set()
|
||||
logger.debug(
|
||||
"WorkerProc handling signal %d, raising SystemExit", signum
|
||||
)
|
||||
@@ -733,33 +788,20 @@ class WorkerProc:
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
worker = None
|
||||
# tuple[Connection, Connection]
|
||||
reader, ready_writer = kwargs.pop("ready_pipe")
|
||||
death_pipe: Connection | None = kwargs.pop("death_pipe", None)
|
||||
shutdown_event = threading.Event()
|
||||
# Start death monitoring thread if death_pipe is provided
|
||||
if death_pipe is not None:
|
||||
ready_writer = kwargs.pop("ready_pipe")
|
||||
death_pipe = kwargs.pop("death_pipe", None)
|
||||
|
||||
def monitor_parent_death():
|
||||
try:
|
||||
# This will block until parent process exits (pipe closes)
|
||||
death_pipe.recv()
|
||||
except EOFError:
|
||||
# Parent process has exited, terminate this worker
|
||||
logger.info_once("Parent process exited, terminating worker")
|
||||
# Send signal to self to trigger clean shutdown
|
||||
shutdown_event.set()
|
||||
except Exception as e:
|
||||
logger.warning("Death monitoring error: %s", e)
|
||||
|
||||
death_monitor = Thread(
|
||||
target=monitor_parent_death, daemon=True, name="WorkerDeathMonitor"
|
||||
)
|
||||
death_monitor.start()
|
||||
# Close inherited pipes from parent (incl. other worker pipes)
|
||||
# Explicitly passing in existing pipes and closing them makes the pipe
|
||||
# behave when using fork. Otherwise, a hidden reference to the pipes
|
||||
# exist in the child process and prevents EOF closure.
|
||||
for fd in kwargs.pop("inherited_fds", []):
|
||||
try:
|
||||
os.close(fd)
|
||||
except Exception as e:
|
||||
logger.warning("Exception closing inherited connection: %s", e)
|
||||
|
||||
try:
|
||||
reader.close()
|
||||
|
||||
# Initialize tracer
|
||||
rank = kwargs.get("rank", 0)
|
||||
maybe_init_worker_tracer(
|
||||
@@ -771,6 +813,8 @@ class WorkerProc:
|
||||
worker = WorkerProc(*args, **kwargs)
|
||||
assert worker.worker_response_mq is not None
|
||||
|
||||
worker.monitor_death_pipe(death_pipe, shutdown_requested)
|
||||
|
||||
# Send READY once we know everything is loaded
|
||||
ready_writer.send(
|
||||
{
|
||||
@@ -788,7 +832,7 @@ class WorkerProc:
|
||||
ready_writer.close()
|
||||
ready_writer = None
|
||||
|
||||
worker.worker_busy_loop(cancel=shutdown_event)
|
||||
worker.worker_busy_loop()
|
||||
|
||||
except Exception:
|
||||
# NOTE: if an Exception arises in busy_loop, we send
|
||||
@@ -798,7 +842,7 @@ class WorkerProc:
|
||||
|
||||
if ready_writer is not None:
|
||||
logger.exception("WorkerProc failed to start.")
|
||||
elif shutdown_event.is_set():
|
||||
elif shutdown_requested.is_set():
|
||||
logger.info("WorkerProc shutting down.")
|
||||
else:
|
||||
logger.exception("WorkerProc failed.")
|
||||
@@ -806,7 +850,7 @@ class WorkerProc:
|
||||
# The parent sends a SIGTERM to all worker processes if
|
||||
# any worker dies. Set this value so we don't re-throw
|
||||
# SystemExit() to avoid zmq exceptions in __del__.
|
||||
shutdown_requested = True
|
||||
shutdown_requested.set()
|
||||
|
||||
except SystemExit as e:
|
||||
# SystemExit is raised on SIGTERM or SIGKILL, which usually indicates that
|
||||
@@ -859,12 +903,12 @@ class WorkerProc:
|
||||
output = self.async_output_queue.get()
|
||||
self.enqueue_output(output)
|
||||
|
||||
def worker_busy_loop(self, cancel: threading.Event | None = None):
|
||||
def worker_busy_loop(self):
|
||||
"""Main busy loop for Multiprocessing Workers"""
|
||||
assert self.rpc_broadcast_mq is not None
|
||||
while True:
|
||||
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue(
|
||||
cancel=cancel, indefinite=True
|
||||
indefinite=True
|
||||
)
|
||||
try:
|
||||
if isinstance(method, str):
|
||||
|
||||
Reference in New Issue
Block a user