Merge branch 'main' into spark-weekly-newcases

This commit is contained in:
Larry Xu 2026-01-13 16:39:38 +08:00 committed by GitHub
commit db09dafbc9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 184 additions and 13 deletions

View File

@ -156,6 +156,7 @@ jobs:
"kaiyux", "kaiyux",
"kanghui0204", "kanghui0204",
"karljang", "karljang",
"karthikvetrivel",
"katec846", "katec846",
"Kefeng-Duan", "Kefeng-Duan",
"KingsleyLiu-NV", "KingsleyLiu-NV",
@ -193,6 +194,7 @@ jobs:
"mlefeb01", "mlefeb01",
"moraxu", "moraxu",
"MrGeva", "MrGeva",
"mzweilz",
"Naveassaf", "Naveassaf",
"nekorobov", "nekorobov",
"netanel-haber", "netanel-haber",

View File

@ -874,7 +874,6 @@ def getMountListForSlurmTest(SlurmCluster cluster, boolean useSbatch = false)
} }
mounts += [ mounts += [
"${cluster.scratchPath}:/scratch.trt_llm_data:ro", "${cluster.scratchPath}:/scratch.trt_llm_data:ro",
"/home/svc_tensorrt/.cache:/root/.cache",
] ]
} else { } else {
throw new Exception("Unsupported container runtime: ${cluster.containerRuntime}") throw new Exception("Unsupported container runtime: ${cluster.containerRuntime}")

View File

@ -16,6 +16,7 @@ import contextlib
import ctypes import ctypes
import os import os
import platform import platform
import threading
import time import time
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
@ -34,7 +35,7 @@ if TYPE_CHECKING:
else: else:
Network = None Network = None
from ._utils import str_dtype_to_trt from ._utils import print_all_stacks, str_dtype_to_trt
from .bindings import MpiComm from .bindings import MpiComm
from .logger import logger from .logger import logger
from .plugin import _load_plugin_lib from .plugin import _load_plugin_lib
@ -82,6 +83,19 @@ def _init(log_level: object = None) -> None:
MpiComm.local_init() MpiComm.local_init()
def _print_stacks():
counter = 0
while True:
time.sleep(print_stacks_period)
counter += 1
logger.error(f"Printing stacks {counter} times")
print_all_stacks()
print_stacks_period = int(os.getenv("TRTLLM_PRINT_STACKS_PERIOD", "-1"))
if print_stacks_period > 0:
print_stacks_thread = threading.Thread(target=_print_stacks, daemon=True)
print_stacks_thread.start()
logger.info("TensorRT LLM inited.") logger.info("TensorRT LLM inited.")

View File

@ -14,6 +14,7 @@ from tensorrt_llm._utils import mpi_disabled, nvtx_range
from tensorrt_llm.mapping import CpType from tensorrt_llm.mapping import CpType
from ..distributed import Distributed from ..distributed import Distributed
from .hang_detector import HangDetector
from .llm_request import (ExecutorRequest, LlmRequest, from .llm_request import (ExecutorRequest, LlmRequest,
executor_request_to_llm_request) executor_request_to_llm_request)
@ -47,10 +48,17 @@ class RequestQueueItem:
class ExecutorRequestQueue: class ExecutorRequestQueue:
"""Handles fetching and processing of new requests from the request queue.""" """Handles fetching and processing of new requests from the request queue."""
def __init__(self, dist: Distributed, enable_attention_dp: bool, def __init__(
max_batch_size: int, max_beam_width: int, self,
max_num_active_requests: int, enable_iter_perf_stats: bool, dist: Distributed,
batch_wait_timeout_ms: float): enable_attention_dp: bool,
max_batch_size: int,
max_beam_width: int,
max_num_active_requests: int,
enable_iter_perf_stats: bool,
batch_wait_timeout_ms: float,
hang_detector: Optional[HangDetector] = None,
):
self.dist = dist self.dist = dist
self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue() self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue()
self.waiting_queue: deque[RequestQueueItem] = deque() self.waiting_queue: deque[RequestQueueItem] = deque()
@ -66,6 +74,7 @@ class ExecutorRequestQueue:
self.active = True self.active = True
self.batch_wait_timeout_ms = batch_wait_timeout_ms self.batch_wait_timeout_ms = batch_wait_timeout_ms
self.send_requests_handler = None self.send_requests_handler = None
self.hang_detector = hang_detector or HangDetector()
# State tracking # State tracking
self.num_fetch_requests = 0 self.num_fetch_requests = 0
@ -303,7 +312,8 @@ class ExecutorRequestQueue:
self.request_accumulated.clear() self.request_accumulated.clear()
# Reset timeout to 0 to avoid hanging when no new requests are available # Reset timeout to 0 to avoid hanging when no new requests are available
timeout = datetime.timedelta(0) timeout = datetime.timedelta(0)
new_requests.extend(self._get_from_request_queue(timeout)) with self.hang_detector.pause():
new_requests.extend(self._get_from_request_queue(timeout))
# Broadcast requests and handle Python objects # Broadcast requests and handle Python objects
new_requests, py_request_objects = self._handle_request_broadcasting( new_requests, py_request_objects = self._handle_request_broadcasting(
@ -477,8 +487,9 @@ class ExecutorRequestQueue:
# Preserve original `new_requests` on rank 0 # Preserve original `new_requests` on rank 0
_ = self._broadcast_new_requests(new_requests, py_request_objects) _ = self._broadcast_new_requests(new_requests, py_request_objects)
else: else:
new_requests, py_request_objects = self._broadcast_new_requests( with self.hang_detector.pause():
new_requests, py_request_objects) new_requests, py_request_objects = self._broadcast_new_requests(
new_requests, py_request_objects)
return new_requests, py_request_objects return new_requests, py_request_objects

View File

@ -0,0 +1,96 @@
import asyncio
import threading
from contextlib import contextmanager
from typing import Callable, Optional
from tensorrt_llm._utils import print_all_stacks
from tensorrt_llm.logger import logger
class HangDetector:
def __init__(
self, timeout: Optional[int] = None, on_detected: Optional[Callable[[], None]] = None
):
self.timeout = timeout if timeout is not None else 300
assert self.timeout > 0, "timeout must be greater than 0"
self.on_detected = on_detected or (lambda: None)
self.task = None
self.loop = None
self.loop_thread = None
self.lock = threading.Lock()
self.active = False
self._detected = False
def start(self):
"""Enable hang detection."""
def run_loop():
asyncio.set_event_loop(self.loop)
self.loop.run_forever()
self.active = True
self.loop = asyncio.new_event_loop()
self.loop_thread = threading.Thread(target=run_loop, daemon=True, name="hang_detector_loop")
self.loop_thread.start()
async def _detect_hang(self):
await asyncio.sleep(self.timeout)
with self.lock:
self._detected = True
logger.error(f"Hang detected after {self.timeout} seconds.")
print_all_stacks()
self.on_detected()
def detected(self):
"""Return True if hang is detected."""
with self.lock:
return self._detected
def checkpoint(self):
"""Reset hang detection timer."""
self.cancel_task()
if self.active:
self.task = asyncio.run_coroutine_threadsafe(self._detect_hang(), self.loop)
def cancel_task(self):
"""Cancel the hang detection task."""
if self.task is not None and not self.task.done():
self.task.cancel()
self.task = None
@contextmanager
def pause(self):
"""Pause hang detection in scope."""
try:
self.cancel_task()
yield
finally:
self.checkpoint()
def stop(self):
"""Stop hang detection."""
self.active = False
self.cancel_task()
if self.loop is not None:
# Cancel all pending tasks before stopping the loop
def cancel_all_tasks():
for task in asyncio.all_tasks(self.loop):
if not task.done():
task.cancel()
self.loop.call_soon(self.loop.stop)
self.loop.call_soon_threadsafe(cancel_all_tasks)
if self.loop_thread is not None and self.loop_thread.is_alive():
self.loop_thread.join()
self.loop = None
self.loop_thread = None
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.stop()
return False

View File

@ -46,6 +46,7 @@ from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem
from .guided_decoder import GuidedDecoder from .guided_decoder import GuidedDecoder
from .handle_additional_outputs import HandleAdditionalOutputs from .handle_additional_outputs import HandleAdditionalOutputs
from .handle_logits import HandleLogits from .handle_logits import HandleLogits
from .hang_detector import HangDetector
from .kv_cache_connector import KvCacheConnectorManager from .kv_cache_connector import KvCacheConnectorManager
from .kv_cache_transceiver import KvCacheTransceiver from .kv_cache_transceiver import KvCacheTransceiver
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState, from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
@ -137,6 +138,7 @@ class PyExecutor:
max_seq_len: Optional[int] = None, max_seq_len: Optional[int] = None,
peft_cache_config: Optional[PeftCacheConfig] = None, peft_cache_config: Optional[PeftCacheConfig] = None,
virtual_memory_pools: Optional[dict] = None, virtual_memory_pools: Optional[dict] = None,
hang_detection_timeout: Optional[int] = None,
execution_stream: Optional[torch.cuda.Stream] = None): execution_stream: Optional[torch.cuda.Stream] = None):
super(PyExecutor, self).__init__() super(PyExecutor, self).__init__()
self.device_id = torch.cuda.current_device() self.device_id = torch.cuda.current_device()
@ -280,6 +282,15 @@ class PyExecutor:
self.adp_ctx_batching_wait_iters_count = 0 self.adp_ctx_batching_wait_iters_count = 0
self.batch_wait_iters_count = 0 self.batch_wait_iters_count = 0
def on_detected():
self._handle_errors(
f"Hang detected on rank {self.global_rank} in PyExecutor.")
self.shutdown_event.set()
self.is_shutdown = True
self.hang_detector = HangDetector(timeout=hang_detection_timeout,
on_detected=on_detected)
# request fetcher initialization # request fetcher initialization
self._set_global_steady_clock_offset() self._set_global_steady_clock_offset()
self.executor_request_queue = ExecutorRequestQueue( self.executor_request_queue = ExecutorRequestQueue(
@ -290,6 +301,7 @@ class PyExecutor:
max_num_active_requests=self.max_num_active_requests, max_num_active_requests=self.max_num_active_requests,
enable_iter_perf_stats=self.enable_iter_perf_stats, enable_iter_perf_stats=self.enable_iter_perf_stats,
batch_wait_timeout_ms=self.batch_wait_timeout_ms, batch_wait_timeout_ms=self.batch_wait_timeout_ms,
hang_detector=self.hang_detector,
) )
self.executor_request_queue.set_exclude_last_generation_logits( self.executor_request_queue.set_exclude_last_generation_logits(
self.disable_overlap_scheduler, self.dist.pp_size) self.disable_overlap_scheduler, self.dist.pp_size)
@ -476,6 +488,14 @@ class PyExecutor:
""" """
self.executor_request_queue.enqueue_shutdown_request() self.executor_request_queue.enqueue_shutdown_request()
self.shutdown_event.wait() self.shutdown_event.wait()
if self.hang_detector.detected():
# Early return here to avoid waiting for hanging threads.
# Since `on_detected` has sent the error message as response,
# this worker will be asked to shutdown immediately.
# Since the whole process will shutdown after this `shutdown` call,
# All threads and memory pools will be freed properly.
logger.error("Hang detected, shutting down immediately.")
return
self.worker_thread.join() self.worker_thread.join()
self.worker_started = False self.worker_started = False
for manager in self.resource_manager.resource_managers.values(): for manager in self.resource_manager.resource_managers.values():
@ -960,10 +980,11 @@ class PyExecutor:
# ensure the context is created, otherwise, some MPI calls will fail. # ensure the context is created, otherwise, some MPI calls will fail.
CUASSERT(cudart.cudaSetDevice(self.device_id)) CUASSERT(cudart.cudaSetDevice(self.device_id))
microbatch_id = 0 microbatch_id = 0
with self._profiler() as profile_step: with self._profiler() as profile_step, self.hang_detector:
iter_start_time = time.time() iter_start_time = time.time()
iter_stats = None iter_stats = None
while True: while True:
self.hang_detector.checkpoint()
profile_step() profile_step()
if self.enable_iter_perf_stats: if self.enable_iter_perf_stats:
iter_start_time = time.time() iter_start_time = time.time()
@ -1349,11 +1370,12 @@ class PyExecutor:
torch.cuda.set_device(self.device_id) torch.cuda.set_device(self.device_id)
# ensure the context is created, otherwise, some MPI calls will fail. # ensure the context is created, otherwise, some MPI calls will fail.
CUASSERT(cudart.cudaSetDevice(self.device_id)) CUASSERT(cudart.cudaSetDevice(self.device_id))
with self._profiler() as profile_step: with self._profiler() as profile_step, self.hang_detector:
sample_state = None sample_state = None
iter_start_time = time.time() iter_start_time = time.time()
iter_stats = None iter_stats = None
while True: while True:
self.hang_detector.checkpoint()
profile_step() profile_step()
if self.enable_iter_perf_stats: if self.enable_iter_perf_stats:
iter_start_time = time.time() iter_start_time = time.time()
@ -1551,13 +1573,14 @@ class PyExecutor:
torch.cuda.set_device(self.device_id) torch.cuda.set_device(self.device_id)
# ensure the context is created, otherwise, some MPI calls will fail. # ensure the context is created, otherwise, some MPI calls will fail.
CUASSERT(cudart.cudaSetDevice(self.device_id)) CUASSERT(cudart.cudaSetDevice(self.device_id))
with self._profiler() as profile_step: with self._profiler() as profile_step, self.hang_detector:
iter_start_time = time.time() iter_start_time = time.time()
iter_stats = None iter_stats = None
target_inputs = None target_inputs = None
previous_tensors_device = None previous_tensors_device = None
can_forward = False if self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver else True can_forward = False if self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver else True
while True: while True:
self.hang_detector.checkpoint()
profile_step() profile_step()
if self.enable_iter_perf_stats: if self.enable_iter_perf_stats:
iter_start_time = time.time() iter_start_time = time.time()

View File

@ -21,8 +21,10 @@ import math
import os import os
import socket import socket
import struct import struct
import sys
import tempfile import tempfile
import trace import trace
import traceback
import weakref import weakref
from contextlib import contextmanager from contextlib import contextmanager
from enum import EnumMeta from enum import EnumMeta
@ -761,6 +763,13 @@ def is_sm_100f(sm_version=None):
return sm_version == 100 or sm_version == 103 return sm_version == 100 or sm_version == 103
def print_all_stacks():
"""Print stack traces for all threads"""
for thread_id, frame in sys._current_frames().items():
logger.error(f"Thread {thread_id} stack trace:\n" +
"".join(traceback.format_stack(frame)))
def is_trace_enabled(env_var: str): def is_trace_enabled(env_var: str):
value = os.environ.get(env_var, "-1") value = os.environ.get(env_var, "-1")
if value == "ALL": if value == "ALL":

View File

@ -1,5 +1,7 @@
import gc import gc
import os import os
import threading
import time
import traceback import traceback
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
from pathlib import Path from pathlib import Path
@ -9,7 +11,7 @@ import zmq
from tensorrt_llm.logger import logger from tensorrt_llm.logger import logger
from .._utils import mpi_comm, mpi_rank from .._utils import mpi_comm, mpi_rank, print_all_stacks
from ..bindings import executor as tllm from ..bindings import executor as tllm
from ..builder import Engine from ..builder import Engine
from ..llmapi.llm_args import BaseLlmArgs from ..llmapi.llm_args import BaseLlmArgs
@ -153,6 +155,21 @@ def worker_main(
hmac_key: Optional[bytes] = None, hmac_key: Optional[bytes] = None,
) -> None: ) -> None:
def _print_stacks():
counter = 0
while True:
time.sleep(print_stacks_period)
counter += 1
logger.error(f"Printing stacks {counter} times")
print_all_stacks()
print_stacks_period = int(
os.getenv("TRTLLM_WORKER_PRINT_STACKS_PERIOD", "-1"))
if print_stacks_period > 0:
print_stacks_thread = threading.Thread(target=_print_stacks,
daemon=True)
print_stacks_thread.start()
mpi_comm().barrier() mpi_comm().barrier()
if llm_args is not None and llm_args.env_overrides: if llm_args is not None and llm_args.env_overrides: