mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
231 lines
7.8 KiB
Python
231 lines
7.8 KiB
Python
from abc import ABC
|
|
from copy import deepcopy
|
|
from typing import Callable, List, Optional, Union
|
|
|
|
import openai
|
|
from transformers import AutoTokenizer
|
|
|
|
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
|
from tensorrt_llm.executor import GenerationExecutor
|
|
from tensorrt_llm.llmapi.llm import LLM
|
|
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
|
|
from tensorrt_llm.sampling_params import SamplingParams
|
|
|
|
from .task import GenerationTask, Task, TaskStatus
|
|
|
|
ExecutorCls = GenerationExecutor
|
|
|
|
|
|
class Worker(ABC):
|
|
# user can use this api to register/add/override task handle function
|
|
def register_task_handler(self, task_cls: type[Task],
|
|
handler: Callable[[object, Task], TaskStatus]):
|
|
worker_cls = type(self)
|
|
worker_cls.task_handlers[task_cls] = handler
|
|
|
|
async def run_task(self, task: Task) -> TaskStatus:
|
|
worker_cls = type(self)
|
|
if type(task) not in worker_cls.task_handlers:
|
|
return TaskStatus.WORKER_NOT_SUPPORTED
|
|
return await worker_cls.task_handlers[type(task)](self, task)
|
|
|
|
task_handlers = {}
|
|
|
|
def shutdown(self):
|
|
pass
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self):
|
|
self.shutdown()
|
|
|
|
|
|
# helper function
|
|
# add first non-None candidate_values to params with key
|
|
def add_param_if_not_none(params, key, candidate_values):
|
|
for value in candidate_values:
|
|
if value is not None:
|
|
params[key] = value
|
|
return
|
|
|
|
|
|
# helper function
|
|
# add first non-None candidate_values to the attribute of the object with key
|
|
def add_attr_if_not_none(obj, attr, candidate_values):
|
|
for value in candidate_values:
|
|
if value is not None:
|
|
setattr(obj, attr, value)
|
|
return
|
|
|
|
|
|
# Worker for standard openai api
|
|
class OpenaiWorker(Worker):
|
|
|
|
def __init__(
|
|
self,
|
|
async_client: openai.AsyncOpenAI,
|
|
model: str,
|
|
max_tokens: int = 2048,
|
|
temperature: float = 0.9,
|
|
top_p: Optional[float] = None,
|
|
):
|
|
self.model = model
|
|
self.async_client = async_client
|
|
self.max_tokens = max_tokens
|
|
self.temperature = temperature
|
|
self.top_p = top_p
|
|
|
|
def combine_params_with_generation_task(self, params: dict,
|
|
task: GenerationTask):
|
|
params["prompt"] = task.input_str
|
|
|
|
add_param_if_not_none(params, "max_tokens",
|
|
[task.max_tokens, self.max_tokens])
|
|
add_param_if_not_none(params, "temperature",
|
|
[task.temperature, self.temperature])
|
|
add_param_if_not_none(params, "top_p", [task.top_p, self.top_p])
|
|
|
|
def fill_generation_task_with_response(self, task: GenerationTask,
|
|
response: openai.Completion):
|
|
task.output_str = response.choices[0].text
|
|
task.logprobs = response.choices[0].logprobs
|
|
|
|
async def generation_handler(self, task: GenerationTask) -> TaskStatus:
|
|
params = {}
|
|
|
|
# Set required parameters
|
|
params["model"] = self.model
|
|
|
|
self.combine_params_with_generation_task(params, task)
|
|
|
|
# Make the API call
|
|
try:
|
|
response = await self.async_client.completions.create(**params)
|
|
self.fill_generation_task_with_response(task, response)
|
|
|
|
return TaskStatus.SUCCESS
|
|
|
|
except Exception as e:
|
|
# Handle errors
|
|
print('Openai client get exception: ' + str(e))
|
|
return TaskStatus.WORKER_EXECEPTION
|
|
|
|
def shutdown(self):
|
|
# OpenAI client doesn't require explicit cleanup
|
|
pass
|
|
|
|
task_handlers = {GenerationTask: generation_handler}
|
|
|
|
|
|
# worker inherit from OpenaiWorker
|
|
# add TRT-LLM openai server special params
|
|
class TRTOpenaiWorker(OpenaiWorker):
|
|
# just manager the TRT-LLM openai server special params
|
|
def __init__(self, top_k: Optional[float] = None, **kwargs):
|
|
self.top_k = top_k
|
|
super().__init__(**kwargs)
|
|
|
|
def combine_params_with_generation_task(self, params: dict,
|
|
task: GenerationTask):
|
|
super().combine_params_with_generation_task(params, task)
|
|
extra_body = {}
|
|
add_param_if_not_none(extra_body, "top_k", [task.top_k, self.top_k])
|
|
params["extra_body"] = extra_body
|
|
|
|
|
|
class TRTLLMWorker(Worker):
|
|
|
|
def __init__(
|
|
self,
|
|
llm: LLM,
|
|
tokenizer: AutoTokenizer,
|
|
max_tokens: int = 2048,
|
|
temperature: float = 0.9,
|
|
top_p: Optional[float] = None,
|
|
topk: Optional[float] = None,
|
|
stop: Optional[Union[str, List[str]]] = None,
|
|
):
|
|
self.llm = llm
|
|
self.tokenizer = tokenizer
|
|
self.default_sampling_params = SamplingParams(max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
top_k=topk,
|
|
stop=stop)
|
|
|
|
self.default_sampling_params._setup(self.tokenizer)
|
|
self.own_llm = False
|
|
|
|
@classmethod
|
|
def init_with_new_llm(cls,
|
|
model_dir: str,
|
|
backend: str = None,
|
|
max_batch_size: int = 32,
|
|
max_num_tokens: int = 4096,
|
|
kv_cache_free_gpu_memory_fraction: float = 0.9,
|
|
enable_overlap_scheduler: bool = True,
|
|
**kwargs):
|
|
pytorch_backend_config = PyTorchConfig(
|
|
mixed_decoder=True,
|
|
enable_overlap_scheduler=enable_overlap_scheduler,
|
|
)
|
|
kv_cache_config = KvCacheConfig(
|
|
free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction, )
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_dir,
|
|
legacy=False,
|
|
padding_side='left',
|
|
truncation_side='left',
|
|
trust_remote_code=False,
|
|
use_fast=True,
|
|
)
|
|
|
|
llm = LLM(model_dir,
|
|
backend=backend,
|
|
tokenizer=tokenizer,
|
|
pytorch_backend_config=pytorch_backend_config,
|
|
kv_cache_config=kv_cache_config,
|
|
max_batch_size=max_batch_size,
|
|
max_num_tokens=max_num_tokens)
|
|
|
|
worker = cls(llm, tokenizer, **kwargs)
|
|
worker.own_llm = True
|
|
return worker
|
|
|
|
def combine_sampling_params_with_generation_task(self,
|
|
task: GenerationTask):
|
|
sampling_params = deepcopy(self.default_sampling_params)
|
|
|
|
add_attr_if_not_none(sampling_params, "max_tokens", [task.max_tokens])
|
|
add_attr_if_not_none(sampling_params, "temperature", [task.temperature])
|
|
add_attr_if_not_none(sampling_params, "top_p", [task.top_p])
|
|
add_attr_if_not_none(sampling_params, "top_k", [task.top_k])
|
|
add_attr_if_not_none(sampling_params, "return_context_logits",
|
|
[task.return_context_logits])
|
|
|
|
return sampling_params
|
|
|
|
async def generation_handler(self, task: GenerationTask) -> TaskStatus:
|
|
sampling_params = self.combine_sampling_params_with_generation_task(
|
|
task)
|
|
|
|
result = await self.llm.generate_async(task.input_str,
|
|
sampling_params=sampling_params)
|
|
|
|
task.output_tokens = result.outputs[0].token_ids
|
|
task.cumulative_logprob = result.outputs[0].cumulative_logprob
|
|
task.logprobs = result.outputs[0].logprobs
|
|
task.output_str = result.outputs[0].text
|
|
task.context_logits = result.context_logits
|
|
|
|
# TODO: error handle
|
|
return TaskStatus.SUCCESS
|
|
|
|
def shutdown(self):
|
|
if self.own_llm:
|
|
self.llm.shutdown()
|
|
|
|
task_handlers = {GenerationTask: generation_handler}
|