mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
185 lines
5.8 KiB
Python
185 lines
5.8 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import time
|
|
from contextlib import contextmanager
|
|
from typing import Any, Callable, Optional
|
|
|
|
|
|
class TimeoutManager:
|
|
"""
|
|
A utility class for managing timeout in test cases.
|
|
|
|
This class helps reduce boilerplate code for timeout handling in test cases
|
|
by providing a simple interface to track remaining time and execute operations
|
|
with automatic timeout checking.
|
|
"""
|
|
|
|
def __init__(self, initial_timeout: Optional[float] = None):
|
|
"""
|
|
Initialize the timeout manager.
|
|
|
|
Args:
|
|
initial_timeout: Initial timeout value in seconds. If None, no timeout is enforced.
|
|
"""
|
|
self._initial_timeout = initial_timeout
|
|
self._remaining_timeout = initial_timeout
|
|
self._start_time = None
|
|
|
|
@property
|
|
def remaining_timeout(self) -> Optional[float]:
|
|
"""Get the remaining timeout value."""
|
|
return self._remaining_timeout
|
|
|
|
def reset(self, timeout: Optional[float] = None) -> None:
|
|
"""
|
|
Reset the timeout manager with a new timeout value.
|
|
|
|
Args:
|
|
timeout: New timeout value. If None, uses the initial timeout.
|
|
"""
|
|
self._remaining_timeout = timeout if timeout is not None else self._initial_timeout
|
|
self._start_time = None
|
|
|
|
def check_timeout(self, phase_name: str = "operation") -> None:
|
|
"""
|
|
Check if timeout has been exceeded and raise TimeoutError if so.
|
|
|
|
Args:
|
|
phase_name: Name of the current phase for error message.
|
|
|
|
Raises:
|
|
TimeoutError: If timeout has been exceeded.
|
|
"""
|
|
if self._remaining_timeout is not None and self._remaining_timeout <= 0:
|
|
raise TimeoutError(f"Timeout exceeded after {phase_name} phase!")
|
|
|
|
@contextmanager
|
|
def timed_operation(self, phase_name: str = "operation"):
|
|
"""
|
|
Context manager for timing an operation and updating remaining timeout.
|
|
|
|
Args:
|
|
phase_name: Name of the phase for timeout checking.
|
|
|
|
Yields:
|
|
None
|
|
|
|
Raises:
|
|
TimeoutError: If timeout is exceeded after the operation.
|
|
"""
|
|
if self._remaining_timeout is None:
|
|
# No timeout enforcement
|
|
yield
|
|
return
|
|
|
|
start_time = time.time()
|
|
try:
|
|
yield
|
|
finally:
|
|
operation_time = time.time() - start_time
|
|
self._remaining_timeout -= operation_time
|
|
self.check_timeout(phase_name)
|
|
|
|
def execute_with_timeout(self,
|
|
operation: Callable[[], Any],
|
|
phase_name: str = "operation",
|
|
**kwargs) -> Any:
|
|
"""
|
|
Execute an operation with timeout tracking.
|
|
|
|
Args:
|
|
operation: The operation to execute.
|
|
phase_name: Name of the phase for timeout checking.
|
|
**kwargs: Additional arguments to pass to the operation.
|
|
|
|
Returns:
|
|
The result of the operation.
|
|
|
|
Raises:
|
|
TimeoutError: If timeout is exceeded after the operation.
|
|
"""
|
|
with self.timed_operation(phase_name):
|
|
return operation(**kwargs)
|
|
|
|
def call_with_timeout(self,
|
|
func: Callable,
|
|
*args,
|
|
phase_name: str = "operation",
|
|
**kwargs) -> Any:
|
|
"""
|
|
Call a function with timeout tracking.
|
|
|
|
Args:
|
|
func: The function to call.
|
|
*args: Positional arguments for the function.
|
|
phase_name: Name of the phase for timeout checking.
|
|
**kwargs: Keyword arguments for the function.
|
|
|
|
Returns:
|
|
The result of the function call.
|
|
|
|
Raises:
|
|
TimeoutError: If timeout is exceeded after the function call.
|
|
"""
|
|
with self.timed_operation(phase_name):
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
def create_timeout_manager(
|
|
timeout_from_marker: Optional[float] = None) -> TimeoutManager:
|
|
"""
|
|
Create a TimeoutManager instance from a timeout marker value.
|
|
|
|
Args:
|
|
timeout_from_marker: Timeout value from pytest marker.
|
|
|
|
Returns:
|
|
A TimeoutManager instance.
|
|
"""
|
|
return TimeoutManager(timeout_from_marker)
|
|
|
|
|
|
# Convenience decorator for test functions
|
|
def with_timeout_management(func: Callable) -> Callable:
|
|
"""
|
|
Decorator to automatically inject timeout management into test functions.
|
|
|
|
This decorator expects the test function to have a 'timeout_from_marker' parameter
|
|
and automatically creates a TimeoutManager instance.
|
|
|
|
Args:
|
|
func: The test function to decorate.
|
|
|
|
Returns:
|
|
The decorated function.
|
|
"""
|
|
import functools
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
# Extract timeout_from_marker from kwargs
|
|
timeout_from_marker = kwargs.get('timeout_from_marker')
|
|
|
|
# Create timeout manager
|
|
timeout_manager = create_timeout_manager(timeout_from_marker)
|
|
|
|
# Add timeout_manager to kwargs
|
|
kwargs['timeout_manager'] = timeout_manager
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|