diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index b74dbef8429..8bd6e92157a 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -2055,6 +2055,10 @@ def destroy_distributed_environment(): def cleanup_dist_env_and_memory(shutdown_ray: bool = False): + logger.debug( + "[shutdown] Distributed: cleanup start shutdown_ray=%s", + shutdown_ray, + ) # Reset environment variable cache envs.disable_envs_cache() @@ -2089,6 +2093,8 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): "torch._C._host_emptyCache() only available in Pytorch >=2.5" ) + logger.debug_once("[shutdown] Distributed: cleanup complete") + def in_the_same_node_as( pg: ProcessGroup | StatelessProcessGroup, source_rank: int = 0 diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index a560db87ea2..08a3ab58c78 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -95,6 +95,9 @@ async def serve_http( shutdown_event = asyncio.Event() def signal_handler() -> None: + if shutdown_event.is_set(): + return + logger.info_once("[shutdown] API server: shutdown triggered") shutdown_event.set() async def dummy_shutdown() -> None: @@ -108,12 +111,21 @@ async def serve_http( engine_client = app.state.engine_client timeout = engine_client.vllm_config.shutdown_timeout + mode = "abort" if timeout == 0 else "drain" + + logger.info( + "[shutdown] API server: stopping engine client mode=%s timeout=%ss", + mode, + timeout, + ) await loop.run_in_executor( None, partial(engine_client.shutdown, timeout=timeout) ) + logger.info_once("[shutdown] API server: engine client stopped") server.should_exit = True + logger.info_once("[shutdown] API server: signalling HTTP server shutdown") server_task.cancel() watchdog_task.cancel() if ssl_cert_refresher: @@ -134,7 +146,7 @@ async def serve_http( process, " ".join(process.cmdline()), ) - logger.info("Shutting down FastAPI HTTP server.") + logger.info_once("[shutdown] API server: shutting down FastAPI HTTP server") return server.shutdown() finally: shutdown_task.cancel() diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 897d063e830..8384def4664 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -2056,13 +2056,17 @@ class Scheduler(SchedulerInterface): return spec_decoding_stats def shutdown(self) -> None: + logger.debug_once("[shutdown] Scheduler: start") if self.kv_event_publisher: self.kv_event_publisher.shutdown() if self.connector is not None: self.connector.shutdown() + if self.ec_connector is not None: self.ec_connector.shutdown() + logger.debug_once("[shutdown] Scheduler: complete") + ######################################################################## # KV Connector Related Methods ######################################################################## diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index b12aa9d0505..66b49d0099c 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -608,6 +608,7 @@ class EngineCore: self.abort_requests(request_ids) def shutdown(self): + logger.debug_once("[shutdown] EngineCore: tearing down local resources") self.structured_output_manager.clear_backend() if self.model_executor: self.model_executor.shutdown() @@ -622,6 +623,7 @@ class EngineCore: # Tear down distributed state initialized in this EngineCore process # before it exits and release cached memory. cleanup_dist_env_and_memory() + logger.debug_once("[shutdown] EngineCore: local resource teardown complete") def profile(self, is_start: bool = True, profile_prefix: str | None = None): self.model_executor.profile(is_start, profile_prefix) @@ -1172,6 +1174,11 @@ class EngineCoreProc(EngineCore): signal_callback = SignalCallback(wakeup_engine) def signal_handler(signum, frame): + signal_name = signal.Signals(signum).name + logger.info( + "[shutdown] EngineCore: trigger received signal=%s", + signal_name, + ) engine_core.shutdown_state = EngineShutdownState.REQUESTED signal_callback.trigger() @@ -1181,7 +1188,7 @@ class EngineCoreProc(EngineCore): engine_core.run_busy_loop() except SystemExit: - logger.debug("EngineCore exiting.") + logger.info_once("[shutdown] EngineCore: exiting busy loop") raise except Exception as e: if engine_core is None: @@ -1285,13 +1292,21 @@ class EngineCoreProc(EngineCore): if self.shutdown_state == EngineShutdownState.REQUESTED: shutdown_timeout = self.vllm_config.shutdown_timeout + mode = "abort" if shutdown_timeout == 0 else "drain" - logger.info("Shutdown initiated (timeout=%d)", shutdown_timeout) + logger.info( + "[shutdown] EngineCore: start mode=%s timeout=%ds", + mode, + shutdown_timeout, + ) if shutdown_timeout == 0: num_requests = self.scheduler.get_num_unfinished_requests() if num_requests > 0: - logger.info("Aborting %d requests", num_requests) + logger.info( + "[shutdown] EngineCore: aborting in-flight requests count=%d", + num_requests, + ) aborted_reqs = self.scheduler.finish_requests( None, RequestStatus.FINISHED_ABORTED ) @@ -1300,7 +1315,8 @@ class EngineCoreProc(EngineCore): num_requests = self.scheduler.get_num_unfinished_requests() if num_requests > 0: logger.info( - "Draining %d in-flight requests (timeout=%ds)", + "[shutdown] EngineCore: draining in-flight requests " + "count=%d timeout=%ds", num_requests, shutdown_timeout, ) @@ -1309,7 +1325,10 @@ class EngineCoreProc(EngineCore): # Exit when no work remaining if not self.has_work(): - logger.info("Shutdown complete") + logger.info( + "[shutdown] EngineCore: request processing complete; " + "starting resource teardown" + ) return False return True @@ -1353,7 +1372,10 @@ class EngineCoreProc(EngineCore): if self.shutdown_state == EngineShutdownState.RUNNING: return False - logger.info("Rejecting request %s (server shutting down)", request.request_id) + logger.debug( + "[shutdown] EngineCore: rejecting new request request_id=%s", + request.request_id, + ) self._send_abort_outputs_to_client([request.request_id], request.client_index) return True @@ -1363,7 +1385,10 @@ class EngineCoreProc(EngineCore): if self.shutdown_state == EngineShutdownState.RUNNING: return False - logger.warning("Rejecting utility call %s (server shutting down)", method_name) + logger.warning( + "[shutdown] EngineCore: rejecting utility call method=%s", + method_name, + ) output = UtilityOutput(call_id, failure_message="Server shutting down") self.output_queue.put_nowait( (client_idx, EngineCoreOutputs(utility_output=output)) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 14257b020ee..32f2d091eb3 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -391,6 +391,7 @@ class BackgroundResources: def __call__(self): """Clean up background resources.""" + logger.debug_once("[shutdown] MPClient: background resource cleanup start") self.engine_dead = True if self.engine_manager is not None: self.engine_manager.shutdown() @@ -445,6 +446,8 @@ class BackgroundResources: # Send shutdown signal. shutdown_sender.send(b"") + logger.debug_once("[shutdown] MPClient: background resource cleanup complete") + def validate_alive(self, frames: Sequence[zmq.Frame]): if len(frames) == 1 and (frames[0].buffer == EngineCoreProc.ENGINE_CORE_DEAD): self.engine_dead = True @@ -645,9 +648,15 @@ class MPClient(EngineCoreClient): def shutdown(self, timeout: float | None = None) -> None: """Shutdown engine manager under timeout and clean up resources.""" if self._finalizer.detach() is not None: + timeout_str = "default" if timeout is None else f"{timeout}s" + logger.info("[shutdown] MPClient: start timeout=%s", timeout_str) if self.resources.engine_manager is not None: + logger.info_once("[shutdown] MPClient: stopping engine manager") self.resources.engine_manager.shutdown(timeout=timeout) + logger.info_once("[shutdown] MPClient: engine manager stopped") + logger.info_once("[shutdown] MPClient: cleaning up background resources") self.resources() + logger.info_once("[shutdown] MPClient: complete") def _format_exception(self, e: Exception) -> Exception: """If errored, use EngineDeadError so root cause is clear.""" @@ -687,6 +696,9 @@ class MPClient(EngineCoreClient): if not _self or not _self._finalizer.alive or _self.resources.engine_dead: return _self.resources.engine_dead = True + logger.warning_once( + "[shutdown] MPClient: engine core exited unexpectedly; starting cleanup" + ) _self.shutdown() # Note: For MPClient, we don't have a failure callback mechanism # like MultiprocExecutor, but we set engine_dead flag which will diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index c5766c923c8..66564bebdb6 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -422,27 +422,45 @@ class MultiprocExecutor(Executor): return False active_procs = lambda: [proc for proc in worker_procs if proc.is_alive()] + initial_count = len(active_procs()) + # Give processes time to clean themselves up properly first - logger.debug("Worker Termination: allow workers to gracefully shutdown") + logger.info( + "[shutdown] Executor: waiting for worker exit count=%d", + initial_count, + ) if wait_for_termination(active_procs(), 4): + logger.info_once("[shutdown] Executor: all workers exited gracefully") return # Send SIGTERM if still running - logger.debug("Worker Termination: workers still running sending SIGTERM") - for p in active_procs(): + remaining = active_procs() + logger.warning( + "[shutdown] Executor: workers still running after grace period; " + "sending SIGTERM count=%d", + len(remaining), + ) + for p in remaining: p.terminate() if not wait_for_termination(active_procs(), 4): # Send SIGKILL if still running - logger.debug( - "Worker Termination: resorting to SIGKILL to take down workers" + remaining = active_procs() + logger.warning( + "[shutdown] Executor: workers still running after SIGTERM; " + "sending SIGKILL count=%d", + len(remaining), ) - for p in active_procs(): + for p in remaining: p.kill() def shutdown(self): """Properly shut down the executor and its workers""" if not getattr(self, "shutting_down", False): - logger.debug("Triggering shutdown of workers") + worker_count = len(getattr(self, "workers", None) or []) + logger.debug( + "[shutdown] Executor: start worker_count=%d", + worker_count, + ) self.shutting_down = True # Make sure all the worker processes are terminated first. @@ -468,6 +486,8 @@ class MultiprocExecutor(Executor): mq.shutdown() self.response_mqs = [] + logger.debug_once("[shutdown] Executor: complete") + def check_health(self) -> None: self.collective_rpc("check_health", timeout=10) return @@ -867,7 +887,9 @@ class WorkerProc: if ready_writer is not None: logger.exception("WorkerProc failed to start.") elif shutdown_requested.is_set(): - logger.info("WorkerProc shutting down.") + logger.debug_once( + "[shutdown] WorkerProc: exiting after shutdown request" + ) else: logger.exception("WorkerProc failed.") @@ -879,7 +901,12 @@ class WorkerProc: except SystemExit as e: # SystemExit is raised on SIGTERM or SIGKILL, which usually indicates that # the graceful shutdown process did not succeed - logger.warning("WorkerProc was terminated") + if shutdown_requested.is_set(): + logger.debug_once( + "[shutdown] WorkerProc: terminated by shutdown signal" + ) + else: + logger.warning("WorkerProc was terminated") # SystemExit must never be ignored raise e diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index f11c92a805d..ffdf43f54c3 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -444,6 +444,12 @@ def _shutdown_subprocesses( timeout = 0.0 timeout = max(timeout, 5.0) + logger.debug( + "[shutdown] Subprocess manager: start process_count=%d timeout=%ss", + len(procs), + timeout, + ) + for proc in procs: if proc.is_alive(): proc.terminate() @@ -456,9 +462,18 @@ def _shutdown_subprocesses( if proc.is_alive(): proc.join(remaining) - for proc in procs: - if proc.is_alive() and (pid := proc.pid) is not None: - kill_process_tree(pid) + remaining_pids = [ + proc.pid for proc in procs if proc.is_alive() and proc.pid is not None + ] + if remaining_pids: + logger.warning( + "[shutdown] Subprocess manager: force killing remaining processes count=%d", + len(remaining_pids), + ) + for pid in remaining_pids: + kill_process_tree(pid) + + logger.debug_once("[shutdown] Subprocess manager: complete") def run_api_server_worker_proc( @@ -565,6 +580,12 @@ def shutdown(procs: list[BaseProcess], timeout: float | None = None) -> None: # have a user-configured shutdown timeout. timeout = 5.0 + logger.debug( + "[shutdown] Process manager: start process_count=%d timeout=%ss", + len(procs), + timeout, + ) + # Shutdown the process. for proc in procs: if proc.is_alive(): @@ -579,9 +600,18 @@ def shutdown(procs: list[BaseProcess], timeout: float | None = None) -> None: if proc.is_alive(): proc.join(remaining) - for proc in procs: - if proc.is_alive() and (pid := proc.pid) is not None: - kill_process_tree(pid) + remaining_pids = [ + proc.pid for proc in procs if proc.is_alive() and proc.pid is not None + ] + if remaining_pids: + logger.warning( + "[shutdown] Process manager: force killing remaining processes count=%d", + len(remaining_pids), + ) + for pid in remaining_pids: + kill_process_tree(pid) + + logger.debug_once("[shutdown] Process manager: complete") def copy_slice(