TensorRT-LLMs/tensorrt_llm/scaffolding/worker.py
Yan Chunwei 9bd42ecf9b
[TRTLLM-5208][BREAKING CHANGE] chore: make pytorch LLM the default (#5312)
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
2025-06-20 03:01:10 +08:00

194 lines
5.8 KiB
Python

from abc import ABC
from typing import Callable
import openai
from transformers import AutoTokenizer
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm.executor import GenerationExecutor
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
from tensorrt_llm.sampling_params import SamplingParams
from .task import GenerationTask, Task, TaskStatus
ExecutorCls = GenerationExecutor
class Worker(ABC):
# user can use this api to register/add/override task handle function
def register_task_handler(self, task_cls: type[Task],
handler: Callable[[object, Task], TaskStatus]):
worker_cls = type(self)
worker_cls.task_handlers[task_cls] = handler
async def run_task(self, task: Task) -> TaskStatus:
worker_cls = type(self)
if type(task) not in worker_cls.task_handlers:
return TaskStatus.WORKER_NOT_SUPPORTED
return await worker_cls.task_handlers[type(task)](self, task)
task_handlers = {}
def shutdown(self):
pass
def __enter__(self):
return self
def __exit__(self):
self.shutdown()
# helper function
# add first non-None candidate_values to params with key
def add_param_if_not_none(params, key, candidate_values):
for value in candidate_values:
if value is not None:
params[key] = value
return
# helper function
# add first non-None candidate_values to the attribute of the object with key
def add_attr_if_not_none(obj, attr, candidate_values):
for value in candidate_values:
if value is not None:
setattr(obj, attr, value)
return
# Worker for standard openai api
class OpenaiWorker(Worker):
def __init__(
self,
async_client: openai.AsyncOpenAI,
model: str,
):
self.model = model
self.async_client = async_client
def convert_task_params(self, task: GenerationTask):
params = {
"model": self.model,
"prompt": task.input_str,
}
add_param_if_not_none(params, "max_tokens", [task.max_tokens])
add_param_if_not_none(params, "temperature", [task.temperature])
add_param_if_not_none(params, "top_p", [task.top_p])
return params
def fill_generation_task_with_response(self, task: GenerationTask,
response: openai.Completion):
task.output_str = response.choices[0].text
task.logprobs = response.choices[0].logprobs
async def generation_handler(self, task: GenerationTask) -> TaskStatus:
params = self.convert_task_params(task)
# Make the API call
try:
response = await self.async_client.completions.create(**params)
self.fill_generation_task_with_response(task, response)
return TaskStatus.SUCCESS
except Exception as e:
# Handle errors
print('Openai client get exception: ' + str(e))
return TaskStatus.WORKER_EXECEPTION
def shutdown(self):
# OpenAI client doesn't require explicit cleanup
pass
task_handlers = {GenerationTask: generation_handler}
# worker inherit from OpenaiWorker
# add TRT-LLM openai server special params
class TRTOpenaiWorker(OpenaiWorker):
def convert_task_params(self, task: GenerationTask):
params = super().convert_task_params(task)
if task.top_k is not None:
params["extra_body"] = {"top_k": task.top_k}
return params
class TRTLLMWorker(Worker):
def __init__(
self,
llm: LLM,
tokenizer: AutoTokenizer,
):
self.llm = llm
self.tokenizer = tokenizer
self.own_llm = False
@classmethod
def init_with_new_llm(
cls,
model_dir: str,
backend: str = "pytorch",
max_batch_size: int = 32,
max_num_tokens: int = 4096,
kv_cache_free_gpu_memory_fraction: float = 0.9,
disable_overlap_scheduler: bool = False,
):
kv_cache_config = KvCacheConfig(
free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction, )
tokenizer = AutoTokenizer.from_pretrained(
model_dir,
legacy=False,
padding_side='left',
truncation_side='left',
trust_remote_code=False,
use_fast=True,
)
llm = LLM(model_dir,
backend=backend,
tokenizer=tokenizer,
mixed_sampler=True,
disable_overlap_scheduler=disable_overlap_scheduler,
kv_cache_config=kv_cache_config,
max_batch_size=max_batch_size,
max_num_tokens=max_num_tokens)
worker = cls(llm, tokenizer)
worker.own_llm = True
return worker
def convert_task_params(self, task: GenerationTask):
sampling_params = SamplingParams(
max_tokens=task.max_tokens,
temperature=task.temperature,
top_p=task.top_p,
top_k=task.top_k,
return_context_logits=task.return_context_logits)
return sampling_params
async def generation_handler(self, task: GenerationTask) -> TaskStatus:
sampling_params = self.convert_task_params(task)
result = await self.llm.generate_async(task.input_str,
sampling_params=sampling_params)
task.output_tokens = result.outputs[0].token_ids
task.cumulative_logprob = result.outputs[0].cumulative_logprob
task.logprobs = result.outputs[0].logprobs
task.output_str = result.outputs[0].text
task.context_logits = result.context_logits
# TODO: error handle
return TaskStatus.SUCCESS
def shutdown(self):
if self.own_llm:
self.llm.shutdown()
task_handlers = {GenerationTask: generation_handler}