mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Merge branch 'main' into spark-weekly-newcases
This commit is contained in:
commit
db09dafbc9
2
.github/workflows/blossom-ci.yml
vendored
2
.github/workflows/blossom-ci.yml
vendored
@ -156,6 +156,7 @@ jobs:
|
|||||||
"kaiyux",
|
"kaiyux",
|
||||||
"kanghui0204",
|
"kanghui0204",
|
||||||
"karljang",
|
"karljang",
|
||||||
|
"karthikvetrivel",
|
||||||
"katec846",
|
"katec846",
|
||||||
"Kefeng-Duan",
|
"Kefeng-Duan",
|
||||||
"KingsleyLiu-NV",
|
"KingsleyLiu-NV",
|
||||||
@ -193,6 +194,7 @@ jobs:
|
|||||||
"mlefeb01",
|
"mlefeb01",
|
||||||
"moraxu",
|
"moraxu",
|
||||||
"MrGeva",
|
"MrGeva",
|
||||||
|
"mzweilz",
|
||||||
"Naveassaf",
|
"Naveassaf",
|
||||||
"nekorobov",
|
"nekorobov",
|
||||||
"netanel-haber",
|
"netanel-haber",
|
||||||
|
|||||||
@ -874,7 +874,6 @@ def getMountListForSlurmTest(SlurmCluster cluster, boolean useSbatch = false)
|
|||||||
}
|
}
|
||||||
mounts += [
|
mounts += [
|
||||||
"${cluster.scratchPath}:/scratch.trt_llm_data:ro",
|
"${cluster.scratchPath}:/scratch.trt_llm_data:ro",
|
||||||
"/home/svc_tensorrt/.cache:/root/.cache",
|
|
||||||
]
|
]
|
||||||
} else {
|
} else {
|
||||||
throw new Exception("Unsupported container runtime: ${cluster.containerRuntime}")
|
throw new Exception("Unsupported container runtime: ${cluster.containerRuntime}")
|
||||||
|
|||||||
@ -16,6 +16,7 @@ import contextlib
|
|||||||
import ctypes
|
import ctypes
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -34,7 +35,7 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
Network = None
|
Network = None
|
||||||
|
|
||||||
from ._utils import str_dtype_to_trt
|
from ._utils import print_all_stacks, str_dtype_to_trt
|
||||||
from .bindings import MpiComm
|
from .bindings import MpiComm
|
||||||
from .logger import logger
|
from .logger import logger
|
||||||
from .plugin import _load_plugin_lib
|
from .plugin import _load_plugin_lib
|
||||||
@ -82,6 +83,19 @@ def _init(log_level: object = None) -> None:
|
|||||||
|
|
||||||
MpiComm.local_init()
|
MpiComm.local_init()
|
||||||
|
|
||||||
|
def _print_stacks():
|
||||||
|
counter = 0
|
||||||
|
while True:
|
||||||
|
time.sleep(print_stacks_period)
|
||||||
|
counter += 1
|
||||||
|
logger.error(f"Printing stacks {counter} times")
|
||||||
|
print_all_stacks()
|
||||||
|
|
||||||
|
print_stacks_period = int(os.getenv("TRTLLM_PRINT_STACKS_PERIOD", "-1"))
|
||||||
|
if print_stacks_period > 0:
|
||||||
|
print_stacks_thread = threading.Thread(target=_print_stacks, daemon=True)
|
||||||
|
print_stacks_thread.start()
|
||||||
|
|
||||||
logger.info("TensorRT LLM inited.")
|
logger.info("TensorRT LLM inited.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from tensorrt_llm._utils import mpi_disabled, nvtx_range
|
|||||||
from tensorrt_llm.mapping import CpType
|
from tensorrt_llm.mapping import CpType
|
||||||
|
|
||||||
from ..distributed import Distributed
|
from ..distributed import Distributed
|
||||||
|
from .hang_detector import HangDetector
|
||||||
from .llm_request import (ExecutorRequest, LlmRequest,
|
from .llm_request import (ExecutorRequest, LlmRequest,
|
||||||
executor_request_to_llm_request)
|
executor_request_to_llm_request)
|
||||||
|
|
||||||
@ -47,10 +48,17 @@ class RequestQueueItem:
|
|||||||
class ExecutorRequestQueue:
|
class ExecutorRequestQueue:
|
||||||
"""Handles fetching and processing of new requests from the request queue."""
|
"""Handles fetching and processing of new requests from the request queue."""
|
||||||
|
|
||||||
def __init__(self, dist: Distributed, enable_attention_dp: bool,
|
def __init__(
|
||||||
max_batch_size: int, max_beam_width: int,
|
self,
|
||||||
max_num_active_requests: int, enable_iter_perf_stats: bool,
|
dist: Distributed,
|
||||||
batch_wait_timeout_ms: float):
|
enable_attention_dp: bool,
|
||||||
|
max_batch_size: int,
|
||||||
|
max_beam_width: int,
|
||||||
|
max_num_active_requests: int,
|
||||||
|
enable_iter_perf_stats: bool,
|
||||||
|
batch_wait_timeout_ms: float,
|
||||||
|
hang_detector: Optional[HangDetector] = None,
|
||||||
|
):
|
||||||
self.dist = dist
|
self.dist = dist
|
||||||
self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue()
|
self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue()
|
||||||
self.waiting_queue: deque[RequestQueueItem] = deque()
|
self.waiting_queue: deque[RequestQueueItem] = deque()
|
||||||
@ -66,6 +74,7 @@ class ExecutorRequestQueue:
|
|||||||
self.active = True
|
self.active = True
|
||||||
self.batch_wait_timeout_ms = batch_wait_timeout_ms
|
self.batch_wait_timeout_ms = batch_wait_timeout_ms
|
||||||
self.send_requests_handler = None
|
self.send_requests_handler = None
|
||||||
|
self.hang_detector = hang_detector or HangDetector()
|
||||||
|
|
||||||
# State tracking
|
# State tracking
|
||||||
self.num_fetch_requests = 0
|
self.num_fetch_requests = 0
|
||||||
@ -303,7 +312,8 @@ class ExecutorRequestQueue:
|
|||||||
self.request_accumulated.clear()
|
self.request_accumulated.clear()
|
||||||
# Reset timeout to 0 to avoid hanging when no new requests are available
|
# Reset timeout to 0 to avoid hanging when no new requests are available
|
||||||
timeout = datetime.timedelta(0)
|
timeout = datetime.timedelta(0)
|
||||||
new_requests.extend(self._get_from_request_queue(timeout))
|
with self.hang_detector.pause():
|
||||||
|
new_requests.extend(self._get_from_request_queue(timeout))
|
||||||
|
|
||||||
# Broadcast requests and handle Python objects
|
# Broadcast requests and handle Python objects
|
||||||
new_requests, py_request_objects = self._handle_request_broadcasting(
|
new_requests, py_request_objects = self._handle_request_broadcasting(
|
||||||
@ -477,8 +487,9 @@ class ExecutorRequestQueue:
|
|||||||
# Preserve original `new_requests` on rank 0
|
# Preserve original `new_requests` on rank 0
|
||||||
_ = self._broadcast_new_requests(new_requests, py_request_objects)
|
_ = self._broadcast_new_requests(new_requests, py_request_objects)
|
||||||
else:
|
else:
|
||||||
new_requests, py_request_objects = self._broadcast_new_requests(
|
with self.hang_detector.pause():
|
||||||
new_requests, py_request_objects)
|
new_requests, py_request_objects = self._broadcast_new_requests(
|
||||||
|
new_requests, py_request_objects)
|
||||||
|
|
||||||
return new_requests, py_request_objects
|
return new_requests, py_request_objects
|
||||||
|
|
||||||
|
|||||||
96
tensorrt_llm/_torch/pyexecutor/hang_detector.py
Normal file
96
tensorrt_llm/_torch/pyexecutor/hang_detector.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from tensorrt_llm._utils import print_all_stacks
|
||||||
|
from tensorrt_llm.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
class HangDetector:
|
||||||
|
def __init__(
|
||||||
|
self, timeout: Optional[int] = None, on_detected: Optional[Callable[[], None]] = None
|
||||||
|
):
|
||||||
|
self.timeout = timeout if timeout is not None else 300
|
||||||
|
assert self.timeout > 0, "timeout must be greater than 0"
|
||||||
|
self.on_detected = on_detected or (lambda: None)
|
||||||
|
self.task = None
|
||||||
|
self.loop = None
|
||||||
|
self.loop_thread = None
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
self.active = False
|
||||||
|
self._detected = False
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Enable hang detection."""
|
||||||
|
|
||||||
|
def run_loop():
|
||||||
|
asyncio.set_event_loop(self.loop)
|
||||||
|
self.loop.run_forever()
|
||||||
|
|
||||||
|
self.active = True
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
self.loop_thread = threading.Thread(target=run_loop, daemon=True, name="hang_detector_loop")
|
||||||
|
self.loop_thread.start()
|
||||||
|
|
||||||
|
async def _detect_hang(self):
|
||||||
|
await asyncio.sleep(self.timeout)
|
||||||
|
with self.lock:
|
||||||
|
self._detected = True
|
||||||
|
logger.error(f"Hang detected after {self.timeout} seconds.")
|
||||||
|
print_all_stacks()
|
||||||
|
self.on_detected()
|
||||||
|
|
||||||
|
def detected(self):
|
||||||
|
"""Return True if hang is detected."""
|
||||||
|
with self.lock:
|
||||||
|
return self._detected
|
||||||
|
|
||||||
|
def checkpoint(self):
|
||||||
|
"""Reset hang detection timer."""
|
||||||
|
self.cancel_task()
|
||||||
|
if self.active:
|
||||||
|
self.task = asyncio.run_coroutine_threadsafe(self._detect_hang(), self.loop)
|
||||||
|
|
||||||
|
def cancel_task(self):
|
||||||
|
"""Cancel the hang detection task."""
|
||||||
|
if self.task is not None and not self.task.done():
|
||||||
|
self.task.cancel()
|
||||||
|
self.task = None
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def pause(self):
|
||||||
|
"""Pause hang detection in scope."""
|
||||||
|
try:
|
||||||
|
self.cancel_task()
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self.checkpoint()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stop hang detection."""
|
||||||
|
self.active = False
|
||||||
|
self.cancel_task()
|
||||||
|
if self.loop is not None:
|
||||||
|
# Cancel all pending tasks before stopping the loop
|
||||||
|
def cancel_all_tasks():
|
||||||
|
for task in asyncio.all_tasks(self.loop):
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
self.loop.call_soon(self.loop.stop)
|
||||||
|
|
||||||
|
self.loop.call_soon_threadsafe(cancel_all_tasks)
|
||||||
|
|
||||||
|
if self.loop_thread is not None and self.loop_thread.is_alive():
|
||||||
|
self.loop_thread.join()
|
||||||
|
|
||||||
|
self.loop = None
|
||||||
|
self.loop_thread = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.stop()
|
||||||
|
return False
|
||||||
@ -46,6 +46,7 @@ from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem
|
|||||||
from .guided_decoder import GuidedDecoder
|
from .guided_decoder import GuidedDecoder
|
||||||
from .handle_additional_outputs import HandleAdditionalOutputs
|
from .handle_additional_outputs import HandleAdditionalOutputs
|
||||||
from .handle_logits import HandleLogits
|
from .handle_logits import HandleLogits
|
||||||
|
from .hang_detector import HangDetector
|
||||||
from .kv_cache_connector import KvCacheConnectorManager
|
from .kv_cache_connector import KvCacheConnectorManager
|
||||||
from .kv_cache_transceiver import KvCacheTransceiver
|
from .kv_cache_transceiver import KvCacheTransceiver
|
||||||
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
|
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
|
||||||
@ -137,6 +138,7 @@ class PyExecutor:
|
|||||||
max_seq_len: Optional[int] = None,
|
max_seq_len: Optional[int] = None,
|
||||||
peft_cache_config: Optional[PeftCacheConfig] = None,
|
peft_cache_config: Optional[PeftCacheConfig] = None,
|
||||||
virtual_memory_pools: Optional[dict] = None,
|
virtual_memory_pools: Optional[dict] = None,
|
||||||
|
hang_detection_timeout: Optional[int] = None,
|
||||||
execution_stream: Optional[torch.cuda.Stream] = None):
|
execution_stream: Optional[torch.cuda.Stream] = None):
|
||||||
super(PyExecutor, self).__init__()
|
super(PyExecutor, self).__init__()
|
||||||
self.device_id = torch.cuda.current_device()
|
self.device_id = torch.cuda.current_device()
|
||||||
@ -280,6 +282,15 @@ class PyExecutor:
|
|||||||
self.adp_ctx_batching_wait_iters_count = 0
|
self.adp_ctx_batching_wait_iters_count = 0
|
||||||
self.batch_wait_iters_count = 0
|
self.batch_wait_iters_count = 0
|
||||||
|
|
||||||
|
def on_detected():
|
||||||
|
self._handle_errors(
|
||||||
|
f"Hang detected on rank {self.global_rank} in PyExecutor.")
|
||||||
|
self.shutdown_event.set()
|
||||||
|
self.is_shutdown = True
|
||||||
|
|
||||||
|
self.hang_detector = HangDetector(timeout=hang_detection_timeout,
|
||||||
|
on_detected=on_detected)
|
||||||
|
|
||||||
# request fetcher initialization
|
# request fetcher initialization
|
||||||
self._set_global_steady_clock_offset()
|
self._set_global_steady_clock_offset()
|
||||||
self.executor_request_queue = ExecutorRequestQueue(
|
self.executor_request_queue = ExecutorRequestQueue(
|
||||||
@ -290,6 +301,7 @@ class PyExecutor:
|
|||||||
max_num_active_requests=self.max_num_active_requests,
|
max_num_active_requests=self.max_num_active_requests,
|
||||||
enable_iter_perf_stats=self.enable_iter_perf_stats,
|
enable_iter_perf_stats=self.enable_iter_perf_stats,
|
||||||
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
|
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
|
||||||
|
hang_detector=self.hang_detector,
|
||||||
)
|
)
|
||||||
self.executor_request_queue.set_exclude_last_generation_logits(
|
self.executor_request_queue.set_exclude_last_generation_logits(
|
||||||
self.disable_overlap_scheduler, self.dist.pp_size)
|
self.disable_overlap_scheduler, self.dist.pp_size)
|
||||||
@ -476,6 +488,14 @@ class PyExecutor:
|
|||||||
"""
|
"""
|
||||||
self.executor_request_queue.enqueue_shutdown_request()
|
self.executor_request_queue.enqueue_shutdown_request()
|
||||||
self.shutdown_event.wait()
|
self.shutdown_event.wait()
|
||||||
|
if self.hang_detector.detected():
|
||||||
|
# Early return here to avoid waiting for hanging threads.
|
||||||
|
# Since `on_detected` has sent the error message as response,
|
||||||
|
# this worker will be asked to shutdown immediately.
|
||||||
|
# Since the whole process will shutdown after this `shutdown` call,
|
||||||
|
# All threads and memory pools will be freed properly.
|
||||||
|
logger.error("Hang detected, shutting down immediately.")
|
||||||
|
return
|
||||||
self.worker_thread.join()
|
self.worker_thread.join()
|
||||||
self.worker_started = False
|
self.worker_started = False
|
||||||
for manager in self.resource_manager.resource_managers.values():
|
for manager in self.resource_manager.resource_managers.values():
|
||||||
@ -960,10 +980,11 @@ class PyExecutor:
|
|||||||
# ensure the context is created, otherwise, some MPI calls will fail.
|
# ensure the context is created, otherwise, some MPI calls will fail.
|
||||||
CUASSERT(cudart.cudaSetDevice(self.device_id))
|
CUASSERT(cudart.cudaSetDevice(self.device_id))
|
||||||
microbatch_id = 0
|
microbatch_id = 0
|
||||||
with self._profiler() as profile_step:
|
with self._profiler() as profile_step, self.hang_detector:
|
||||||
iter_start_time = time.time()
|
iter_start_time = time.time()
|
||||||
iter_stats = None
|
iter_stats = None
|
||||||
while True:
|
while True:
|
||||||
|
self.hang_detector.checkpoint()
|
||||||
profile_step()
|
profile_step()
|
||||||
if self.enable_iter_perf_stats:
|
if self.enable_iter_perf_stats:
|
||||||
iter_start_time = time.time()
|
iter_start_time = time.time()
|
||||||
@ -1349,11 +1370,12 @@ class PyExecutor:
|
|||||||
torch.cuda.set_device(self.device_id)
|
torch.cuda.set_device(self.device_id)
|
||||||
# ensure the context is created, otherwise, some MPI calls will fail.
|
# ensure the context is created, otherwise, some MPI calls will fail.
|
||||||
CUASSERT(cudart.cudaSetDevice(self.device_id))
|
CUASSERT(cudart.cudaSetDevice(self.device_id))
|
||||||
with self._profiler() as profile_step:
|
with self._profiler() as profile_step, self.hang_detector:
|
||||||
sample_state = None
|
sample_state = None
|
||||||
iter_start_time = time.time()
|
iter_start_time = time.time()
|
||||||
iter_stats = None
|
iter_stats = None
|
||||||
while True:
|
while True:
|
||||||
|
self.hang_detector.checkpoint()
|
||||||
profile_step()
|
profile_step()
|
||||||
if self.enable_iter_perf_stats:
|
if self.enable_iter_perf_stats:
|
||||||
iter_start_time = time.time()
|
iter_start_time = time.time()
|
||||||
@ -1551,13 +1573,14 @@ class PyExecutor:
|
|||||||
torch.cuda.set_device(self.device_id)
|
torch.cuda.set_device(self.device_id)
|
||||||
# ensure the context is created, otherwise, some MPI calls will fail.
|
# ensure the context is created, otherwise, some MPI calls will fail.
|
||||||
CUASSERT(cudart.cudaSetDevice(self.device_id))
|
CUASSERT(cudart.cudaSetDevice(self.device_id))
|
||||||
with self._profiler() as profile_step:
|
with self._profiler() as profile_step, self.hang_detector:
|
||||||
iter_start_time = time.time()
|
iter_start_time = time.time()
|
||||||
iter_stats = None
|
iter_stats = None
|
||||||
target_inputs = None
|
target_inputs = None
|
||||||
previous_tensors_device = None
|
previous_tensors_device = None
|
||||||
can_forward = False if self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver else True
|
can_forward = False if self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver else True
|
||||||
while True:
|
while True:
|
||||||
|
self.hang_detector.checkpoint()
|
||||||
profile_step()
|
profile_step()
|
||||||
if self.enable_iter_perf_stats:
|
if self.enable_iter_perf_stats:
|
||||||
iter_start_time = time.time()
|
iter_start_time = time.time()
|
||||||
|
|||||||
@ -21,8 +21,10 @@ import math
|
|||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
import struct
|
import struct
|
||||||
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import trace
|
import trace
|
||||||
|
import traceback
|
||||||
import weakref
|
import weakref
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from enum import EnumMeta
|
from enum import EnumMeta
|
||||||
@ -761,6 +763,13 @@ def is_sm_100f(sm_version=None):
|
|||||||
return sm_version == 100 or sm_version == 103
|
return sm_version == 100 or sm_version == 103
|
||||||
|
|
||||||
|
|
||||||
|
def print_all_stacks():
|
||||||
|
"""Print stack traces for all threads"""
|
||||||
|
for thread_id, frame in sys._current_frames().items():
|
||||||
|
logger.error(f"Thread {thread_id} stack trace:\n" +
|
||||||
|
"".join(traceback.format_stack(frame)))
|
||||||
|
|
||||||
|
|
||||||
def is_trace_enabled(env_var: str):
|
def is_trace_enabled(env_var: str):
|
||||||
value = os.environ.get(env_var, "-1")
|
value = os.environ.get(env_var, "-1")
|
||||||
if value == "ALL":
|
if value == "ALL":
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -9,7 +11,7 @@ import zmq
|
|||||||
|
|
||||||
from tensorrt_llm.logger import logger
|
from tensorrt_llm.logger import logger
|
||||||
|
|
||||||
from .._utils import mpi_comm, mpi_rank
|
from .._utils import mpi_comm, mpi_rank, print_all_stacks
|
||||||
from ..bindings import executor as tllm
|
from ..bindings import executor as tllm
|
||||||
from ..builder import Engine
|
from ..builder import Engine
|
||||||
from ..llmapi.llm_args import BaseLlmArgs
|
from ..llmapi.llm_args import BaseLlmArgs
|
||||||
@ -153,6 +155,21 @@ def worker_main(
|
|||||||
hmac_key: Optional[bytes] = None,
|
hmac_key: Optional[bytes] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
|
def _print_stacks():
|
||||||
|
counter = 0
|
||||||
|
while True:
|
||||||
|
time.sleep(print_stacks_period)
|
||||||
|
counter += 1
|
||||||
|
logger.error(f"Printing stacks {counter} times")
|
||||||
|
print_all_stacks()
|
||||||
|
|
||||||
|
print_stacks_period = int(
|
||||||
|
os.getenv("TRTLLM_WORKER_PRINT_STACKS_PERIOD", "-1"))
|
||||||
|
if print_stacks_period > 0:
|
||||||
|
print_stacks_thread = threading.Thread(target=_print_stacks,
|
||||||
|
daemon=True)
|
||||||
|
print_stacks_thread.start()
|
||||||
|
|
||||||
mpi_comm().barrier()
|
mpi_comm().barrier()
|
||||||
|
|
||||||
if llm_args is not None and llm_args.env_overrides:
|
if llm_args is not None and llm_args.env_overrides:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user