import hashlib import os import signal import sys import tempfile import traceback from dataclasses import dataclass, field from functools import wraps from pathlib import Path from typing import List, Optional, Union import filelock import huggingface_hub import torch from huggingface_hub import snapshot_download from tensorrt_llm.bindings import executor as tllme from tensorrt_llm.logger import set_level from ..bindings.executor import OutputConfig def print_traceback_on_error(func): @wraps(func) def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: traceback.print_exc() raise e return wrapper class SamplingConfig(tllme.SamplingConfig): def __init__(self, end_id: Optional[int] = None, pad_id: Optional[int] = None, max_new_tokens: Optional[int] = None, *args, **kwargs): super().__init__(*args, **kwargs) # Patch some configs from cpp Request here, since they are attached to each prompt while # it we won't introduce the Request concept in the generate() API. self.end_id = end_id self.pad_id = pad_id if pad_id is not None else end_id self.max_new_tokens = max_new_tokens def _get_member_names(self): return ("beam_width", "top_k", "top_p", "top_p_min", "top_p_reset_ids", "top_p_decay", "random_seed", "temperature", "min_length", "beam_search_diversity_rate", "repetition_penalty", "presence_penalty", "frequency_penalty", "length_penalty", "early_stopping") def __eq__(self, other): return all(getattr(self, name) == getattr(other, name) for name in self._get_member_names()) and \ self.end_id == other.end_id and self.pad_id == other.pad_id and self.max_new_tokens == other.max_new_tokens @print_traceback_on_error def __setstate__(self, args: tuple) -> None: assert len(args) == 3 + len(self._get_member_names()) kwargs = { 'end_id': args[0], 'pad_id': args[1], 'max_new_tokens': args[2] } kwargs.update({ key: value for key, value in zip(self._get_member_names(), args[3:]) if value is not None }) # The C++ class's constructor is not properly called by pickle, so we need to call it manually. self.__init__(**kwargs) @print_traceback_on_error def __getstate__(self) -> tuple: args = (self.end_id, self.pad_id, self.max_new_tokens) + tuple( getattr(self, name) for name in self._get_member_names()) return args def __repr__(self): return f"SamplingConfig({', '.join(f'{name}={getattr(self, name)}' for name in self._get_member_names())})" class OutputConfig(OutputConfig): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Alter several default values for HLAPI usage. if "exclude_input_from_output" not in kwargs: self.exclude_input_from_output = True def __eq__(self, other: "OutputConfig") -> bool: return self.__getstate__() == other.__getstate__() def __getstate__(self): return { "exclude_input_from_output": self.exclude_input_from_output, "return_log_probs": self.return_log_probs, "return_context_logits": self.return_context_logits, "return_generation_logits": self.return_generation_logits, } def __setstate__(self, state): self.exclude_input_from_output = state["exclude_input_from_output"] self.return_log_probs = state["return_log_probs"] self.return_context_logits = state["return_context_logits"] self.return_generation_logits = state["return_generation_logits"] @dataclass class GenerationOutput: text: str = "" token_ids: Union[List[int], List[List[int]]] = field(default_factory=list) log_probs: Optional[List[float]] = None context_logits: Optional[torch.Tensor] = None generation_logits: Optional[torch.Tensor] = None def print_colored(message, color: str = None): colors = dict( grey="\x1b[38;20m", yellow="\x1b[33;20m", red="\x1b[31;20m", bold_red="\x1b[31;1m", bold_green="\033[1;32m", green="\033[0;32m", ) reset = "\x1b[0m" if color: sys.stderr.write(colors[color] + message + reset) else: sys.stderr.write(message) def file_with_glob_exists(directory, glob) -> bool: path = Path(directory) for file_path in path.glob(glob): if file_path.is_file(): return True return False def file_with_suffix_exists(directory, suffix) -> bool: return file_with_glob_exists(directory, f'*{suffix}') def get_device_count() -> int: return torch.cuda.device_count() if torch.cuda.is_available() else 0 def get_total_gpu_memory(device: int) -> float: return torch.cuda.get_device_properties(device).total_memory class GpuArch: @staticmethod def is_post_hopper() -> bool: return get_gpu_arch() >= 9 @staticmethod def is_post_ampere() -> bool: return get_gpu_arch() >= 8 @staticmethod def is_post_volta() -> bool: return get_gpu_arch() >= 7 def get_gpu_arch(device: int = 0) -> int: return torch.cuda.get_device_properties(device).major class ContextManager: ''' A helper to create a context manager for a resource. ''' def __init__(self, resource): self.resource = resource def __enter__(self): return self.resource.__enter__() def __exit__(self, exc_type, exc_value, traceback): return self.resource.__exit__(exc_type, exc_value, traceback) def is_directory_empty(directory: Path) -> bool: return not any(directory.iterdir()) def init_log_level(): ''' Set the log level if the environment variable is not set. ''' if "TLLM_LOG_LEVEL" not in os.environ: set_level("warning") os.environ["TLLM_LOG_LEVEL"] = "WARNING" def sigint_handler(signal, frame): sys.stderr.write("\nSIGINT received, quit LLM!\n") sys.exit(1) # Register the signal handler to handle SIGINT # This helps to deal with user's Ctrl+C signal.signal(signal.SIGINT, sigint_handler) # Use the system temporary directory to share the cache temp_dir = tempfile.gettempdir() def get_file_lock(model_name: str, cache_dir: Optional[str] = None) -> filelock.FileLock: # Hash the model name to avoid invalid characters in the lock file path hashed_model_name = hashlib.sha256(model_name.encode()).hexdigest() cache_dir = cache_dir or temp_dir os.makedirs(cache_dir, exist_ok=True) lock_file_path = os.path.join(cache_dir, f"{hashed_model_name}.lock") return filelock.FileLock(lock_file_path) def download_hf_model(model_name: str) -> Path: with get_file_lock(model_name): hf_folder = snapshot_download( model_name, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE) return Path(hf_folder)