Merge branch 'main' into spark-weekly-newcases

This commit is contained in:
Larry Xu 2026-01-13 16:39:38 +08:00 committed by GitHub
commit db09dafbc9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 184 additions and 13 deletions

View File

@ -156,6 +156,7 @@ jobs:
"kaiyux",
"kanghui0204",
"karljang",
"karthikvetrivel",
"katec846",
"Kefeng-Duan",
"KingsleyLiu-NV",
@ -193,6 +194,7 @@ jobs:
"mlefeb01",
"moraxu",
"MrGeva",
"mzweilz",
"Naveassaf",
"nekorobov",
"netanel-haber",

View File

@ -874,7 +874,6 @@ def getMountListForSlurmTest(SlurmCluster cluster, boolean useSbatch = false)
}
mounts += [
"${cluster.scratchPath}:/scratch.trt_llm_data:ro",
"/home/svc_tensorrt/.cache:/root/.cache",
]
} else {
throw new Exception("Unsupported container runtime: ${cluster.containerRuntime}")

View File

@ -16,6 +16,7 @@ import contextlib
import ctypes
import os
import platform
import threading
import time
from functools import wraps
from pathlib import Path
@ -34,7 +35,7 @@ if TYPE_CHECKING:
else:
Network = None
from ._utils import str_dtype_to_trt
from ._utils import print_all_stacks, str_dtype_to_trt
from .bindings import MpiComm
from .logger import logger
from .plugin import _load_plugin_lib
@ -82,6 +83,19 @@ def _init(log_level: object = None) -> None:
MpiComm.local_init()
def _print_stacks():
counter = 0
while True:
time.sleep(print_stacks_period)
counter += 1
logger.error(f"Printing stacks {counter} times")
print_all_stacks()
print_stacks_period = int(os.getenv("TRTLLM_PRINT_STACKS_PERIOD", "-1"))
if print_stacks_period > 0:
print_stacks_thread = threading.Thread(target=_print_stacks, daemon=True)
print_stacks_thread.start()
logger.info("TensorRT LLM inited.")

View File

@ -14,6 +14,7 @@ from tensorrt_llm._utils import mpi_disabled, nvtx_range
from tensorrt_llm.mapping import CpType
from ..distributed import Distributed
from .hang_detector import HangDetector
from .llm_request import (ExecutorRequest, LlmRequest,
executor_request_to_llm_request)
@ -47,10 +48,17 @@ class RequestQueueItem:
class ExecutorRequestQueue:
"""Handles fetching and processing of new requests from the request queue."""
def __init__(self, dist: Distributed, enable_attention_dp: bool,
max_batch_size: int, max_beam_width: int,
max_num_active_requests: int, enable_iter_perf_stats: bool,
batch_wait_timeout_ms: float):
def __init__(
self,
dist: Distributed,
enable_attention_dp: bool,
max_batch_size: int,
max_beam_width: int,
max_num_active_requests: int,
enable_iter_perf_stats: bool,
batch_wait_timeout_ms: float,
hang_detector: Optional[HangDetector] = None,
):
self.dist = dist
self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue()
self.waiting_queue: deque[RequestQueueItem] = deque()
@ -66,6 +74,7 @@ class ExecutorRequestQueue:
self.active = True
self.batch_wait_timeout_ms = batch_wait_timeout_ms
self.send_requests_handler = None
self.hang_detector = hang_detector or HangDetector()
# State tracking
self.num_fetch_requests = 0
@ -303,7 +312,8 @@ class ExecutorRequestQueue:
self.request_accumulated.clear()
# Reset timeout to 0 to avoid hanging when no new requests are available
timeout = datetime.timedelta(0)
new_requests.extend(self._get_from_request_queue(timeout))
with self.hang_detector.pause():
new_requests.extend(self._get_from_request_queue(timeout))
# Broadcast requests and handle Python objects
new_requests, py_request_objects = self._handle_request_broadcasting(
@ -477,8 +487,9 @@ class ExecutorRequestQueue:
# Preserve original `new_requests` on rank 0
_ = self._broadcast_new_requests(new_requests, py_request_objects)
else:
new_requests, py_request_objects = self._broadcast_new_requests(
new_requests, py_request_objects)
with self.hang_detector.pause():
new_requests, py_request_objects = self._broadcast_new_requests(
new_requests, py_request_objects)
return new_requests, py_request_objects

View File

@ -0,0 +1,96 @@
import asyncio
import threading
from contextlib import contextmanager
from typing import Callable, Optional
from tensorrt_llm._utils import print_all_stacks
from tensorrt_llm.logger import logger
class HangDetector:
def __init__(
self, timeout: Optional[int] = None, on_detected: Optional[Callable[[], None]] = None
):
self.timeout = timeout if timeout is not None else 300
assert self.timeout > 0, "timeout must be greater than 0"
self.on_detected = on_detected or (lambda: None)
self.task = None
self.loop = None
self.loop_thread = None
self.lock = threading.Lock()
self.active = False
self._detected = False
def start(self):
"""Enable hang detection."""
def run_loop():
asyncio.set_event_loop(self.loop)
self.loop.run_forever()
self.active = True
self.loop = asyncio.new_event_loop()
self.loop_thread = threading.Thread(target=run_loop, daemon=True, name="hang_detector_loop")
self.loop_thread.start()
async def _detect_hang(self):
await asyncio.sleep(self.timeout)
with self.lock:
self._detected = True
logger.error(f"Hang detected after {self.timeout} seconds.")
print_all_stacks()
self.on_detected()
def detected(self):
"""Return True if hang is detected."""
with self.lock:
return self._detected
def checkpoint(self):
"""Reset hang detection timer."""
self.cancel_task()
if self.active:
self.task = asyncio.run_coroutine_threadsafe(self._detect_hang(), self.loop)
def cancel_task(self):
"""Cancel the hang detection task."""
if self.task is not None and not self.task.done():
self.task.cancel()
self.task = None
@contextmanager
def pause(self):
"""Pause hang detection in scope."""
try:
self.cancel_task()
yield
finally:
self.checkpoint()
def stop(self):
"""Stop hang detection."""
self.active = False
self.cancel_task()
if self.loop is not None:
# Cancel all pending tasks before stopping the loop
def cancel_all_tasks():
for task in asyncio.all_tasks(self.loop):
if not task.done():
task.cancel()
self.loop.call_soon(self.loop.stop)
self.loop.call_soon_threadsafe(cancel_all_tasks)
if self.loop_thread is not None and self.loop_thread.is_alive():
self.loop_thread.join()
self.loop = None
self.loop_thread = None
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.stop()
return False

View File

@ -46,6 +46,7 @@ from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem
from .guided_decoder import GuidedDecoder
from .handle_additional_outputs import HandleAdditionalOutputs
from .handle_logits import HandleLogits
from .hang_detector import HangDetector
from .kv_cache_connector import KvCacheConnectorManager
from .kv_cache_transceiver import KvCacheTransceiver
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
@ -137,6 +138,7 @@ class PyExecutor:
max_seq_len: Optional[int] = None,
peft_cache_config: Optional[PeftCacheConfig] = None,
virtual_memory_pools: Optional[dict] = None,
hang_detection_timeout: Optional[int] = None,
execution_stream: Optional[torch.cuda.Stream] = None):
super(PyExecutor, self).__init__()
self.device_id = torch.cuda.current_device()
@ -280,6 +282,15 @@ class PyExecutor:
self.adp_ctx_batching_wait_iters_count = 0
self.batch_wait_iters_count = 0
def on_detected():
self._handle_errors(
f"Hang detected on rank {self.global_rank} in PyExecutor.")
self.shutdown_event.set()
self.is_shutdown = True
self.hang_detector = HangDetector(timeout=hang_detection_timeout,
on_detected=on_detected)
# request fetcher initialization
self._set_global_steady_clock_offset()
self.executor_request_queue = ExecutorRequestQueue(
@ -290,6 +301,7 @@ class PyExecutor:
max_num_active_requests=self.max_num_active_requests,
enable_iter_perf_stats=self.enable_iter_perf_stats,
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
hang_detector=self.hang_detector,
)
self.executor_request_queue.set_exclude_last_generation_logits(
self.disable_overlap_scheduler, self.dist.pp_size)
@ -476,6 +488,14 @@ class PyExecutor:
"""
self.executor_request_queue.enqueue_shutdown_request()
self.shutdown_event.wait()
if self.hang_detector.detected():
# Early return here to avoid waiting for hanging threads.
# Since `on_detected` has sent the error message as response,
# this worker will be asked to shutdown immediately.
# Since the whole process will shutdown after this `shutdown` call,
# All threads and memory pools will be freed properly.
logger.error("Hang detected, shutting down immediately.")
return
self.worker_thread.join()
self.worker_started = False
for manager in self.resource_manager.resource_managers.values():
@ -960,10 +980,11 @@ class PyExecutor:
# ensure the context is created, otherwise, some MPI calls will fail.
CUASSERT(cudart.cudaSetDevice(self.device_id))
microbatch_id = 0
with self._profiler() as profile_step:
with self._profiler() as profile_step, self.hang_detector:
iter_start_time = time.time()
iter_stats = None
while True:
self.hang_detector.checkpoint()
profile_step()
if self.enable_iter_perf_stats:
iter_start_time = time.time()
@ -1349,11 +1370,12 @@ class PyExecutor:
torch.cuda.set_device(self.device_id)
# ensure the context is created, otherwise, some MPI calls will fail.
CUASSERT(cudart.cudaSetDevice(self.device_id))
with self._profiler() as profile_step:
with self._profiler() as profile_step, self.hang_detector:
sample_state = None
iter_start_time = time.time()
iter_stats = None
while True:
self.hang_detector.checkpoint()
profile_step()
if self.enable_iter_perf_stats:
iter_start_time = time.time()
@ -1551,13 +1573,14 @@ class PyExecutor:
torch.cuda.set_device(self.device_id)
# ensure the context is created, otherwise, some MPI calls will fail.
CUASSERT(cudart.cudaSetDevice(self.device_id))
with self._profiler() as profile_step:
with self._profiler() as profile_step, self.hang_detector:
iter_start_time = time.time()
iter_stats = None
target_inputs = None
previous_tensors_device = None
can_forward = False if self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver else True
while True:
self.hang_detector.checkpoint()
profile_step()
if self.enable_iter_perf_stats:
iter_start_time = time.time()

View File

@ -21,8 +21,10 @@ import math
import os
import socket
import struct
import sys
import tempfile
import trace
import traceback
import weakref
from contextlib import contextmanager
from enum import EnumMeta
@ -761,6 +763,13 @@ def is_sm_100f(sm_version=None):
return sm_version == 100 or sm_version == 103
def print_all_stacks():
"""Print stack traces for all threads"""
for thread_id, frame in sys._current_frames().items():
logger.error(f"Thread {thread_id} stack trace:\n" +
"".join(traceback.format_stack(frame)))
def is_trace_enabled(env_var: str):
value = os.environ.get(env_var, "-1")
if value == "ALL":

View File

@ -1,5 +1,7 @@
import gc
import os
import threading
import time
import traceback
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
@ -9,7 +11,7 @@ import zmq
from tensorrt_llm.logger import logger
from .._utils import mpi_comm, mpi_rank
from .._utils import mpi_comm, mpi_rank, print_all_stacks
from ..bindings import executor as tllm
from ..builder import Engine
from ..llmapi.llm_args import BaseLlmArgs
@ -153,6 +155,21 @@ def worker_main(
hmac_key: Optional[bytes] = None,
) -> None:
def _print_stacks():
counter = 0
while True:
time.sleep(print_stacks_period)
counter += 1
logger.error(f"Printing stacks {counter} times")
print_all_stacks()
print_stacks_period = int(
os.getenv("TRTLLM_WORKER_PRINT_STACKS_PERIOD", "-1"))
if print_stacks_period > 0:
print_stacks_thread = threading.Thread(target=_print_stacks,
daemon=True)
print_stacks_thread.start()
mpi_comm().barrier()
if llm_args is not None and llm_args.env_overrides: