TensorRT-LLMs/tensorrt_llm/scaffolding/task.py
WeiHaocheng ff35af77ea
feat: refactor scaffolding worker and support openai api worker (#3166)
Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
Signed-off-by: fredw <20514172+WeiHaocheng@users.noreply.github.com>
2025-04-01 18:31:52 +08:00

73 lines
2.1 KiB
Python

from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional, Union
class ScaffoldingOutput:
def __init__(self):
self.output_str = None
@dataclass
class Task:
# Reserve for custom input params.
custom_input_params: Optional[dict] = None
# Scaffolding delivers the task to the Worker by worker_tag.
worker_tag: str = field(default=None)
# Reserve for custom output params.
custom_output_params: Optional[dict] = None
class TaskStatus(Enum):
SUCCESS = "success"
WORKER_NOT_SUPPORTED = "worker_not_supported"
WORKER_EXECEPTION = "worker_exception"
@dataclass
class GenerationTask(Task):
# input field
input_tokens: Optional[List[int]] = field(default=None)
input_str: Optional[str] = field(default=None)
skip_tokenizer: bool = False
skip_detokenizer: bool = False
# sampling params
#custom_sampling_params: Optional[dict] = None
max_tokens: Optional[int] = field(default=None)
temperature: Optional[float] = field(default=None)
top_p: Optional[float] = field(default=None)
top_k: Optional[int] = field(default=None)
# suggest to use Controller.WorkerTag
# anyway, users need to ensure that the value of the worker_tag can be found in the scaffoldingLlm's workers map
worker_tag: Union[str, "Controller.WorkerTag"] = None
# result field
output_tokens: List[int] = None
output_str: Optional[str] = None
cumulative_logprob: Optional[float] = None
logprobs: List[float] = field(default_factory=list)
@staticmethod
def create_from_prompt(prompt: str) -> "GenerationTask":
task = GenerationTask()
task.input_str = prompt
task.skip_tokenizer = False
task.skip_detokenizer = False
return task
def create_scaffolding_output(self) -> "ScaffoldingOutput":
output = ScaffoldingOutput()
output.output_str = self.output_str
return output
@dataclass
class RewardTask(Task):
# input field
input_tokens: Optional[List[int]] = field(default=None)
input_str: Optional[str] = field(default=None)