import copy from enum import Enum from typing import List from tensorrt_llm.scaffolding import Controller, GenerationTask, Task from tensorrt_llm.scaffolding.contrib.AsyncGeneration import \ StreamGenerationTask class NativeStreamGenerationController(Controller): # output threshold is the number of tokens that the worker cancels the generation work # when the number of tokens generated by the worker is greater than the output threshold, output_threshold = 100 # streaming_step is the number of tokens that one time the worker returns to the controller stream_step = 10 class WorkerTag(Enum): STREAM = "stream" def __init__(self, custom_sampling_params: dict = None): super().__init__() self.custom_sampling_params = copy.deepcopy(custom_sampling_params) def set_output_threshold(self, output_threshold: int): self.output_threshold = output_threshold def set_stream_step(self, stream_step: int): self.stream_step = stream_step def process(self, tasks: List[Task], **kwargs): stream_tasks = [] for task in tasks: if not isinstance(task, GenerationTask): raise ValueError( "NativeStreamGenerationController requires exactly one GenerationTask" ) task.worker_tag = self.WorkerTag.STREAM if kwargs.get("custom_sampling_params"): task.custom_sampling_params = kwargs.get( "custom_sampling_params") elif self.custom_sampling_params: task.custom_sampling_params = self.custom_sampling_params stream_task = StreamGenerationTask.create_from_generation_task( task, self.stream_step) stream_tasks.append(stream_task) lst = list(range(len(stream_tasks))) while len(stream_tasks) > 0: yield stream_tasks new_tasks = [] new_lst = [] for i in range(len(stream_tasks)): stream_task = stream_tasks[i] tasks[lst[i]].output_str = stream_task.output_str tasks[lst[i]].output_tokens = stream_task.output_tokens tasks[ lst[i]].cumulative_logprob = stream_task.cumulative_logprob tasks[lst[i]].logprobs = stream_task.logprobs if len(stream_task.output_tokens) > self.output_threshold: stream_task.cancel_flag = True if not stream_task.end_flag: new_tasks.append(stream_task) new_lst.append(lst[i]) stream_tasks = new_tasks lst = new_lst