diff --git a/.github/workflows/blossom-ci.yml b/.github/workflows/blossom-ci.yml index 6b166503f4..be774ec1cc 100644 --- a/.github/workflows/blossom-ci.yml +++ b/.github/workflows/blossom-ci.yml @@ -156,6 +156,7 @@ jobs: "kaiyux", "kanghui0204", "karljang", + "karthikvetrivel", "katec846", "Kefeng-Duan", "KingsleyLiu-NV", @@ -193,6 +194,7 @@ jobs: "mlefeb01", "moraxu", "MrGeva", + "mzweilz", "Naveassaf", "nekorobov", "netanel-haber", diff --git a/tensorrt_llm/_common.py b/tensorrt_llm/_common.py index c0d64abb81..871120aabf 100644 --- a/tensorrt_llm/_common.py +++ b/tensorrt_llm/_common.py @@ -16,6 +16,7 @@ import contextlib import ctypes import os import platform +import threading import time from functools import wraps from pathlib import Path @@ -34,7 +35,7 @@ if TYPE_CHECKING: else: Network = None -from ._utils import str_dtype_to_trt +from ._utils import print_all_stacks, str_dtype_to_trt from .bindings import MpiComm from .logger import logger from .plugin import _load_plugin_lib @@ -82,6 +83,19 @@ def _init(log_level: object = None) -> None: MpiComm.local_init() + def _print_stacks(): + counter = 0 + while True: + time.sleep(print_stacks_period) + counter += 1 + logger.error(f"Printing stacks {counter} times") + print_all_stacks() + + print_stacks_period = int(os.getenv("TRTLLM_PRINT_STACKS_PERIOD", "-1")) + if print_stacks_period > 0: + print_stacks_thread = threading.Thread(target=_print_stacks, daemon=True) + print_stacks_thread.start() + logger.info("TensorRT LLM inited.") diff --git a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py index 161282e4c4..cb42186520 100644 --- a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py @@ -14,6 +14,7 @@ from tensorrt_llm._utils import mpi_disabled, nvtx_range from tensorrt_llm.mapping import CpType from ..distributed import Distributed +from .hang_detector import HangDetector from .llm_request import (ExecutorRequest, LlmRequest, executor_request_to_llm_request) @@ -47,10 +48,17 @@ class RequestQueueItem: class ExecutorRequestQueue: """Handles fetching and processing of new requests from the request queue.""" - def __init__(self, dist: Distributed, enable_attention_dp: bool, - max_batch_size: int, max_beam_width: int, - max_num_active_requests: int, enable_iter_perf_stats: bool, - batch_wait_timeout_ms: float): + def __init__( + self, + dist: Distributed, + enable_attention_dp: bool, + max_batch_size: int, + max_beam_width: int, + max_num_active_requests: int, + enable_iter_perf_stats: bool, + batch_wait_timeout_ms: float, + hang_detector: Optional[HangDetector] = None, + ): self.dist = dist self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue() self.waiting_queue: deque[RequestQueueItem] = deque() @@ -66,6 +74,7 @@ class ExecutorRequestQueue: self.active = True self.batch_wait_timeout_ms = batch_wait_timeout_ms self.send_requests_handler = None + self.hang_detector = hang_detector or HangDetector() # State tracking self.num_fetch_requests = 0 @@ -303,7 +312,8 @@ class ExecutorRequestQueue: self.request_accumulated.clear() # Reset timeout to 0 to avoid hanging when no new requests are available timeout = datetime.timedelta(0) - new_requests.extend(self._get_from_request_queue(timeout)) + with self.hang_detector.pause(): + new_requests.extend(self._get_from_request_queue(timeout)) # Broadcast requests and handle Python objects new_requests, py_request_objects = self._handle_request_broadcasting( @@ -477,8 +487,9 @@ class ExecutorRequestQueue: # Preserve original `new_requests` on rank 0 _ = self._broadcast_new_requests(new_requests, py_request_objects) else: - new_requests, py_request_objects = self._broadcast_new_requests( - new_requests, py_request_objects) + with self.hang_detector.pause(): + new_requests, py_request_objects = self._broadcast_new_requests( + new_requests, py_request_objects) return new_requests, py_request_objects diff --git a/tensorrt_llm/_torch/pyexecutor/hang_detector.py b/tensorrt_llm/_torch/pyexecutor/hang_detector.py new file mode 100644 index 0000000000..ae23b11fba --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/hang_detector.py @@ -0,0 +1,96 @@ +import asyncio +import threading +from contextlib import contextmanager +from typing import Callable, Optional + +from tensorrt_llm._utils import print_all_stacks +from tensorrt_llm.logger import logger + + +class HangDetector: + def __init__( + self, timeout: Optional[int] = None, on_detected: Optional[Callable[[], None]] = None + ): + self.timeout = timeout if timeout is not None else 300 + assert self.timeout > 0, "timeout must be greater than 0" + self.on_detected = on_detected or (lambda: None) + self.task = None + self.loop = None + self.loop_thread = None + self.lock = threading.Lock() + self.active = False + self._detected = False + + def start(self): + """Enable hang detection.""" + + def run_loop(): + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + + self.active = True + self.loop = asyncio.new_event_loop() + self.loop_thread = threading.Thread(target=run_loop, daemon=True, name="hang_detector_loop") + self.loop_thread.start() + + async def _detect_hang(self): + await asyncio.sleep(self.timeout) + with self.lock: + self._detected = True + logger.error(f"Hang detected after {self.timeout} seconds.") + print_all_stacks() + self.on_detected() + + def detected(self): + """Return True if hang is detected.""" + with self.lock: + return self._detected + + def checkpoint(self): + """Reset hang detection timer.""" + self.cancel_task() + if self.active: + self.task = asyncio.run_coroutine_threadsafe(self._detect_hang(), self.loop) + + def cancel_task(self): + """Cancel the hang detection task.""" + if self.task is not None and not self.task.done(): + self.task.cancel() + self.task = None + + @contextmanager + def pause(self): + """Pause hang detection in scope.""" + try: + self.cancel_task() + yield + finally: + self.checkpoint() + + def stop(self): + """Stop hang detection.""" + self.active = False + self.cancel_task() + if self.loop is not None: + # Cancel all pending tasks before stopping the loop + def cancel_all_tasks(): + for task in asyncio.all_tasks(self.loop): + if not task.done(): + task.cancel() + self.loop.call_soon(self.loop.stop) + + self.loop.call_soon_threadsafe(cancel_all_tasks) + + if self.loop_thread is not None and self.loop_thread.is_alive(): + self.loop_thread.join() + + self.loop = None + self.loop_thread = None + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.stop() + return False diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 7e6aa747f4..4129973363 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -46,6 +46,7 @@ from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem from .guided_decoder import GuidedDecoder from .handle_additional_outputs import HandleAdditionalOutputs from .handle_logits import HandleLogits +from .hang_detector import HangDetector from .kv_cache_connector import KvCacheConnectorManager from .kv_cache_transceiver import KvCacheTransceiver from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState, @@ -137,6 +138,7 @@ class PyExecutor: max_seq_len: Optional[int] = None, peft_cache_config: Optional[PeftCacheConfig] = None, virtual_memory_pools: Optional[dict] = None, + hang_detection_timeout: Optional[int] = None, execution_stream: Optional[torch.cuda.Stream] = None): super(PyExecutor, self).__init__() self.device_id = torch.cuda.current_device() @@ -280,6 +282,15 @@ class PyExecutor: self.adp_ctx_batching_wait_iters_count = 0 self.batch_wait_iters_count = 0 + def on_detected(): + self._handle_errors( + f"Hang detected on rank {self.global_rank} in PyExecutor.") + self.shutdown_event.set() + self.is_shutdown = True + + self.hang_detector = HangDetector(timeout=hang_detection_timeout, + on_detected=on_detected) + # request fetcher initialization self._set_global_steady_clock_offset() self.executor_request_queue = ExecutorRequestQueue( @@ -290,6 +301,7 @@ class PyExecutor: max_num_active_requests=self.max_num_active_requests, enable_iter_perf_stats=self.enable_iter_perf_stats, batch_wait_timeout_ms=self.batch_wait_timeout_ms, + hang_detector=self.hang_detector, ) self.executor_request_queue.set_exclude_last_generation_logits( self.disable_overlap_scheduler, self.dist.pp_size) @@ -476,6 +488,14 @@ class PyExecutor: """ self.executor_request_queue.enqueue_shutdown_request() self.shutdown_event.wait() + if self.hang_detector.detected(): + # Early return here to avoid waiting for hanging threads. + # Since `on_detected` has sent the error message as response, + # this worker will be asked to shutdown immediately. + # Since the whole process will shutdown after this `shutdown` call, + # All threads and memory pools will be freed properly. + logger.error("Hang detected, shutting down immediately.") + return self.worker_thread.join() self.worker_started = False for manager in self.resource_manager.resource_managers.values(): @@ -960,10 +980,11 @@ class PyExecutor: # ensure the context is created, otherwise, some MPI calls will fail. CUASSERT(cudart.cudaSetDevice(self.device_id)) microbatch_id = 0 - with self._profiler() as profile_step: + with self._profiler() as profile_step, self.hang_detector: iter_start_time = time.time() iter_stats = None while True: + self.hang_detector.checkpoint() profile_step() if self.enable_iter_perf_stats: iter_start_time = time.time() @@ -1349,11 +1370,12 @@ class PyExecutor: torch.cuda.set_device(self.device_id) # ensure the context is created, otherwise, some MPI calls will fail. CUASSERT(cudart.cudaSetDevice(self.device_id)) - with self._profiler() as profile_step: + with self._profiler() as profile_step, self.hang_detector: sample_state = None iter_start_time = time.time() iter_stats = None while True: + self.hang_detector.checkpoint() profile_step() if self.enable_iter_perf_stats: iter_start_time = time.time() @@ -1551,13 +1573,14 @@ class PyExecutor: torch.cuda.set_device(self.device_id) # ensure the context is created, otherwise, some MPI calls will fail. CUASSERT(cudart.cudaSetDevice(self.device_id)) - with self._profiler() as profile_step: + with self._profiler() as profile_step, self.hang_detector: iter_start_time = time.time() iter_stats = None target_inputs = None previous_tensors_device = None can_forward = False if self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver else True while True: + self.hang_detector.checkpoint() profile_step() if self.enable_iter_perf_stats: iter_start_time = time.time() diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index 86ebaef371..bfb4d32b42 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -21,8 +21,10 @@ import math import os import socket import struct +import sys import tempfile import trace +import traceback import weakref from contextlib import contextmanager from enum import EnumMeta @@ -761,6 +763,13 @@ def is_sm_100f(sm_version=None): return sm_version == 100 or sm_version == 103 +def print_all_stacks(): + """Print stack traces for all threads""" + for thread_id, frame in sys._current_frames().items(): + logger.error(f"Thread {thread_id} stack trace:\n" + + "".join(traceback.format_stack(frame))) + + def is_trace_enabled(env_var: str): value = os.environ.get(env_var, "-1") if value == "ALL": diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index c4917a86a5..8e13f8c636 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -1,5 +1,7 @@ import gc import os +import threading +import time import traceback from concurrent.futures import ProcessPoolExecutor from pathlib import Path @@ -9,7 +11,7 @@ import zmq from tensorrt_llm.logger import logger -from .._utils import mpi_comm, mpi_rank +from .._utils import mpi_comm, mpi_rank, print_all_stacks from ..bindings import executor as tllm from ..builder import Engine from ..llmapi.llm_args import BaseLlmArgs @@ -153,6 +155,21 @@ def worker_main( hmac_key: Optional[bytes] = None, ) -> None: + def _print_stacks(): + counter = 0 + while True: + time.sleep(print_stacks_period) + counter += 1 + logger.error(f"Printing stacks {counter} times") + print_all_stacks() + + print_stacks_period = int( + os.getenv("TRTLLM_WORKER_PRINT_STACKS_PERIOD", "-1")) + if print_stacks_period > 0: + print_stacks_thread = threading.Thread(target=_print_stacks, + daemon=True) + print_stacks_thread.start() + mpi_comm().barrier() if llm_args is not None and llm_args.env_overrides: