mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge branch 'main' into spark-weekly-newcases
This commit is contained in:
commit
db09dafbc9
2
.github/workflows/blossom-ci.yml
vendored
2
.github/workflows/blossom-ci.yml
vendored
@ -156,6 +156,7 @@ jobs:
|
||||
"kaiyux",
|
||||
"kanghui0204",
|
||||
"karljang",
|
||||
"karthikvetrivel",
|
||||
"katec846",
|
||||
"Kefeng-Duan",
|
||||
"KingsleyLiu-NV",
|
||||
@ -193,6 +194,7 @@ jobs:
|
||||
"mlefeb01",
|
||||
"moraxu",
|
||||
"MrGeva",
|
||||
"mzweilz",
|
||||
"Naveassaf",
|
||||
"nekorobov",
|
||||
"netanel-haber",
|
||||
|
||||
@ -874,7 +874,6 @@ def getMountListForSlurmTest(SlurmCluster cluster, boolean useSbatch = false)
|
||||
}
|
||||
mounts += [
|
||||
"${cluster.scratchPath}:/scratch.trt_llm_data:ro",
|
||||
"/home/svc_tensorrt/.cache:/root/.cache",
|
||||
]
|
||||
} else {
|
||||
throw new Exception("Unsupported container runtime: ${cluster.containerRuntime}")
|
||||
|
||||
@ -16,6 +16,7 @@ import contextlib
|
||||
import ctypes
|
||||
import os
|
||||
import platform
|
||||
import threading
|
||||
import time
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
@ -34,7 +35,7 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
Network = None
|
||||
|
||||
from ._utils import str_dtype_to_trt
|
||||
from ._utils import print_all_stacks, str_dtype_to_trt
|
||||
from .bindings import MpiComm
|
||||
from .logger import logger
|
||||
from .plugin import _load_plugin_lib
|
||||
@ -82,6 +83,19 @@ def _init(log_level: object = None) -> None:
|
||||
|
||||
MpiComm.local_init()
|
||||
|
||||
def _print_stacks():
|
||||
counter = 0
|
||||
while True:
|
||||
time.sleep(print_stacks_period)
|
||||
counter += 1
|
||||
logger.error(f"Printing stacks {counter} times")
|
||||
print_all_stacks()
|
||||
|
||||
print_stacks_period = int(os.getenv("TRTLLM_PRINT_STACKS_PERIOD", "-1"))
|
||||
if print_stacks_period > 0:
|
||||
print_stacks_thread = threading.Thread(target=_print_stacks, daemon=True)
|
||||
print_stacks_thread.start()
|
||||
|
||||
logger.info("TensorRT LLM inited.")
|
||||
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ from tensorrt_llm._utils import mpi_disabled, nvtx_range
|
||||
from tensorrt_llm.mapping import CpType
|
||||
|
||||
from ..distributed import Distributed
|
||||
from .hang_detector import HangDetector
|
||||
from .llm_request import (ExecutorRequest, LlmRequest,
|
||||
executor_request_to_llm_request)
|
||||
|
||||
@ -47,10 +48,17 @@ class RequestQueueItem:
|
||||
class ExecutorRequestQueue:
|
||||
"""Handles fetching and processing of new requests from the request queue."""
|
||||
|
||||
def __init__(self, dist: Distributed, enable_attention_dp: bool,
|
||||
max_batch_size: int, max_beam_width: int,
|
||||
max_num_active_requests: int, enable_iter_perf_stats: bool,
|
||||
batch_wait_timeout_ms: float):
|
||||
def __init__(
|
||||
self,
|
||||
dist: Distributed,
|
||||
enable_attention_dp: bool,
|
||||
max_batch_size: int,
|
||||
max_beam_width: int,
|
||||
max_num_active_requests: int,
|
||||
enable_iter_perf_stats: bool,
|
||||
batch_wait_timeout_ms: float,
|
||||
hang_detector: Optional[HangDetector] = None,
|
||||
):
|
||||
self.dist = dist
|
||||
self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue()
|
||||
self.waiting_queue: deque[RequestQueueItem] = deque()
|
||||
@ -66,6 +74,7 @@ class ExecutorRequestQueue:
|
||||
self.active = True
|
||||
self.batch_wait_timeout_ms = batch_wait_timeout_ms
|
||||
self.send_requests_handler = None
|
||||
self.hang_detector = hang_detector or HangDetector()
|
||||
|
||||
# State tracking
|
||||
self.num_fetch_requests = 0
|
||||
@ -303,7 +312,8 @@ class ExecutorRequestQueue:
|
||||
self.request_accumulated.clear()
|
||||
# Reset timeout to 0 to avoid hanging when no new requests are available
|
||||
timeout = datetime.timedelta(0)
|
||||
new_requests.extend(self._get_from_request_queue(timeout))
|
||||
with self.hang_detector.pause():
|
||||
new_requests.extend(self._get_from_request_queue(timeout))
|
||||
|
||||
# Broadcast requests and handle Python objects
|
||||
new_requests, py_request_objects = self._handle_request_broadcasting(
|
||||
@ -477,8 +487,9 @@ class ExecutorRequestQueue:
|
||||
# Preserve original `new_requests` on rank 0
|
||||
_ = self._broadcast_new_requests(new_requests, py_request_objects)
|
||||
else:
|
||||
new_requests, py_request_objects = self._broadcast_new_requests(
|
||||
new_requests, py_request_objects)
|
||||
with self.hang_detector.pause():
|
||||
new_requests, py_request_objects = self._broadcast_new_requests(
|
||||
new_requests, py_request_objects)
|
||||
|
||||
return new_requests, py_request_objects
|
||||
|
||||
|
||||
96
tensorrt_llm/_torch/pyexecutor/hang_detector.py
Normal file
96
tensorrt_llm/_torch/pyexecutor/hang_detector.py
Normal file
@ -0,0 +1,96 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, Optional
|
||||
|
||||
from tensorrt_llm._utils import print_all_stacks
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
|
||||
class HangDetector:
|
||||
def __init__(
|
||||
self, timeout: Optional[int] = None, on_detected: Optional[Callable[[], None]] = None
|
||||
):
|
||||
self.timeout = timeout if timeout is not None else 300
|
||||
assert self.timeout > 0, "timeout must be greater than 0"
|
||||
self.on_detected = on_detected or (lambda: None)
|
||||
self.task = None
|
||||
self.loop = None
|
||||
self.loop_thread = None
|
||||
self.lock = threading.Lock()
|
||||
self.active = False
|
||||
self._detected = False
|
||||
|
||||
def start(self):
|
||||
"""Enable hang detection."""
|
||||
|
||||
def run_loop():
|
||||
asyncio.set_event_loop(self.loop)
|
||||
self.loop.run_forever()
|
||||
|
||||
self.active = True
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.loop_thread = threading.Thread(target=run_loop, daemon=True, name="hang_detector_loop")
|
||||
self.loop_thread.start()
|
||||
|
||||
async def _detect_hang(self):
|
||||
await asyncio.sleep(self.timeout)
|
||||
with self.lock:
|
||||
self._detected = True
|
||||
logger.error(f"Hang detected after {self.timeout} seconds.")
|
||||
print_all_stacks()
|
||||
self.on_detected()
|
||||
|
||||
def detected(self):
|
||||
"""Return True if hang is detected."""
|
||||
with self.lock:
|
||||
return self._detected
|
||||
|
||||
def checkpoint(self):
|
||||
"""Reset hang detection timer."""
|
||||
self.cancel_task()
|
||||
if self.active:
|
||||
self.task = asyncio.run_coroutine_threadsafe(self._detect_hang(), self.loop)
|
||||
|
||||
def cancel_task(self):
|
||||
"""Cancel the hang detection task."""
|
||||
if self.task is not None and not self.task.done():
|
||||
self.task.cancel()
|
||||
self.task = None
|
||||
|
||||
@contextmanager
|
||||
def pause(self):
|
||||
"""Pause hang detection in scope."""
|
||||
try:
|
||||
self.cancel_task()
|
||||
yield
|
||||
finally:
|
||||
self.checkpoint()
|
||||
|
||||
def stop(self):
|
||||
"""Stop hang detection."""
|
||||
self.active = False
|
||||
self.cancel_task()
|
||||
if self.loop is not None:
|
||||
# Cancel all pending tasks before stopping the loop
|
||||
def cancel_all_tasks():
|
||||
for task in asyncio.all_tasks(self.loop):
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
self.loop.call_soon(self.loop.stop)
|
||||
|
||||
self.loop.call_soon_threadsafe(cancel_all_tasks)
|
||||
|
||||
if self.loop_thread is not None and self.loop_thread.is_alive():
|
||||
self.loop_thread.join()
|
||||
|
||||
self.loop = None
|
||||
self.loop_thread = None
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.stop()
|
||||
return False
|
||||
@ -46,6 +46,7 @@ from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem
|
||||
from .guided_decoder import GuidedDecoder
|
||||
from .handle_additional_outputs import HandleAdditionalOutputs
|
||||
from .handle_logits import HandleLogits
|
||||
from .hang_detector import HangDetector
|
||||
from .kv_cache_connector import KvCacheConnectorManager
|
||||
from .kv_cache_transceiver import KvCacheTransceiver
|
||||
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
|
||||
@ -137,6 +138,7 @@ class PyExecutor:
|
||||
max_seq_len: Optional[int] = None,
|
||||
peft_cache_config: Optional[PeftCacheConfig] = None,
|
||||
virtual_memory_pools: Optional[dict] = None,
|
||||
hang_detection_timeout: Optional[int] = None,
|
||||
execution_stream: Optional[torch.cuda.Stream] = None):
|
||||
super(PyExecutor, self).__init__()
|
||||
self.device_id = torch.cuda.current_device()
|
||||
@ -280,6 +282,15 @@ class PyExecutor:
|
||||
self.adp_ctx_batching_wait_iters_count = 0
|
||||
self.batch_wait_iters_count = 0
|
||||
|
||||
def on_detected():
|
||||
self._handle_errors(
|
||||
f"Hang detected on rank {self.global_rank} in PyExecutor.")
|
||||
self.shutdown_event.set()
|
||||
self.is_shutdown = True
|
||||
|
||||
self.hang_detector = HangDetector(timeout=hang_detection_timeout,
|
||||
on_detected=on_detected)
|
||||
|
||||
# request fetcher initialization
|
||||
self._set_global_steady_clock_offset()
|
||||
self.executor_request_queue = ExecutorRequestQueue(
|
||||
@ -290,6 +301,7 @@ class PyExecutor:
|
||||
max_num_active_requests=self.max_num_active_requests,
|
||||
enable_iter_perf_stats=self.enable_iter_perf_stats,
|
||||
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
|
||||
hang_detector=self.hang_detector,
|
||||
)
|
||||
self.executor_request_queue.set_exclude_last_generation_logits(
|
||||
self.disable_overlap_scheduler, self.dist.pp_size)
|
||||
@ -476,6 +488,14 @@ class PyExecutor:
|
||||
"""
|
||||
self.executor_request_queue.enqueue_shutdown_request()
|
||||
self.shutdown_event.wait()
|
||||
if self.hang_detector.detected():
|
||||
# Early return here to avoid waiting for hanging threads.
|
||||
# Since `on_detected` has sent the error message as response,
|
||||
# this worker will be asked to shutdown immediately.
|
||||
# Since the whole process will shutdown after this `shutdown` call,
|
||||
# All threads and memory pools will be freed properly.
|
||||
logger.error("Hang detected, shutting down immediately.")
|
||||
return
|
||||
self.worker_thread.join()
|
||||
self.worker_started = False
|
||||
for manager in self.resource_manager.resource_managers.values():
|
||||
@ -960,10 +980,11 @@ class PyExecutor:
|
||||
# ensure the context is created, otherwise, some MPI calls will fail.
|
||||
CUASSERT(cudart.cudaSetDevice(self.device_id))
|
||||
microbatch_id = 0
|
||||
with self._profiler() as profile_step:
|
||||
with self._profiler() as profile_step, self.hang_detector:
|
||||
iter_start_time = time.time()
|
||||
iter_stats = None
|
||||
while True:
|
||||
self.hang_detector.checkpoint()
|
||||
profile_step()
|
||||
if self.enable_iter_perf_stats:
|
||||
iter_start_time = time.time()
|
||||
@ -1349,11 +1370,12 @@ class PyExecutor:
|
||||
torch.cuda.set_device(self.device_id)
|
||||
# ensure the context is created, otherwise, some MPI calls will fail.
|
||||
CUASSERT(cudart.cudaSetDevice(self.device_id))
|
||||
with self._profiler() as profile_step:
|
||||
with self._profiler() as profile_step, self.hang_detector:
|
||||
sample_state = None
|
||||
iter_start_time = time.time()
|
||||
iter_stats = None
|
||||
while True:
|
||||
self.hang_detector.checkpoint()
|
||||
profile_step()
|
||||
if self.enable_iter_perf_stats:
|
||||
iter_start_time = time.time()
|
||||
@ -1551,13 +1573,14 @@ class PyExecutor:
|
||||
torch.cuda.set_device(self.device_id)
|
||||
# ensure the context is created, otherwise, some MPI calls will fail.
|
||||
CUASSERT(cudart.cudaSetDevice(self.device_id))
|
||||
with self._profiler() as profile_step:
|
||||
with self._profiler() as profile_step, self.hang_detector:
|
||||
iter_start_time = time.time()
|
||||
iter_stats = None
|
||||
target_inputs = None
|
||||
previous_tensors_device = None
|
||||
can_forward = False if self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver else True
|
||||
while True:
|
||||
self.hang_detector.checkpoint()
|
||||
profile_step()
|
||||
if self.enable_iter_perf_stats:
|
||||
iter_start_time = time.time()
|
||||
|
||||
@ -21,8 +21,10 @@ import math
|
||||
import os
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
import tempfile
|
||||
import trace
|
||||
import traceback
|
||||
import weakref
|
||||
from contextlib import contextmanager
|
||||
from enum import EnumMeta
|
||||
@ -761,6 +763,13 @@ def is_sm_100f(sm_version=None):
|
||||
return sm_version == 100 or sm_version == 103
|
||||
|
||||
|
||||
def print_all_stacks():
|
||||
"""Print stack traces for all threads"""
|
||||
for thread_id, frame in sys._current_frames().items():
|
||||
logger.error(f"Thread {thread_id} stack trace:\n" +
|
||||
"".join(traceback.format_stack(frame)))
|
||||
|
||||
|
||||
def is_trace_enabled(env_var: str):
|
||||
value = os.environ.get(env_var, "-1")
|
||||
if value == "ALL":
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import gc
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from pathlib import Path
|
||||
@ -9,7 +11,7 @@ import zmq
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from .._utils import mpi_comm, mpi_rank
|
||||
from .._utils import mpi_comm, mpi_rank, print_all_stacks
|
||||
from ..bindings import executor as tllm
|
||||
from ..builder import Engine
|
||||
from ..llmapi.llm_args import BaseLlmArgs
|
||||
@ -153,6 +155,21 @@ def worker_main(
|
||||
hmac_key: Optional[bytes] = None,
|
||||
) -> None:
|
||||
|
||||
def _print_stacks():
|
||||
counter = 0
|
||||
while True:
|
||||
time.sleep(print_stacks_period)
|
||||
counter += 1
|
||||
logger.error(f"Printing stacks {counter} times")
|
||||
print_all_stacks()
|
||||
|
||||
print_stacks_period = int(
|
||||
os.getenv("TRTLLM_WORKER_PRINT_STACKS_PERIOD", "-1"))
|
||||
if print_stacks_period > 0:
|
||||
print_stacks_thread = threading.Thread(target=_print_stacks,
|
||||
daemon=True)
|
||||
print_stacks_thread.start()
|
||||
|
||||
mpi_comm().barrier()
|
||||
|
||||
if llm_args is not None and llm_args.env_overrides:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user