mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: erenup <ping.nie@pku.edu.cn> Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
66 lines
1.5 KiB
Python
66 lines
1.5 KiB
Python
import gc
|
|
import sys
|
|
import traceback
|
|
from dataclasses import dataclass, field
|
|
from functools import wraps
|
|
from pathlib import Path
|
|
from typing import List
|
|
|
|
import torch
|
|
|
|
|
|
@dataclass
|
|
class GenerationOutput:
|
|
text: str = ""
|
|
token_ids: List[int] = field(default_factory=list)
|
|
logprobs: List[float] = field(default_factory=list)
|
|
|
|
|
|
def print_colored(message, color: str = None):
|
|
colors = dict(
|
|
grey="\x1b[38;20m",
|
|
yellow="\x1b[33;20m",
|
|
red="\x1b[31;20m",
|
|
bold_red="\x1b[31;1m",
|
|
bold_green="\033[1;32m",
|
|
)
|
|
reset = "\x1b[0m"
|
|
|
|
if color:
|
|
sys.stderr.write(colors[color] + message + reset)
|
|
else:
|
|
sys.stderr.write(message)
|
|
|
|
|
|
def file_with_suffix_exists(directory, suffix) -> bool:
|
|
path = Path(directory)
|
|
for file_path in path.glob(f'*{suffix}'):
|
|
if file_path.is_file():
|
|
return True
|
|
return False
|
|
|
|
|
|
def print_traceback_on_error(func):
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
raise e
|
|
|
|
return wrapper
|
|
|
|
|
|
def get_device_count() -> int:
|
|
return torch.cuda.device_count() if torch.cuda.is_available() else 0
|
|
|
|
|
|
def release_gc():
|
|
''' Release memory allocated by PyTorch and Python garbage collector explicitly and immediately.
|
|
This could be used when some states might be kept in memory even after the variables are deleted.
|
|
'''
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|