TensorRT-LLMs/tensorrt_llm/llmapi/utils.py
Yan Chunwei b86256eb54
[TRTLLM-9144][fix] enhance RPC robustness (#8711)
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com>
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
Co-authored-by: Erin Ho <14718778+hchings@users.noreply.github.com>
2025-12-02 21:37:59 +08:00

768 lines
24 KiB
Python

import asyncio
import collections
import ctypes
import datetime
import hashlib
import inspect
import io
import math
import os
import re
import sys
import tempfile
import threading
import time
import traceback
import warnings
import weakref
from functools import cache, wraps
from pathlib import Path
from queue import Queue
from typing import (Any, Callable, Iterable, List, Optional, Tuple, Type,
get_type_hints)
import filelock
import huggingface_hub
import psutil
import torch
from huggingface_hub import snapshot_download
from pydantic import BaseModel
from tqdm.auto import tqdm
from tensorrt_llm.logger import Singleton, logger
def print_traceback_on_error(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
logger_debug(f"Exception in {func.__name__}: {e}\n", "red")
traceback.print_exc()
raise e
return wrapper
def print_colored(message,
color: Optional[str] = None,
writer: io.TextIOWrapper = sys.stderr):
colors = dict(
grey="\x1b[38;20m",
yellow="\x1b[33;20m",
red="\x1b[31;20m",
bold_red="\x1b[31;1m",
bold_green="\033[1;32m",
green="\033[0;32m",
cyan="\033[0;36m",
)
reset = "\x1b[0m"
if color:
writer.write(colors[color] + message + reset)
else:
writer.write(message)
def get_current_location(skip_frames: int = 2) -> str:
"""
Get the current execution location in format 'module.class.function'.
Args:
skip_frames: Number of stack frames to skip (default 2 to skip this function and its caller)
Returns:
String in format 'module.class.function' or 'module.function' if not in a class
"""
stack = inspect.stack()
if len(stack) <= skip_frames:
return "unknown"
frame = stack[skip_frames]
module_name = frame.frame.f_globals.get('__name__', 'unknown')
function_name = frame.function
# Try to determine if we're in a class method
class_name = None
if 'self' in frame.frame.f_locals:
# This is likely an instance method
obj = frame.frame.f_locals['self']
class_name = obj.__class__.__name__
elif 'cls' in frame.frame.f_locals:
# This might be a class method
cls = frame.frame.f_locals['cls']
if inspect.isclass(cls):
class_name = cls.__name__
# Build the location string
if class_name:
return f"{module_name}.{class_name}.{function_name}"
else:
return f"{module_name}.{function_name}"
def logger_debug(message,
color: Optional[str] = None,
writer: io.TextIOWrapper = sys.stderr):
""" Print the message if the llmapi debug mode is enabled. Fallback to logger.debug if not. """
if enable_llmapi_debug():
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
location = get_current_location()
cur_dualname = "..." + location[-47:] if len(
location) > 50 else location
print_colored(f"{timestamp} [{cur_dualname}]", "bold_green", writer)
print_colored(f" {message}\n", color, writer)
writer.flush()
else:
# Fallback to logger.debug
logger.debug(message)
def file_with_glob_exists(directory, glob) -> bool:
path = Path(directory)
for file_path in path.glob(glob):
if file_path.is_file():
return True
return False
def file_with_suffix_exists(directory, suffix) -> bool:
return file_with_glob_exists(directory, f'*{suffix}')
def get_device_count() -> int:
return torch.cuda.device_count() if torch.cuda.is_available() else 0
def get_total_gpu_memory(device: int) -> float:
return torch.cuda.get_device_properties(device).total_memory
class GpuArch:
@staticmethod
def get_arch() -> int:
return get_gpu_arch()
@staticmethod
def is_post_hopper() -> bool:
return get_gpu_arch() >= 9
@staticmethod
def is_post_ampere() -> bool:
return get_gpu_arch() >= 8
def get_gpu_arch(device: int = 0) -> int:
return torch.cuda.get_device_properties(device).major
class ContextManager:
''' A helper to create a context manager for a resource. '''
def __init__(self, resource):
self.resource = resource
def __enter__(self):
return self.resource.__enter__()
def __exit__(self, exc_type, exc_value, traceback):
return self.resource.__exit__(exc_type, exc_value, traceback)
def is_directory_empty(directory: Path) -> bool:
return not any(directory.iterdir())
class ExceptionHandler(metaclass=Singleton):
def __init__(self):
self._sys_excepthook: Callable = sys.excepthook
self._obj_refs_and_callbacks: List[Tuple[weakref.ReferenceType,
str]] = []
def __call__(self, exc_type, exc_value, traceback):
self._sys_excepthook(exc_type, exc_value, traceback)
for obj_ref, callback_name in self._obj_refs_and_callbacks:
if (obj := obj_ref()) is not None:
callback = getattr(obj, callback_name)
callback()
def register(self, obj: Any, callback_name: str):
assert callable(getattr(obj, callback_name, None))
self._obj_refs_and_callbacks.append((weakref.ref(obj), callback_name))
exception_handler = ExceptionHandler()
sys.excepthook = exception_handler
# Use the system temporary directory to share the cache
temp_dir = tempfile.gettempdir()
def get_file_lock(model_name: str,
cache_dir: Optional[str] = None) -> filelock.FileLock:
# Hash the model name to avoid invalid characters in the lock file path
hashed_model_name = hashlib.sha256(model_name.encode()).hexdigest()
cache_dir = cache_dir or temp_dir
os.makedirs(cache_dir, exist_ok=True)
lock_file_path = os.path.join(cache_dir, f"{hashed_model_name}.lock")
return filelock.FileLock(lock_file_path)
class DisabledTqdm(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True)
def download_hf_model(model: str, revision: Optional[str] = None) -> Path:
ignore_patterns = ["original/**/*"]
with get_file_lock(model):
hf_folder = snapshot_download(
model,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
ignore_patterns=ignore_patterns,
revision=revision,
tqdm_class=DisabledTqdm)
return Path(hf_folder)
def download_hf_pretrained_config(model: str,
revision: Optional[str] = None) -> Path:
with get_file_lock(model):
hf_folder = snapshot_download(
model,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
revision=revision,
allow_patterns=["config.json"],
tqdm_class=DisabledTqdm)
return Path(hf_folder)
def append_docstring(docstring: str):
''' A decorator to append a docstring to a function. '''
def decorator(fn):
fn.__doc__ = (fn.__doc__ or '') + docstring
return fn
return decorator
def set_docstring(docstring: str):
''' A decorator to set a docstring to a function. '''
def decorator(fn):
fn.__doc__ = docstring
return fn
return decorator
def get_directory_size_in_gb(directory: Path) -> float:
""" Get the size of the directory. """
if not (directory.is_dir() and directory.exists()):
raise ValueError(f"{directory} is not a directory.")
total_size = 0
for dirpath, dirnames, filenames in os.walk(directory):
for f in filenames:
fp = os.path.join(dirpath, f)
total_size += os.path.getsize(fp)
return total_size / 1024**3 # GB
class ManagedThread(threading.Thread):
""" A thread that will put exceptions into an external queue if the task fails.
There are two approaches to stop the thread:
1. Set stop_event to stop the loop
2. Let `task` return False
Args:
task (Callable[..., bool]): The task to run repeatedly in the thread, should return False if break the loop.
error_queue (Queue): The queue to put exceptions into if the task fails.
name (str): The name of the thread.
**kwargs: The arguments to pass to the task
"""
def __init__(self,
task: Callable[..., bool],
error_queue: Queue,
name: Optional[str] = None,
stop_event: Optional[threading.Event] = None,
**kwargs):
super().__init__(name=name)
self.task = task
self.error_queue = error_queue
self.kwargs = kwargs
self.daemon = True
self.stop_event = stop_event or threading.Event()
def run(self):
while not self.stop_event.is_set():
task = self.task
if isinstance(task, weakref.WeakMethod):
task = task()
if task is None:
# Normally, this should not happen.
logger.warning("WeakMethod is expired.")
break
try:
if not task(**self.kwargs):
break
except Exception as e:
logger.error(
f"Error in thread {self.name}: {e}\n{traceback.format_exc()}"
)
self.error_queue.put(e)
logger.info(f"Thread {self.name} stopped.")
def stop(self):
self.stop_event.set()
_enable_llm_debug_ = None
def enable_llm_debug() -> bool:
''' Tell whether to enable the debug mode for LLM class. '''
global _enable_llm_debug_
if _enable_llm_debug_ is None:
_enable_llm_debug_ = os.environ.get("TLLM_LLM_ENABLE_DEBUG", "0") == "1"
return _enable_llm_debug_
_enable_llmapi_debug_ = None
def enable_llmapi_debug() -> bool:
global _enable_llmapi_debug_
if _enable_llmapi_debug_ is None:
_enable_llmapi_debug_ = os.environ.get("TLLM_LLMAPI_ENABLE_DEBUG",
"0") == "1"
return _enable_llmapi_debug_
@cache
def enable_worker_single_process_for_tp1() -> bool:
''' Tell whether to make worker use single process for TP1.
This is helpful for return-logits performance and debugging. '''
return os.environ.get("TLLM_WORKER_USE_SINGLE_PROCESS", "0") == "1"
class AsyncQueue:
"""
A queue-style container that provides both sync and async interface.
This is used to provide a compatible interface for janus.Queue.
"""
class EventLoopShutdownError(Exception):
pass
class MixedSyncAsyncAPIError(Exception):
pass
def __init__(self):
self._q = collections.deque()
self._event = asyncio.Event()
self._tainted = False
self._sync_q = _SyncQueue(self)
@property
def sync_q(self):
return self._sync_q
def full(self) -> bool:
return len(self._q) == self._q.maxlen
def empty(self) -> bool:
return not self._q
def put(self, item) -> None:
self._q.append(item)
self._event.set()
# Decoupled put and notify.
# Deque is thread safe so we can put from outside the event loop.
# However, we have to schedule notify in event loop because it's not thread safe.
# In this case the notify part may get scheduled late, to the point that
# corresponding data in deque may have already been consumed.
# Avoid firing the event in this case.
def put_nowait(self, item) -> None:
self._q.append(item)
def notify(self) -> None:
if self._q:
self._event.set()
def unsafe_get(self):
# Unsafe get taints the queue, renders it unusable by async methods.
self._tainted = True
# Use exception to detect empty. Pre-check is not thread safe.
try:
return self._q.popleft()
except IndexError:
raise asyncio.QueueEmpty() from None
async def get(self, timeout=None):
if self._tainted:
raise AsyncQueue.MixedSyncAsyncAPIError()
# Blocking path: timeout is None (wait indefinitely)
if timeout is None:
# Wait indefinitely until the queue is non-empty.
# It is necessary to check if the queue is empty after waking.
# Because multiple waiting coroutines may be awakened simultaneously when a new item entries empty queue.
# These coroutines will all pop this item from queue, and then raise IndexError.
while not self._q:
await self._event.wait()
# Blocking path: timeout > 0 (timed wait, retry with remaining time).
elif timeout > 0:
# Compute the deadline; if the queue is still empty after waking, continue waiting for the remaining time.
loop = asyncio.get_running_loop()
deadline = loop.time() + timeout
while not self._q:
remaining = deadline - loop.time()
if remaining <= 0:
raise asyncio.TimeoutError()
# This may raise asyncio.TimeoutError.
await asyncio.wait_for(self._event.wait(), timeout=remaining)
# Non-blocking path: timeout <= 0.
elif not self._q:
raise asyncio.QueueEmpty()
res = self._q.popleft()
if not self._q:
self._event.clear()
return res
class _SyncQueue:
"""
A simplified Queue that provides a `put` method that is compatible with the asyncio event loop.
"""
def __init__(self,
queue: "AsyncQueue",
loop: Optional[asyncio.AbstractEventLoop] = None):
self._aq = queue
self._loop = loop or asyncio.get_event_loop()
async def _notify(self):
self._aq.notify()
def put(self, item) -> None:
self._aq.put_nowait(item)
if self._loop.is_running():
asyncio.run_coroutine_threadsafe(self._notify(), self._loop)
else:
raise AsyncQueue.EventLoopShutdownError()
def put_nowait(self, item) -> None:
""" Put item without notify the event. """
self._aq.put_nowait(item)
# Notify many queues in one coroutine, to cut down context switch overhead.
@staticmethod
async def _notify_many(queues: Iterable["_SyncQueue"]):
for queue in queues:
queue._aq.notify()
@staticmethod
def notify_many(loop: asyncio.AbstractEventLoop,
queues: List["_SyncQueue"]) -> None:
""" Notify the events in the loop. """
if loop.is_running():
asyncio.run_coroutine_threadsafe(
_SyncQueue._notify_many(frozenset(queues)), loop)
else:
raise AsyncQueue.EventLoopShutdownError()
@property
def loop(self) -> asyncio.AbstractEventLoop:
return self._loop
def full(self) -> bool:
return self._aq.full()
def get(self, timeout=None):
# Here is the WAR for jupyter scenario where trt-llm detects the event loop existence.
# However, this event loop launched by jupyter rather than trt-llm. It led the GenerationResult initialized
# w/ AsyncQueue and call the get() unintentionally.
warnings.warn(
"LLM API is running in async mode because you have a running event loop,"
" but you are using sync API. This may lead to potential performance loss."
)
# We can't call asyncio.run_coroutine_threadsafe(self._aq.get(), self.loop) and wait the returned Future,
# since we are in the same event loop, and we can't yield the thread while waiting result.
deadline = None if timeout is None else time.time() + timeout
while deadline is None or time.time() < deadline:
try:
return self._aq.unsafe_get()
except asyncio.QueueEmpty:
time.sleep(0.01)
def get_numa_aware_cpu_affinity(device_id):
'''Query NVML for NUMA-aware CPU affinity for the specified CUDA device.
Args:
device_id: The CUDA device ID to query for optimal CPU affinity.
Returns:
List of CPU IDs representing the optimal CPU affinity mask for the device.
Raises:
pynvml.NVMLError: If NVML operations fail or device_id is invalid.
'''
cpu_count = psutil.cpu_count()
# If this is not a NUMA system, or we hit an exception, default to
# unconstrained CPU affinity
cpu_affinity = list(range(cpu_count))
if not os.path.isdir("/sys/devices/system/node/node1"):
return cpu_affinity
try:
# initialize NVML
import pynvml
pynvml.nvmlInit()
# Get the number of bits per ulong
c_ulong_bits = ctypes.sizeof(ctypes.c_ulong) * 8
# Determine how large our cpu set array from NVML needs to be
cpu_set_size = math.ceil(cpu_count / c_ulong_bits)
# Get the optimal CPU affinity for this device according to the NUMA
# topology
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
affinity_masks = pynvml.nvmlDeviceGetCpuAffinity(handle, cpu_set_size)
# Convert CPU masks to python list
cpu_affinity = []
for cpu_id in range(cpu_count):
mask_array_index = cpu_id // c_ulong_bits
mask_bit_index = cpu_id % c_ulong_bits
if affinity_masks[mask_array_index] & (1 << mask_bit_index):
cpu_affinity.append(cpu_id)
finally:
try:
pynvml.nvmlShutdown()
except:
pass # Ignore shutdown errors
return cpu_affinity
def generate_api_docs_as_docstring(model: Type[BaseModel],
include_annotations=False,
indent="") -> str:
"""
Generates API documentation for a Pydantic BaseModel, formatted as a
Python docstring.
Args:
model: The Pydantic BaseModel class.
Returns:
A string containing the API documentation formatted as a docstring.
"""
docstring_lines = []
if include_annotations:
# Class docstring
if model.__doc__:
docstring_lines.append(model.__doc__.strip())
docstring_lines.append(
"") # Add a blank line after the class docstring
docstring_lines.append(f"{indent}Args:")
schema = model.schema()
type_hints = get_type_hints(model)
type_alias = {
'integer': 'int',
'number': 'float',
'boolean': 'bool',
'string': 'str',
'array': 'list',
}
for field_name, field_info in schema['properties'].items():
if field_name.startswith("_"): # skip private fields
continue
if field_info.get("status", None) == "deprecated":
continue
field_type = field_info.get('type', None)
field_description = field_info.get('description', '')
field_default = field_info.get('default', None)
field_required = field_name in schema.get('required', [])
# Get full type path from type hints if available
if field_type:
type_str = type_alias.get(field_type, field_type)
elif field_name in type_hints:
type_str = str(type_hints[field_name])
type_str = type_str.replace("typing.", "")
# Extract just the class name from full class path
for regex in [r"<class '([^']+)'>", r"<enum '([^']+)'>"]:
if (match := re.match(regex, type_str)) is not None:
type_str = match.group(1)
break
else:
type_str = field_type or 'Any'
# Format the argument documentation with 12 spaces indent for args
arg_line = f"{indent} {field_name} ({type_str}): "
if status := field_info.get("status", None):
arg_line += f":tag:`{status}` "
elif LABEL_STABLE_APIS:
arg_line += f":tag:`stable` "
if field_description:
arg_line += field_description.split('\n')[0] # First line with type
docstring_lines.append(arg_line)
# Add remaining description lines and default value with 16 spaces indent
if field_description and '\n' in field_description:
remaining_lines = field_description.split('\n')[1:]
for line in remaining_lines:
docstring_lines.append(f"{indent} {line}")
if not field_required or field_default is not None:
default_str = str(
field_default) if field_default is not None else "None"
docstring_lines[-1] += f" Defaults to {default_str}."
if include_annotations:
docstring_lines.append("") # Empty line before Returns
return_annotation = "None" # Default to None, adjust if needed
docstring_lines.append(
f"{indent}Returns:\n{indent} {return_annotation}")
return "\n".join(docstring_lines)
def get_type_repr(cls):
"""Handle built-in types gracefully. """
module_name = cls.__module__
if module_name == 'builtins': # Special case for built-in types
return cls.__qualname__
return f"{module_name}.{cls.__qualname__}"
LABEL_STABLE_APIS: bool = True
""" Whether to label the stable APIs with `stable` tags. """
class ApiParamTagger:
''' A helper to tag the api doc according to the status of the fields.
The status is set in the json_schema_extra of the field.
'''
def __call__(self, cls: Type[BaseModel]) -> None:
""" The main entry point to tag the api doc. """
if cls.__name__ in ["LlmArgs", "TorchLlmArgs"]:
# TODO: apply this to other classes
self._process_pydantic_model(cls)
def _process_pydantic_model(self, cls: Type[BaseModel]) -> None:
"""Process the Pydantic model to add tags to the fields.
"""
for field_name, field_info in cls.model_fields.items():
if field_info.json_schema_extra and 'status' in field_info.json_schema_extra:
status = field_info.json_schema_extra['status']
self._amend_pydantic_field_description_with_tags(
cls, [field_name], status)
else:
self._amend_pydantic_field_description_with_tags(
cls, [field_name], "stable")
def _amend_pydantic_field_description_with_tags(self, cls: Type[BaseModel],
field_names: list[str],
tag: str) -> None:
"""Amend the description of the fields with tags.
e.g. :tag:`beta` or :tag:`prototype`
Args:
cls: The Pydantic BaseModel class.
field_names: The names of the fields to amend.
tag: The tag to add to the fields.
"""
assert field_names
for field_name in field_names:
field = cls.model_fields[field_name]
cls.model_fields[
field_name].description = f":tag:`{tag}` {field.description}"
cls.model_rebuild(force=True)
def tag_llm_params():
from tensorrt_llm.llmapi.llm_args import LlmArgs
ApiParamTagger()(LlmArgs)
class ApiStatusRegistry:
''' A registry to store the status of the api.
usage:
@ApiStatusRegistry.set_api_status("beta")
def my_method(self, *args, **kwargs):
pass
class App:
@ApiStatusRegistry.set_api_status("beta")
def my_method(self, *args, **kwargs):
pass
'''
method_to_status = {}
@classmethod
def set_api_status(cls, status: str):
def decorator(func):
# Use qualified name to support class methods
if func.__qualname__ in cls.method_to_status:
logger.debug(
f"Method {func.__qualname__} already has a status, skipping the decorator"
)
return func
cls.method_to_status[func.__qualname__] = status
func.__doc__ = cls.amend_api_doc_with_status_tags(func)
return func
return decorator
@classmethod
def get_api_status(cls, method: Callable) -> Optional[str]:
return cls.method_to_status.get(method.__qualname__, None)
@classmethod
def amend_api_doc_with_status_tags(cls, method: Callable) -> str:
status = cls.get_api_status(method)
if status is None:
return method.__doc__
return f":tag:`{status}` {method.__doc__}"
set_api_status = ApiStatusRegistry().set_api_status
get_api_status = ApiStatusRegistry().get_api_status