mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-26 21:53:30 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Bhuvanesh Sridharan <bhuvan.sridharan@gmail.com> Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com> Co-authored-by: Eddie-Wang1120 <wangjinheng1120@163.com> Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
130 lines
3.4 KiB
Python
130 lines
3.4 KiB
Python
import sys
|
|
import traceback
|
|
from dataclasses import dataclass, field
|
|
from functools import wraps
|
|
from pathlib import Path
|
|
from typing import List, Optional, Union
|
|
|
|
import torch
|
|
|
|
import tensorrt_llm.bindings as tllm
|
|
|
|
|
|
class SamplingConfig(tllm.SamplingConfig):
|
|
''' The sampling config for the generation. '''
|
|
|
|
# TODO[chunweiy]: switch to the cpp executor's once ready
|
|
def __init__(self,
|
|
end_id: Optional[int] = None,
|
|
pad_id: Optional[int] = None,
|
|
beam_width: int = 1,
|
|
max_new_tokens: Optional[int] = None) -> None:
|
|
super().__init__(beam_width)
|
|
self.end_id = end_id
|
|
self.pad_id = pad_id if pad_id is not None else end_id
|
|
self.max_new_tokens = max_new_tokens
|
|
|
|
def __setstate__(self, arg0: tuple) -> None:
|
|
self.end_id = arg0[0]
|
|
self.pad_id = arg0[1]
|
|
self.max_new_tokens = arg0[2]
|
|
super().__setstate__(arg0[3:])
|
|
|
|
def __getstate__(self) -> tuple:
|
|
return (self.end_id, self.pad_id,
|
|
self.max_new_tokens) + super().__getstate__()
|
|
|
|
def get_attr_names(self):
|
|
return list(self.__dict__.keys()) + [
|
|
"beam_search_diversity_rate",
|
|
"beam_width",
|
|
"early_stopping",
|
|
"frequency_penalty",
|
|
"length_penalty",
|
|
"min_length",
|
|
"presence_penalty",
|
|
"random_seed",
|
|
"repetition_penalty",
|
|
"temperature",
|
|
"top_k",
|
|
"top_p",
|
|
"top_p_decay",
|
|
"top_p_min",
|
|
"top_p_reset_ids",
|
|
]
|
|
|
|
def __repr__(self):
|
|
return f"SamplingConfig(" + ", ".join(
|
|
f"{k}={getattr(self, k)}" for k in self.get_attr_names()
|
|
if getattr(self, k) is not None) + ")"
|
|
|
|
|
|
@dataclass
|
|
class GenerationOutput:
|
|
text: str = ""
|
|
token_ids: Union[List[int], List[List[int]]] = field(default_factory=list)
|
|
logprobs: List[float] = field(default_factory=list)
|
|
|
|
|
|
def print_colored(message, color: str = None):
|
|
colors = dict(
|
|
grey="\x1b[38;20m",
|
|
yellow="\x1b[33;20m",
|
|
red="\x1b[31;20m",
|
|
bold_red="\x1b[31;1m",
|
|
bold_green="\033[1;32m",
|
|
green="\033[0;32m",
|
|
)
|
|
reset = "\x1b[0m"
|
|
|
|
if color:
|
|
sys.stderr.write(colors[color] + message + reset)
|
|
else:
|
|
sys.stderr.write(message)
|
|
|
|
|
|
def file_with_glob_exists(directory, glob) -> bool:
|
|
path = Path(directory)
|
|
for file_path in path.glob(glob):
|
|
if file_path.is_file():
|
|
return True
|
|
return False
|
|
|
|
|
|
def file_with_suffix_exists(directory, suffix) -> bool:
|
|
return file_with_glob_exists(directory, f'*{suffix}')
|
|
|
|
|
|
def print_traceback_on_error(func):
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
raise e
|
|
|
|
return wrapper
|
|
|
|
|
|
def get_device_count() -> int:
|
|
return torch.cuda.device_count() if torch.cuda.is_available() else 0
|
|
|
|
|
|
def get_total_gpu_memory(device: int) -> float:
|
|
return torch.cuda.get_device_properties(device).total_memory
|
|
|
|
|
|
class ContextManager:
|
|
''' A helper to create a context manager for a resource. '''
|
|
|
|
def __init__(self, resource):
|
|
self.resource = resource
|
|
|
|
def __enter__(self):
|
|
return self.resource.__enter__()
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
return self.resource.__exit__(exc_type, exc_value, traceback)
|