mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: fredw (generated by with_the_same_user script) <20514172+WeiHaocheng@users.noreply.github.com>
69 lines
2.7 KiB
Python
69 lines
2.7 KiB
Python
import copy
|
|
from enum import Enum
|
|
from typing import List
|
|
|
|
from tensorrt_llm.scaffolding import Controller, GenerationTask, Task
|
|
from tensorrt_llm.scaffolding.contrib import StreamGenerationTask
|
|
|
|
|
|
class NativeStreamGenerationController(Controller):
|
|
|
|
# output threshold is the number of tokens that the worker cancels the generation work
|
|
# when the number of tokens generated by the worker is greater than the output threshold,
|
|
output_threshold = 100
|
|
|
|
# streaming_step is the number of tokens that one time the worker returns to the controller
|
|
stream_step = 10
|
|
|
|
class WorkerTag(Enum):
|
|
STREAM = "stream"
|
|
|
|
def __init__(self, custom_sampling_params: dict = None):
|
|
super().__init__()
|
|
self.custom_sampling_params = copy.deepcopy(custom_sampling_params)
|
|
|
|
def set_output_threshold(self, output_threshold: int):
|
|
self.output_threshold = output_threshold
|
|
|
|
def set_stream_step(self, stream_step: int):
|
|
self.stream_step = stream_step
|
|
|
|
def process(self, tasks: List[Task], **kwargs):
|
|
stream_tasks = []
|
|
for task in tasks:
|
|
if not isinstance(task, GenerationTask):
|
|
raise ValueError(
|
|
"NativeStreamGenerationController requires exactly one GenerationTask"
|
|
)
|
|
task.worker_tag = self.WorkerTag.STREAM
|
|
if kwargs.get("custom_sampling_params"):
|
|
task.custom_sampling_params = kwargs.get(
|
|
"custom_sampling_params")
|
|
elif self.custom_sampling_params:
|
|
task.custom_sampling_params = self.custom_sampling_params
|
|
stream_task = StreamGenerationTask()
|
|
stream_task.__dict__ = copy.deepcopy(task.__dict__)
|
|
stream_task.streaming_step = self.stream_step
|
|
stream_tasks.append(stream_task)
|
|
lst = list(range(len(stream_tasks)))
|
|
|
|
while len(stream_tasks) > 0:
|
|
yield stream_tasks
|
|
new_tasks = []
|
|
new_lst = []
|
|
for i in range(len(stream_tasks)):
|
|
stream_task = stream_tasks[i]
|
|
tasks[lst[i]].output_str = stream_task.output_str
|
|
tasks[lst[i]].output_tokens = stream_task.output_tokens
|
|
tasks[
|
|
lst[i]].cumulative_logprob = stream_task.cumulative_logprob
|
|
tasks[lst[i]].logprobs = stream_task.logprobs
|
|
if len(stream_task.output_tokens) > self.output_threshold:
|
|
stream_task.cancel_flag = True
|
|
if not stream_task.end_flag:
|
|
new_tasks.append(stream_task)
|
|
new_lst.append(lst[i])
|
|
|
|
stream_tasks = new_tasks
|
|
lst = new_lst
|