TensorRT-LLMs/tensorrt_llm/hlapi/utils.py
Kaiyu Xie 728cc0044b
Update TensorRT-LLM (#1233)
* Update TensorRT-LLM

---------

Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-03-05 18:32:53 +08:00

70 lines
1.6 KiB
Python

import gc
import sys
import traceback
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
from typing import List
import torch
@dataclass
class GenerationOutput:
text: str = ""
token_ids: List[int] = field(default_factory=list)
logprobs: List[float] = field(default_factory=list)
def print_colored(message, color: str = None):
colors = dict(
grey="\x1b[38;20m",
yellow="\x1b[33;20m",
red="\x1b[31;20m",
bold_red="\x1b[31;1m",
bold_green="\033[1;32m",
)
reset = "\x1b[0m"
if color:
sys.stderr.write(colors[color] + message + reset)
else:
sys.stderr.write(message)
def file_with_suffix_exists(directory, suffix) -> bool:
path = Path(directory)
for file_path in path.glob(f'*{suffix}'):
if file_path.is_file():
return True
return False
def print_traceback_on_error(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
traceback.print_exc()
raise e
return wrapper
def get_device_count() -> int:
return torch.cuda.device_count() if torch.cuda.is_available() else 0
def get_total_gpu_memory(device: int) -> float:
return torch.cuda.get_device_properties(device).total_memory
def release_gc():
''' Release memory allocated by PyTorch and Python garbage collector explicitly and immediately.
This could be used when some states might be kept in memory even after the variables are deleted.
'''
gc.collect()
torch.cuda.empty_cache()