TensorRT-LLMs/examples/scaffolding/contrib/AsyncGeneration/stream_generation_controller.py
WeiHaocheng 3fc2a16920
feat(part 2): Enhance the integrated robustness of scaffolding with __init__.py #3305 (#3731)
Signed-off-by: fredw (generated by with_the_same_user script) <20514172+WeiHaocheng@users.noreply.github.com>
2025-04-24 18:47:03 +08:00

69 lines
2.7 KiB
Python

import copy
from enum import Enum
from typing import List
from tensorrt_llm.scaffolding import Controller, GenerationTask, Task
from tensorrt_llm.scaffolding.contrib import StreamGenerationTask
class NativeStreamGenerationController(Controller):
# output threshold is the number of tokens that the worker cancels the generation work
# when the number of tokens generated by the worker is greater than the output threshold,
output_threshold = 100
# streaming_step is the number of tokens that one time the worker returns to the controller
stream_step = 10
class WorkerTag(Enum):
STREAM = "stream"
def __init__(self, custom_sampling_params: dict = None):
super().__init__()
self.custom_sampling_params = copy.deepcopy(custom_sampling_params)
def set_output_threshold(self, output_threshold: int):
self.output_threshold = output_threshold
def set_stream_step(self, stream_step: int):
self.stream_step = stream_step
def process(self, tasks: List[Task], **kwargs):
stream_tasks = []
for task in tasks:
if not isinstance(task, GenerationTask):
raise ValueError(
"NativeStreamGenerationController requires exactly one GenerationTask"
)
task.worker_tag = self.WorkerTag.STREAM
if kwargs.get("custom_sampling_params"):
task.custom_sampling_params = kwargs.get(
"custom_sampling_params")
elif self.custom_sampling_params:
task.custom_sampling_params = self.custom_sampling_params
stream_task = StreamGenerationTask()
stream_task.__dict__ = copy.deepcopy(task.__dict__)
stream_task.streaming_step = self.stream_step
stream_tasks.append(stream_task)
lst = list(range(len(stream_tasks)))
while len(stream_tasks) > 0:
yield stream_tasks
new_tasks = []
new_lst = []
for i in range(len(stream_tasks)):
stream_task = stream_tasks[i]
tasks[lst[i]].output_str = stream_task.output_str
tasks[lst[i]].output_tokens = stream_task.output_tokens
tasks[
lst[i]].cumulative_logprob = stream_task.cumulative_logprob
tasks[lst[i]].logprobs = stream_task.logprobs
if len(stream_task.output_tokens) > self.output_threshold:
stream_task.cancel_flag = True
if not stream_task.end_flag:
new_tasks.append(stream_task)
new_lst.append(lst[i])
stream_tasks = new_tasks
lst = new_lst