TensorRT-LLMs/tensorrt_llm/hlapi/utils.py
Kaiyu Xie b777bd6475
Update TensorRT-LLM (#1725)
* Update TensorRT-LLM

---------

Co-authored-by: RunningLeon <mnsheng@yeah.net>
Co-authored-by: Tlntin <TlntinDeng01@Gmail.com>
Co-authored-by: ZHENG, Zhen <zhengzhen.z@qq.com>
Co-authored-by: Pham Van Ngoan <ngoanpham1196@gmail.com>
Co-authored-by: Nathan Price <nathan@abridge.com>
Co-authored-by: Tushar Goel <tushar.goel.ml@gmail.com>
Co-authored-by: Mati <132419219+matichon-vultureprime@users.noreply.github.com>
2024-06-04 20:26:32 +08:00

234 lines
7.0 KiB
Python

import hashlib
import os
import signal
import sys
import tempfile
import traceback
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
from typing import List, Optional, Union
import filelock
import huggingface_hub
import torch
from huggingface_hub import snapshot_download
from tensorrt_llm.bindings import executor as tllme
from tensorrt_llm.logger import set_level
from ..bindings.executor import OutputConfig
def print_traceback_on_error(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
traceback.print_exc()
raise e
return wrapper
class SamplingConfig(tllme.SamplingConfig):
def __init__(self,
end_id: Optional[int] = None,
pad_id: Optional[int] = None,
max_new_tokens: Optional[int] = None,
*args,
**kwargs):
super().__init__(*args, **kwargs)
# Patch some configs from cpp Request here, since they are attached to each prompt while
# it we won't introduce the Request concept in the generate() API.
self.end_id = end_id
self.pad_id = pad_id if pad_id is not None else end_id
self.max_new_tokens = max_new_tokens
def _get_member_names(self):
return ("beam_width", "top_k", "top_p", "top_p_min", "top_p_reset_ids",
"top_p_decay", "random_seed", "temperature", "min_length",
"beam_search_diversity_rate", "repetition_penalty",
"presence_penalty", "frequency_penalty", "length_penalty",
"early_stopping")
def __eq__(self, other):
return all(getattr(self, name) == getattr(other, name) for name in self._get_member_names()) and \
self.end_id == other.end_id and self.pad_id == other.pad_id and self.max_new_tokens == other.max_new_tokens
@print_traceback_on_error
def __setstate__(self, args: tuple) -> None:
assert len(args) == 3 + len(self._get_member_names())
kwargs = {
'end_id': args[0],
'pad_id': args[1],
'max_new_tokens': args[2]
}
kwargs.update({
key: value
for key, value in zip(self._get_member_names(), args[3:])
if value is not None
})
# The C++ class's constructor is not properly called by pickle, so we need to call it manually.
self.__init__(**kwargs)
@print_traceback_on_error
def __getstate__(self) -> tuple:
args = (self.end_id, self.pad_id, self.max_new_tokens) + tuple(
getattr(self, name) for name in self._get_member_names())
return args
def __repr__(self):
return f"SamplingConfig({', '.join(f'{name}={getattr(self, name)}' for name in self._get_member_names())})"
class OutputConfig(OutputConfig):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Alter several default values for HLAPI usage.
if "exclude_input_from_output" not in kwargs:
self.exclude_input_from_output = True
def __eq__(self, other: "OutputConfig") -> bool:
return self.__getstate__() == other.__getstate__()
def __getstate__(self):
return {
"exclude_input_from_output": self.exclude_input_from_output,
"return_log_probs": self.return_log_probs,
"return_context_logits": self.return_context_logits,
"return_generation_logits": self.return_generation_logits,
}
def __setstate__(self, state):
self.exclude_input_from_output = state["exclude_input_from_output"]
self.return_log_probs = state["return_log_probs"]
self.return_context_logits = state["return_context_logits"]
self.return_generation_logits = state["return_generation_logits"]
@dataclass
class GenerationOutput:
text: str = ""
token_ids: Union[List[int], List[List[int]]] = field(default_factory=list)
log_probs: Optional[List[float]] = None
context_logits: Optional[torch.Tensor] = None
generation_logits: Optional[torch.Tensor] = None
def print_colored(message, color: str = None):
colors = dict(
grey="\x1b[38;20m",
yellow="\x1b[33;20m",
red="\x1b[31;20m",
bold_red="\x1b[31;1m",
bold_green="\033[1;32m",
green="\033[0;32m",
)
reset = "\x1b[0m"
if color:
sys.stderr.write(colors[color] + message + reset)
else:
sys.stderr.write(message)
def file_with_glob_exists(directory, glob) -> bool:
path = Path(directory)
for file_path in path.glob(glob):
if file_path.is_file():
return True
return False
def file_with_suffix_exists(directory, suffix) -> bool:
return file_with_glob_exists(directory, f'*{suffix}')
def get_device_count() -> int:
return torch.cuda.device_count() if torch.cuda.is_available() else 0
def get_total_gpu_memory(device: int) -> float:
return torch.cuda.get_device_properties(device).total_memory
class GpuArch:
@staticmethod
def is_post_hopper() -> bool:
return get_gpu_arch() >= 9
@staticmethod
def is_post_ampere() -> bool:
return get_gpu_arch() >= 8
@staticmethod
def is_post_volta() -> bool:
return get_gpu_arch() >= 7
def get_gpu_arch(device: int = 0) -> int:
return torch.cuda.get_device_properties(device).major
class ContextManager:
''' A helper to create a context manager for a resource. '''
def __init__(self, resource):
self.resource = resource
def __enter__(self):
return self.resource.__enter__()
def __exit__(self, exc_type, exc_value, traceback):
return self.resource.__exit__(exc_type, exc_value, traceback)
def is_directory_empty(directory: Path) -> bool:
return not any(directory.iterdir())
def init_log_level():
''' Set the log level if the environment variable is not set. '''
if "TLLM_LOG_LEVEL" not in os.environ:
set_level("warning")
os.environ["TLLM_LOG_LEVEL"] = "WARNING"
def sigint_handler(signal, frame):
sys.stderr.write("\nSIGINT received, quit LLM!\n")
sys.exit(1)
# Register the signal handler to handle SIGINT
# This helps to deal with user's Ctrl+C
signal.signal(signal.SIGINT, sigint_handler)
# Use the system temporary directory to share the cache
temp_dir = tempfile.gettempdir()
def get_file_lock(model_name: str,
cache_dir: Optional[str] = None) -> filelock.FileLock:
# Hash the model name to avoid invalid characters in the lock file path
hashed_model_name = hashlib.sha256(model_name.encode()).hexdigest()
cache_dir = cache_dir or temp_dir
os.makedirs(cache_dir, exist_ok=True)
lock_file_path = os.path.join(cache_dir, f"{hashed_model_name}.lock")
return filelock.FileLock(lock_file_path)
def download_hf_model(model_name: str) -> Path:
with get_file_lock(model_name):
hf_folder = snapshot_download(
model_name,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE)
return Path(hf_folder)