TensorRT-LLMs/tensorrt_llm/hlapi/utils.py
Kaiyu Xie e06f537e08
Update TensorRT-LLM (#1019)
* Update TensorRT-LLM

---------

Co-authored-by: erenup <ping.nie@pku.edu.cn>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-01-31 21:55:32 +08:00

66 lines
1.5 KiB
Python

import gc
import sys
import traceback
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
from typing import List
import torch
@dataclass
class GenerationOutput:
text: str = ""
token_ids: List[int] = field(default_factory=list)
logprobs: List[float] = field(default_factory=list)
def print_colored(message, color: str = None):
colors = dict(
grey="\x1b[38;20m",
yellow="\x1b[33;20m",
red="\x1b[31;20m",
bold_red="\x1b[31;1m",
bold_green="\033[1;32m",
)
reset = "\x1b[0m"
if color:
sys.stderr.write(colors[color] + message + reset)
else:
sys.stderr.write(message)
def file_with_suffix_exists(directory, suffix) -> bool:
path = Path(directory)
for file_path in path.glob(f'*{suffix}'):
if file_path.is_file():
return True
return False
def print_traceback_on_error(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
traceback.print_exc()
raise e
return wrapper
def get_device_count() -> int:
return torch.cuda.device_count() if torch.cuda.is_available() else 0
def release_gc():
''' Release memory allocated by PyTorch and Python garbage collector explicitly and immediately.
This could be used when some states might be kept in memory even after the variables are deleted.
'''
gc.collect()
torch.cuda.empty_cache()