TensorRT-LLMs/tests/integration/defs/utils/timeout_manager.py
Ivy Zhang d101a6cebc
[https://nvbugs/5410279][test] resubmit timeout refactor (#6337)
Signed-off-by: Ivy Zhang <25222398+crazydemo@users.noreply.github.com>
2025-08-05 16:39:25 +08:00

185 lines
5.8 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from contextlib import contextmanager
from typing import Any, Callable, Optional
class TimeoutManager:
"""
A utility class for managing timeout in test cases.
This class helps reduce boilerplate code for timeout handling in test cases
by providing a simple interface to track remaining time and execute operations
with automatic timeout checking.
"""
def __init__(self, initial_timeout: Optional[float] = None):
"""
Initialize the timeout manager.
Args:
initial_timeout: Initial timeout value in seconds. If None, no timeout is enforced.
"""
self._initial_timeout = initial_timeout
self._remaining_timeout = initial_timeout
self._start_time = None
@property
def remaining_timeout(self) -> Optional[float]:
"""Get the remaining timeout value."""
return self._remaining_timeout
def reset(self, timeout: Optional[float] = None) -> None:
"""
Reset the timeout manager with a new timeout value.
Args:
timeout: New timeout value. If None, uses the initial timeout.
"""
self._remaining_timeout = timeout if timeout is not None else self._initial_timeout
self._start_time = None
def check_timeout(self, phase_name: str = "operation") -> None:
"""
Check if timeout has been exceeded and raise TimeoutError if so.
Args:
phase_name: Name of the current phase for error message.
Raises:
TimeoutError: If timeout has been exceeded.
"""
if self._remaining_timeout is not None and self._remaining_timeout <= 0:
raise TimeoutError(f"Timeout exceeded after {phase_name} phase!")
@contextmanager
def timed_operation(self, phase_name: str = "operation"):
"""
Context manager for timing an operation and updating remaining timeout.
Args:
phase_name: Name of the phase for timeout checking.
Yields:
None
Raises:
TimeoutError: If timeout is exceeded after the operation.
"""
if self._remaining_timeout is None:
# No timeout enforcement
yield
return
start_time = time.time()
try:
yield
finally:
operation_time = time.time() - start_time
self._remaining_timeout -= operation_time
self.check_timeout(phase_name)
def execute_with_timeout(self,
operation: Callable[[], Any],
phase_name: str = "operation",
**kwargs) -> Any:
"""
Execute an operation with timeout tracking.
Args:
operation: The operation to execute.
phase_name: Name of the phase for timeout checking.
**kwargs: Additional arguments to pass to the operation.
Returns:
The result of the operation.
Raises:
TimeoutError: If timeout is exceeded after the operation.
"""
with self.timed_operation(phase_name):
return operation(**kwargs)
def call_with_timeout(self,
func: Callable,
*args,
phase_name: str = "operation",
**kwargs) -> Any:
"""
Call a function with timeout tracking.
Args:
func: The function to call.
*args: Positional arguments for the function.
phase_name: Name of the phase for timeout checking.
**kwargs: Keyword arguments for the function.
Returns:
The result of the function call.
Raises:
TimeoutError: If timeout is exceeded after the function call.
"""
with self.timed_operation(phase_name):
return func(*args, **kwargs)
def create_timeout_manager(
timeout_from_marker: Optional[float] = None) -> TimeoutManager:
"""
Create a TimeoutManager instance from a timeout marker value.
Args:
timeout_from_marker: Timeout value from pytest marker.
Returns:
A TimeoutManager instance.
"""
return TimeoutManager(timeout_from_marker)
# Convenience decorator for test functions
def with_timeout_management(func: Callable) -> Callable:
"""
Decorator to automatically inject timeout management into test functions.
This decorator expects the test function to have a 'timeout_from_marker' parameter
and automatically creates a TimeoutManager instance.
Args:
func: The test function to decorate.
Returns:
The decorated function.
"""
import functools
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Extract timeout_from_marker from kwargs
timeout_from_marker = kwargs.get('timeout_from_marker')
# Create timeout manager
timeout_manager = create_timeout_manager(timeout_from_marker)
# Add timeout_manager to kwargs
kwargs['timeout_manager'] = timeout_manager
return func(*args, **kwargs)
return wrapper