mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
194 lines
5.8 KiB
Python
194 lines
5.8 KiB
Python
from abc import ABC
|
|
from typing import Callable
|
|
|
|
import openai
|
|
from transformers import AutoTokenizer
|
|
|
|
from tensorrt_llm._tensorrt_engine import LLM
|
|
from tensorrt_llm.executor import GenerationExecutor
|
|
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
|
|
from tensorrt_llm.sampling_params import SamplingParams
|
|
|
|
from .task import GenerationTask, Task, TaskStatus
|
|
|
|
ExecutorCls = GenerationExecutor
|
|
|
|
|
|
class Worker(ABC):
|
|
# user can use this api to register/add/override task handle function
|
|
def register_task_handler(self, task_cls: type[Task],
|
|
handler: Callable[[object, Task], TaskStatus]):
|
|
worker_cls = type(self)
|
|
worker_cls.task_handlers[task_cls] = handler
|
|
|
|
async def run_task(self, task: Task) -> TaskStatus:
|
|
worker_cls = type(self)
|
|
if type(task) not in worker_cls.task_handlers:
|
|
return TaskStatus.WORKER_NOT_SUPPORTED
|
|
return await worker_cls.task_handlers[type(task)](self, task)
|
|
|
|
task_handlers = {}
|
|
|
|
def shutdown(self):
|
|
pass
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self):
|
|
self.shutdown()
|
|
|
|
|
|
# helper function
|
|
# add first non-None candidate_values to params with key
|
|
def add_param_if_not_none(params, key, candidate_values):
|
|
for value in candidate_values:
|
|
if value is not None:
|
|
params[key] = value
|
|
return
|
|
|
|
|
|
# helper function
|
|
# add first non-None candidate_values to the attribute of the object with key
|
|
def add_attr_if_not_none(obj, attr, candidate_values):
|
|
for value in candidate_values:
|
|
if value is not None:
|
|
setattr(obj, attr, value)
|
|
return
|
|
|
|
|
|
# Worker for standard openai api
|
|
class OpenaiWorker(Worker):
|
|
|
|
def __init__(
|
|
self,
|
|
async_client: openai.AsyncOpenAI,
|
|
model: str,
|
|
):
|
|
self.model = model
|
|
self.async_client = async_client
|
|
|
|
def convert_task_params(self, task: GenerationTask):
|
|
params = {
|
|
"model": self.model,
|
|
"prompt": task.input_str,
|
|
}
|
|
add_param_if_not_none(params, "max_tokens", [task.max_tokens])
|
|
add_param_if_not_none(params, "temperature", [task.temperature])
|
|
add_param_if_not_none(params, "top_p", [task.top_p])
|
|
return params
|
|
|
|
def fill_generation_task_with_response(self, task: GenerationTask,
|
|
response: openai.Completion):
|
|
task.output_str = response.choices[0].text
|
|
task.logprobs = response.choices[0].logprobs
|
|
|
|
async def generation_handler(self, task: GenerationTask) -> TaskStatus:
|
|
params = self.convert_task_params(task)
|
|
|
|
# Make the API call
|
|
try:
|
|
response = await self.async_client.completions.create(**params)
|
|
self.fill_generation_task_with_response(task, response)
|
|
|
|
return TaskStatus.SUCCESS
|
|
|
|
except Exception as e:
|
|
# Handle errors
|
|
print('Openai client get exception: ' + str(e))
|
|
return TaskStatus.WORKER_EXECEPTION
|
|
|
|
def shutdown(self):
|
|
# OpenAI client doesn't require explicit cleanup
|
|
pass
|
|
|
|
task_handlers = {GenerationTask: generation_handler}
|
|
|
|
|
|
# worker inherit from OpenaiWorker
|
|
# add TRT-LLM openai server special params
|
|
class TRTOpenaiWorker(OpenaiWorker):
|
|
|
|
def convert_task_params(self, task: GenerationTask):
|
|
params = super().convert_task_params(task)
|
|
if task.top_k is not None:
|
|
params["extra_body"] = {"top_k": task.top_k}
|
|
return params
|
|
|
|
|
|
class TRTLLMWorker(Worker):
|
|
|
|
def __init__(
|
|
self,
|
|
llm: LLM,
|
|
tokenizer: AutoTokenizer,
|
|
):
|
|
self.llm = llm
|
|
self.tokenizer = tokenizer
|
|
self.own_llm = False
|
|
|
|
@classmethod
|
|
def init_with_new_llm(
|
|
cls,
|
|
model_dir: str,
|
|
backend: str = "pytorch",
|
|
max_batch_size: int = 32,
|
|
max_num_tokens: int = 4096,
|
|
kv_cache_free_gpu_memory_fraction: float = 0.9,
|
|
disable_overlap_scheduler: bool = False,
|
|
):
|
|
kv_cache_config = KvCacheConfig(
|
|
free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction, )
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_dir,
|
|
legacy=False,
|
|
padding_side='left',
|
|
truncation_side='left',
|
|
trust_remote_code=False,
|
|
use_fast=True,
|
|
)
|
|
|
|
llm = LLM(model_dir,
|
|
backend=backend,
|
|
tokenizer=tokenizer,
|
|
mixed_sampler=True,
|
|
disable_overlap_scheduler=disable_overlap_scheduler,
|
|
kv_cache_config=kv_cache_config,
|
|
max_batch_size=max_batch_size,
|
|
max_num_tokens=max_num_tokens)
|
|
|
|
worker = cls(llm, tokenizer)
|
|
worker.own_llm = True
|
|
return worker
|
|
|
|
def convert_task_params(self, task: GenerationTask):
|
|
sampling_params = SamplingParams(
|
|
max_tokens=task.max_tokens,
|
|
temperature=task.temperature,
|
|
top_p=task.top_p,
|
|
top_k=task.top_k,
|
|
return_context_logits=task.return_context_logits)
|
|
return sampling_params
|
|
|
|
async def generation_handler(self, task: GenerationTask) -> TaskStatus:
|
|
sampling_params = self.convert_task_params(task)
|
|
|
|
result = await self.llm.generate_async(task.input_str,
|
|
sampling_params=sampling_params)
|
|
|
|
task.output_tokens = result.outputs[0].token_ids
|
|
task.cumulative_logprob = result.outputs[0].cumulative_logprob
|
|
task.logprobs = result.outputs[0].logprobs
|
|
task.output_str = result.outputs[0].text
|
|
task.context_logits = result.context_logits
|
|
|
|
# TODO: error handle
|
|
return TaskStatus.SUCCESS
|
|
|
|
def shutdown(self):
|
|
if self.own_llm:
|
|
self.llm.shutdown()
|
|
|
|
task_handlers = {GenerationTask: generation_handler}
|