diff --git a/examples/scaffolding/contrib/AsyncGeneration/README.md b/examples/scaffolding/contrib/AsyncGeneration/README.md deleted file mode 100644 index c92a287e31..0000000000 --- a/examples/scaffolding/contrib/AsyncGeneration/README.md +++ /dev/null @@ -1,10 +0,0 @@ - -This example shows how to use the `StreamGenerationTask` and `stream_generation_handler` to enable efficient streaming-based generation workflows. - -How to run the example? - -```bash -python stream_generation_run.py -``` - -See more detail on [tensorrt_llm/scaffolding/contrib/AsyncGeneration/README.md](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/scaffolding/contrib/AsyncGeneration/README.md). diff --git a/examples/scaffolding/contrib/AsyncGeneration/stream_generation_run.py b/examples/scaffolding/contrib/AsyncGeneration/stream_generation_run.py deleted file mode 100644 index 99b680e979..0000000000 --- a/examples/scaffolding/contrib/AsyncGeneration/stream_generation_run.py +++ /dev/null @@ -1,104 +0,0 @@ -import argparse - -from stream_generation_controller import NativeStreamGenerationController - -from tensorrt_llm.scaffolding import ScaffoldingLlm, TRTLLMWorker -from tensorrt_llm.scaffolding.contrib import (StreamGenerationTask, - stream_generation_handler) - - -def parse_arguments(): - parser = argparse.ArgumentParser() - # .e.g. DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B - parser.add_argument( - '--model_dir', - type=str, - required=True, - help="Path to the directory containing the generation model") - parser.add_argument('--run_type', type=str, default='original') - args = parser.parse_args() - return args - - -def test(prompts, proposer_worker): - prototype_controller = NativeStreamGenerationController( - sampling_params={"temperature": 0.9}) - - llm = ScaffoldingLlm( - prototype_controller, - {NativeStreamGenerationController.WorkerTag.STREAM: proposer_worker}, - ) - results = llm.generate(prompts) - for result in results: - print(result.output.output_str) - print(f'test main shutting down...') - llm.shutdown() - print(f'test worker shutting down...') - proposer_worker.shutdown() - print(f'test main shut down done') - - -def test_step(prompts, proposer_worker): - prototype_controller = NativeStreamGenerationController() - prototype_controller.set_stream_step(20) - - llm = ScaffoldingLlm( - prototype_controller, - {NativeStreamGenerationController.WorkerTag.STREAM: proposer_worker}, - ) - results = llm.generate(prompts) - for result in results: - print(result.output.output_str) - print(f'test step main shutting down...') - llm.shutdown() - print(f'test step worker shutting down...') - proposer_worker.shutdown() - print(f'test step main shut down done') - - -def test_cancel(prompts, proposer_worker): - prototype_controller = NativeStreamGenerationController() - prototype_controller.set_output_threshold(200) - - llm = ScaffoldingLlm( - prototype_controller, - {NativeStreamGenerationController.WorkerTag.STREAM: proposer_worker}, - ) - results = llm.generate(prompts) - for result in results: - print(result.output.output_str) - print(f'test cancel main shutting down...') - llm.shutdown() - print(f'test cancel worker shutting down...') - proposer_worker.shutdown() - print(f'test cancel main shut down done') - - -def main(): - args = parse_arguments() - - prompts = [ - "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\r\n\r\n", - "There exist real numbers $x$ and $y$, both greater than 1, such that $\\log_x\\left(y^x\\right)=\\log_y\\left(x^{4y}\\right)=10$. Find $xy$.", - "Find the largest possible real part of \\[(75+117i)z+\\frac{96+144i}{z}\\]where $z$ is a complex number with $|z|=4$.", - ] - llm_worker = TRTLLMWorker.init_with_new_llm( - args.model_dir, - backend="pytorch", - max_batch_size=32, - max_num_tokens=4096, - ) - - print(f'main llm worker init done') - llm_worker.register_task_handler(StreamGenerationTask, - stream_generation_handler) - if args.run_type == 'original': - test(prompts, llm_worker) - elif args.run_type == 'step': - test_step(prompts, llm_worker) - elif args.run_type == 'cancel': - test_cancel(prompts, llm_worker) - - -if __name__ == "__main__": - main() diff --git a/examples/scaffolding/contrib/Dynasor/scaffolding_dynasor_run.py b/examples/scaffolding/contrib/Dynasor/scaffolding_dynasor_run.py index 1844dea813..853383a11e 100644 --- a/examples/scaffolding/contrib/Dynasor/scaffolding_dynasor_run.py +++ b/examples/scaffolding/contrib/Dynasor/scaffolding_dynasor_run.py @@ -65,15 +65,14 @@ def main(): if args.streaming: async def task(prompt: str): - i = 0 + step = 0 async for result in llm.generate_async(prompt): - i += 1 - print(">>>", i, result) - async for output in result.cur_output: - print(">>>", i, len(output.outputs[0].token_ids), "\n", - output.outputs[0].text) - print(f">>> final output {len(result.outputs[0].token_ids)}\n", - result.outputs[0].text) + step += 1 + tokens_num = len( + result.outputs[0].token_ids + ) if result.outputs[0].token_ids is not None else 0 + print(">>>", step, tokens_num, "\n", result.outputs[0].text) + print(f">>> final output {tokens_num}\n", result.outputs[0].text) # Need to provide LLM's event loop to get results in the middle of the whole process. asyncio.run_coroutine_threadsafe(task(prompts[0]), llm.loop).result() diff --git a/examples/scaffolding/contrib/mcp/mcptest.py b/examples/scaffolding/contrib/mcp/mcptest.py index ebfa29012d..bfd5332437 100644 --- a/examples/scaffolding/contrib/mcp/mcptest.py +++ b/examples/scaffolding/contrib/mcp/mcptest.py @@ -4,8 +4,8 @@ import asyncio from openai import AsyncOpenAI from tensorrt_llm.scaffolding import OpenaiWorker, ScaffoldingLlm -from tensorrt_llm.scaffolding.contrib import (ChatTask, MCPController, - MCPWorker, chat_handler) +from tensorrt_llm.scaffolding.contrib.mcp import (ChatTask, MCPController, + MCPWorker, chat_handler) def parse_arguments(): @@ -28,7 +28,7 @@ def parse_arguments(): from openai import AsyncOpenAI from tensorrt_llm.scaffolding import OpenaiWorker, ScaffoldingLlm -from tensorrt_llm.scaffolding.contrib import MCPController, MCPWorker +from tensorrt_llm.scaffolding.contrib.mcp import MCPController, MCPWorker async def main(): @@ -41,7 +41,7 @@ async def main(): ] API_KEY = args.API_KEY urls = [ - "http://0.0.0.0:8080/sse", "http://0.0.0.0:8081/sse", + #"http://0.0.0.0:8080/sse", "http://0.0.0.0:8081/sse", "http://0.0.0.0:8082/sse" ] print(f"API_KEY {API_KEY}") @@ -61,7 +61,7 @@ async def main(): future = llm.generate_async(prompts[0]) result = await future.aresult() - print(f"\nresult is {result.output.output_str}\n") + print(f"\nresult is {result.outputs[0].text}\n") print(f'main shutting down...') llm.shutdown() diff --git a/examples/scaffolding/run_basic_generation.py b/examples/scaffolding/run_basic_generation.py index 7acec78570..20155f41f0 100644 --- a/examples/scaffolding/run_basic_generation.py +++ b/examples/scaffolding/run_basic_generation.py @@ -53,14 +53,12 @@ def test_async(prompt, proposer_worker): prototype_controller, {NativeGenerationController.WorkerTag.GENERATION: proposer_worker}, ) - i = 0 + step = 0 async for result in llm.generate_async(prompt): - i += 1 - print(">>>", i, result) - async for output in result.cur_output: - print(">>>", i, len(output.outputs[0].token_ids), "\n", - output.outputs[0].text) + step += 1 + print(">>>", step, len(result.outputs[0].token_ids), "\n", + result.outputs[0].text) print(f">>> final output {len(result.outputs[0].token_ids)}\n", result.outputs[0].text) diff --git a/examples/scaffolding/token_budget_majority_vote.py b/examples/scaffolding/token_budget_majority_vote.py index 9ea684f677..c335445c28 100644 --- a/examples/scaffolding/token_budget_majority_vote.py +++ b/examples/scaffolding/token_budget_majority_vote.py @@ -104,7 +104,7 @@ def main(): prompt = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\r\n\r\n" result = llm.generate(prompt) - extracted_answer = extract_answer_from_boxed(result.output.output_str) + extracted_answer = extract_answer_from_boxed(result.outputs[0].text) print(f'extracted_answer={extracted_answer}') llm.shutdown(shutdown_workers=True) diff --git a/tensorrt_llm/scaffolding/__init__.py b/tensorrt_llm/scaffolding/__init__.py index 3b50884558..342b5496bd 100644 --- a/tensorrt_llm/scaffolding/__init__.py +++ b/tensorrt_llm/scaffolding/__init__.py @@ -25,6 +25,7 @@ __all__ = [ "GenerationTask", "StreamGenerationTask", "RewardTask", + "StreamGenerationTask", "Worker", "OpenaiWorker", "TRTOpenaiWorker", diff --git a/tensorrt_llm/scaffolding/contrib/AsyncGeneration/README.md b/tensorrt_llm/scaffolding/contrib/AsyncGeneration/README.md deleted file mode 100644 index ea3d1cc8a6..0000000000 --- a/tensorrt_llm/scaffolding/contrib/AsyncGeneration/README.md +++ /dev/null @@ -1,46 +0,0 @@ -## Overview - -`StreamGenerationTask` is an extension of `GenerationTask` designed for token-level streaming generation in asynchronous LLM workflows using TensorRT-LLM. It enables the controller to receive partial results during generation, which is critical for real-time or latency-sensitive applications such as chatbots, speech generation, or UI-interactive systems. - ---- - -## Features - -- ✅ Supports **streamed token delivery** by step (e.g., `streaming_step=1`). -- ✅ Supports **cancellation** of generation using a flag (`cancel_flag=True`). -- ✅ Tracks **stream completion status** (`end_flag=True` when done). -- ✅ Integrated with a streaming-capable LLM interface (`generate_async`). - ---- - -## Fields in `StreamGenerationTask` - -| Field | Description | -|-------|-------------| -| `cancel_flag` | If `True`, the generation will be cancelled on the worker side. | -| `streaming_step` | Number of new tokens required before returning control to the controller. If set to `0`, the task is returned immediately if any new tokens are available. | -| `request_handle` | Internal handle for the streaming generation (used by the worker). | -| `end_flag` | Indicates whether generation is finished. | -| `output_str` / `output_tokens` / `logprobs` | Outputs after each generation step. | - ---- - -## Usage in Controller/Worker - -The Controller can utilize `StreamGenerationTask` to enable efficient streaming-based generation workflows: -- It sends tasks to the worker, which returns them when the number of newly generated tokens reaches the specified `streaming_step`. -- It can cancel long-running tasks by setting `task.cancel_flag = True` when the number of generated tokens exceeds a predefined threshold. - -To support this behavior on the worker side, we have implemented `stream_generation_handler` and you need to register it with the worker in your project. This handler should process `StreamGenerationTask` instances step-by-step and update relevant fields such as `output_tokens`, `output_str`. - -This design allows the controller and worker to coordinate generation in a token-efficient and responsive manner, ideal for real-time applications. - -You can see more details in `stream_generation_controller.py` and `stream_generation_task.py` from `examples/scaffolding/contrib/AsyncGeneration`. - -## Notes -Remember to register the `stream_generation_handler` with the `TRTLLMWorker`. - -## TODO - -- Add error handling for failed `request_handle`. -- Support retry or backoff mechanism if generation stalls. diff --git a/tensorrt_llm/scaffolding/contrib/AsyncGeneration/__init__.py b/tensorrt_llm/scaffolding/contrib/AsyncGeneration/__init__.py deleted file mode 100644 index d56b274356..0000000000 --- a/tensorrt_llm/scaffolding/contrib/AsyncGeneration/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .stream_generation import StreamGenerationTask, stream_generation_handler - -__all__ = ["stream_generation_handler", "StreamGenerationTask"] diff --git a/tensorrt_llm/scaffolding/contrib/Dynasor/dynasor_controller.py b/tensorrt_llm/scaffolding/contrib/Dynasor/dynasor_controller.py index 2e78188dc3..2b0bf09b83 100644 --- a/tensorrt_llm/scaffolding/contrib/Dynasor/dynasor_controller.py +++ b/tensorrt_llm/scaffolding/contrib/Dynasor/dynasor_controller.py @@ -126,18 +126,10 @@ class DynasorGenerationController(Controller): probe_answers[-self.certainty_threshold:]) == self.certainty_threshold and sum(probe_certain_count) == self.certainty_threshold): - tasks[0].result = probe_task.result - # If the current prompt indicates the chain-of-thought phase has ended, use one type of suffix. - if "" in current_prompt: - tasks[0].output_str = (current_prompt + self.answer_suffix + - probe_answers[-1] + "}\n\\]") - return - else: - # Otherwise, use the suffix with marker to transition clearly. - tasks[0].output_str = (current_prompt + - self.answer_suffix_with_marker + - probe_answers[-1] + "}\n\\]") - return + suffix = self.answer_suffix if "" in current_prompt else self.answer_suffix_with_marker + suffix += probe_answers[-1] + "}\n\\]" + current_prompt += suffix + break # If not confident, do another round of generation # Append the newly generated text from the proposer to the current prompt for the next iteration. @@ -145,7 +137,6 @@ class DynasorGenerationController(Controller): # If the maximum token limit is reached without satisfying the certainty condition, # output the accumulated prompt as the final output. - tasks[0].result = proposer_task.result tasks[0].output_str = current_prompt return diff --git a/tensorrt_llm/scaffolding/contrib/mcp/chat_handler.py b/tensorrt_llm/scaffolding/contrib/mcp/chat_handler.py index 90f10e0808..b8bde0809a 100644 --- a/tensorrt_llm/scaffolding/contrib/mcp/chat_handler.py +++ b/tensorrt_llm/scaffolding/contrib/mcp/chat_handler.py @@ -19,11 +19,9 @@ def add_param_if_not_none(params, key, candidate_values): def combine_params_with_chat_task(worker, params: dict, task: ChatTask): params["messages"] = task.messages - add_param_if_not_none(params, "max_tokens", - [task.max_tokens, worker.max_tokens]) - add_param_if_not_none(params, "temperature", - [task.temperature, worker.temperature]) - add_param_if_not_none(params, "top_p", [task.top_p, worker.top_p]) + add_param_if_not_none(params, "max_tokens", [task.max_tokens]) + add_param_if_not_none(params, "temperature", [task.temperature]) + add_param_if_not_none(params, "top_p", [task.top_p]) add_param_if_not_none(params, "tools", [task.tools]) diff --git a/tensorrt_llm/scaffolding/controller.py b/tensorrt_llm/scaffolding/controller.py index 2d34b76ef6..614685f8f4 100644 --- a/tensorrt_llm/scaffolding/controller.py +++ b/tensorrt_llm/scaffolding/controller.py @@ -67,7 +67,7 @@ class NativeGenerationController(Controller): for key, value in self.sampling_params.items(): if getattr(task, key) is None: setattr(task, key, value) - task.streaming = self.streaming + task.streaming_output_flag = self.streaming yield tasks diff --git a/tensorrt_llm/scaffolding/result.py b/tensorrt_llm/scaffolding/result.py index 9ebb978d9b..51b67bbf94 100644 --- a/tensorrt_llm/scaffolding/result.py +++ b/tensorrt_llm/scaffolding/result.py @@ -1,80 +1,67 @@ import asyncio -from typing import Mapping, Optional +from dataclasses import dataclass +from typing import Any, List, Mapping, Optional, Union -from tensorrt_llm.executor.result import GenerationResult + +@dataclass +class ScaffoldingOutput: + text: str + token_ids: List[int] class ScaffoldingResult: - def __init__(self, streaming_event: Optional[asyncio.Event] = None): + def __init__(self): super().__init__() self.aqueue = asyncio.Queue() - self.cur_output: GenerationResult = None + #self.cur_output: GenerationResult = None + self.outputs = [] + # only support one output for now, so use an empty obj to init + self.outputs.append(ScaffoldingOutput("", [])) self._done = False self.task_collections = None - self.streaming_event = streaming_event - def set_output(self, output: GenerationResult): + def set_output(self, output: Union[ScaffoldingOutput, Any]): + if isinstance(output, ScaffoldingOutput): + self.set_output_streaming(output) + # terminate + self.set_output_streaming(None) + + def set_output_streaming(self, output: Union[ScaffoldingOutput, Any]): self.aqueue.put_nowait(output) - self._done = True - - async def set_output_async(self, output: GenerationResult): - await self.aqueue.put(output) def set_task_collections(self, task_collections: Mapping[str, "TaskCollection"]): self.task_collections = task_collections - @property - def outputs(self): - return self.cur_output.outputs if self.cur_output else None - - @property - def finished(self) -> bool: - return self.cur_output is not None and self.cur_output.finished - async def _aresult_step(self): # TODO: error handling or raise exception? - response = await self.aqueue.get() - if response is None: - raise Exception("ScaffoldingLlm execution failed") - self._handle_response(response) + obj = await self.aqueue.get() + if obj is None: + self._done = True + else: # obj is ScaffoldingOutput + self.outputs[0] = obj def result(self, timeout: Optional[float] = None) -> "ScaffoldingResult": - if not self.finished: + if not self._done: loop = asyncio.get_event_loop() asyncio.run_coroutine_threadsafe(self.aresult(), loop).result() return self async def aresult(self) -> "ScaffoldingResult": - while not self.finished: + while not self._done: await self._aresult_step() return self def __await__(self): return self.aresult().__await__() - def __iter__(self): - return self - - def __next__(self): - if self._done and self.finished: - raise StopIteration - - self._result_step() - return self - def __aiter__(self): return self async def __anext__(self): - if self.finished: - self.streaming_event.set() if self.streaming_event else None - if self._done and self.finished: + if self._done: raise StopAsyncIteration await self._aresult_step() return self - - def _handle_response(self, response: GenerationResult): - self.cur_output = response diff --git a/tensorrt_llm/scaffolding/scaffolding_llm.py b/tensorrt_llm/scaffolding/scaffolding_llm.py index 13afccf073..f9755e8e37 100644 --- a/tensorrt_llm/scaffolding/scaffolding_llm.py +++ b/tensorrt_llm/scaffolding/scaffolding_llm.py @@ -34,7 +34,6 @@ class ScaffoldingLlm: self.task_queue = asyncio.Queue() self.main_loop_stop_event = asyncio.Event() self.shutdown_event = asyncio.Event() - self.streaming_event = asyncio.Event() if self.own_loop: self._run_main_loop_thread() else: @@ -82,10 +81,10 @@ class ScaffoldingLlm: ] await asyncio.gather(*async_tasks) for task in tasks: - if getattr(task, 'streaming', False): - await request.result.set_output_async(task.result) - self.streaming_event.clear() - await self.streaming_event.wait() + if task.streaming_output_flag: + for output in task.streaming_output_list: + request.result.set_output_streaming(output) + task.streaming_output_list = [] async def _handle_parallel_process(self, tasks: ParallelProcess, @@ -172,7 +171,7 @@ class ScaffoldingLlm: self.main_loop_thread.start() def generate_async(self, prompt: str) -> ScaffoldingResult: - result = ScaffoldingResult(self.streaming_event) + result = ScaffoldingResult() async def put_request(): try: diff --git a/tensorrt_llm/scaffolding/task.py b/tensorrt_llm/scaffolding/task.py index eb3fe4bec1..7ee8397193 100644 --- a/tensorrt_llm/scaffolding/task.py +++ b/tensorrt_llm/scaffolding/task.py @@ -5,20 +5,36 @@ from typing import Any, Dict, List, Optional, Union import torch -from tensorrt_llm.executor.result import GenerationResult, TokenLogprobs +from tensorrt_llm.executor.result import TokenLogprobs + +from .result import ScaffoldingOutput @dataclass class Task: - # Reserve for custom input params. - custom_input_params: Optional[dict] = None - # Scaffolding delivers the task to the Worker by worker_tag. worker_tag: str = field(default=None) + # For streaming output. + streaming_output_flag: bool = field(default=False) + streaming_output_list: list[Any] = field(default_factory=list) + + # Reserve for custom input params. + custom_input_params: Optional[dict] = None + # Reserve for custom output params. custom_output_params: Optional[dict] = None + @staticmethod + def create_from_prompt(prompt: str) -> "Task": + pass + + def create_scaffolding_output(self) -> ScaffoldingOutput: + pass + + def create_scaffolding_output_stream(self) -> List[ScaffoldingOutput]: + pass + class TaskStatus(Enum): SUCCESS = "success" @@ -33,7 +49,7 @@ class GenerationTask(Task): input_str: Optional[str] = None skip_tokenizer: bool = False skip_detokenizer: bool = False - streaming: bool = False + #streaming: bool = False # sampling params for openai # Ordered by official OpenAI API documentation @@ -63,48 +79,14 @@ class GenerationTask(Task): worker_tag: Union[str, "Controller.WorkerTag"] = None # result field - # link to TRTLLM's GenerationResult, for async update in streaming mode - _result: Optional[GenerationResult] = None + output_str: Optional[str] = None + output_tokens: Optional[List[int]] = None + # TODO: support openai API format context logits + context_logits: Optional[torch.Tensor] = None + # TODO: don't not use TokenLogprobs for general support + logprobs: Optional[TokenLogprobs] = None customized_result_fields: Dict[str, Any] = field(default_factory=dict) - @property - def result(self) -> GenerationResult: - return self._result - - @result.setter - def result(self, result: GenerationResult) -> None: - self._result = result - - @property - def outputs(self) -> Optional[List[dict]]: - return self._result.outputs if self._result else None - - @property - def output_tokens(self) -> List[int]: - return self._result.outputs[0].token_ids if self._result else None - - @property - def output_str(self) -> Optional[str]: - return self._result.outputs[0].text if self._result else None - - @output_str.setter - def output_str(self, output) -> Optional[str]: - assert self.result - self._result.outputs[0].text = output - - @property - def cumulative_logprob(self) -> Optional[float]: - return self._result.outputs[ - 0].cumulative_logprob if self._result else None - - @property - def logprobs(self) -> Optional[TokenLogprobs]: - return self._result.outputs[0].logprobs if self._result else None - - @property - def context_logits(self) -> Optional[torch.Tensor]: - return self._result.context_logits if self._result else None - @staticmethod def create_from_prompt(prompt: str) -> "GenerationTask": task = GenerationTask() @@ -113,8 +95,8 @@ class GenerationTask(Task): task.skip_detokenizer = False return task - def create_scaffolding_output(self) -> GenerationResult: - return self._result + def create_scaffolding_output(self) -> ScaffoldingOutput: + return ScaffoldingOutput(self.output_str, self.output_tokens) @dataclass @@ -148,3 +130,21 @@ class RewardTask(Task): # input field input_tokens: Optional[List[int]] = field(default=None) input_str: Optional[str] = field(default=None) + + +@dataclass +class StreamGenerationTask(GenerationTask): + # input field + # if the flag is set to True, the worker will cancel the generation work + cancel_flag: Optional[bool] = field(default=False) + # the task will be returned to the controller with at least new streaming_step tokens + # if the streaming_step is set to 0, + # the task will be returned to the controller immediately with + # new tokens that have already been generated. + streaming_step: Optional[int] = field(default=1) + + #result field + # worker set this field and identify the same task by this field + request_handle: Any = field(default=None) + # worker set this field to True when the generation is finished + end_flag: bool = field(default=False) diff --git a/tensorrt_llm/scaffolding/worker.py b/tensorrt_llm/scaffolding/worker.py index 95d9d23995..95a56f57fd 100644 --- a/tensorrt_llm/scaffolding/worker.py +++ b/tensorrt_llm/scaffolding/worker.py @@ -1,4 +1,5 @@ import asyncio +import copy from abc import ABC from typing import Callable, Optional @@ -6,10 +7,11 @@ import openai from transformers import AutoTokenizer from tensorrt_llm import LLM -from tensorrt_llm.executor import GenerationExecutor +from tensorrt_llm.executor import GenerationExecutor, GenerationResult from tensorrt_llm.llmapi.llm_args import KvCacheConfig, SchedulerConfig from tensorrt_llm.sampling_params import SamplingParams +from .result import ScaffoldingOutput from .task import GenerationTask, StreamGenerationTask, Task, TaskStatus ExecutorCls = GenerationExecutor @@ -96,7 +98,7 @@ class OpenaiWorker(Worker): def fill_generation_task_with_response(self, task: GenerationTask, response: openai.Completion): task.output_str = response.choices[0].text - task.logprobs = response.choices[0].logprobs + task.output_tokens = response.choices[0].token_ids async def generation_handler(self, task: GenerationTask) -> TaskStatus: params = self.convert_task_params(task) @@ -190,55 +192,70 @@ class TRTLLMWorker(Worker): logprobs=task.num_logprobs) return sampling_params + async def streaming_generate_helper(self, generate_result, step_at_least, + streaming_output_list): + step = 0 + while not generate_result._done: + async_task = asyncio.create_task(generate_result._aresult_step()) + if step_at_least and step >= step_at_least and not async_task.done( + ): + async_task.cancel() + break + await async_task + step += 1 + # do not put the last token to the streaming_output_list + if streaming_output_list is not None and not generate_result._done: + streaming_output_list.append( + ScaffoldingOutput( + generate_result.outputs[0].text, + copy.deepcopy(generate_result.outputs[0].token_ids))) + + def fill_task_with_result(self, task: GenerationTask, + result: GenerationResult): + task.output_str = result.outputs[0].text + task.output_tokens = result.outputs[0].token_ids + task.context_logits = result.context_logits + task.logprobs = result.outputs[0].logprobs + async def generation_handler(self, task: GenerationTask) -> TaskStatus: sampling_params = self.convert_task_params(task) - # If the task is streaming, we will return result directly for - # async iteration outside. Otherwise, we will wait. - if task.streaming: + if task.streaming_output_flag: result = self.llm.generate_async(task.input_str, sampling_params=sampling_params, streaming=True) + await self.streaming_generate_helper(result, None, + task.streaming_output_list) else: result = await self.llm.generate_async( task.input_str, sampling_params=sampling_params) - task.result = result + #task.result = result + self.fill_task_with_result(task, result) # TODO: error handle return TaskStatus.SUCCESS async def stream_generation_handler( self, task: StreamGenerationTask) -> TaskStatus: - - async def get_step_or_more_tokens(task: StreamGenerationTask): - if task.cancel_flag: - task.end_flag = True - task.request_handle.abort() - return TaskStatus.SUCCESS - - for _ in range(task.streaming_step): - await task.request_handle._aresult_step() - if task.request_handle._done: - break - - while not task.request_handle._done: - async_task = asyncio.create_task( - task.request_handle._aresult_step()) - if not async_task.done(): - async_task.cancel() - break - - if task.request_handle._done: - task.end_flag = True - - if getattr(task, 'end_flag', False): - return TaskStatus.SUCCESS + sampling_params = self.convert_task_params(task) if task.request_handle is None: - sampling_params = self.convert_task_params(task) task.request_handle = self.llm.generate_async( task.input_str, sampling_params=sampling_params, streaming=True) - task._result = task.request_handle - await get_step_or_more_tokens(task) + + if task.cancel_flag: + task.end_flag = True + task.request_handle.abort() + return TaskStatus.SUCCESS + + await self.streaming_generate_helper( + task.request_handle, task.streaming_step, + task.streaming_output_queue if task.streaming_output_flag else None) + + self.fill_task_with_result(task, task.request_handle) + + if task.request_handle._done: + task.end_flag = True + return TaskStatus.SUCCESS def shutdown(self): if self.own_llm: diff --git a/tests/unittest/scaffolding/test_bench.py b/tests/unittest/scaffolding/test_bench.py index a65584d4c4..4350b46507 100644 --- a/tests/unittest/scaffolding/test_bench.py +++ b/tests/unittest/scaffolding/test_bench.py @@ -56,6 +56,6 @@ def test_scaffolding_benchmark(): assert len(results) == requests_num assert len(requests_execution_time) == requests_num - assert results[0].cur_output == OUTPUT_STR + assert results[0].outputs[0].text == OUTPUT_STR assert results[0].task_collections[ "bench_dummy_collection"].output_len == len(OUTPUT_STR) diff --git a/tests/unittest/scaffolding/test_scaffolding.py b/tests/unittest/scaffolding/test_scaffolding.py index b736ea6425..d6d3229365 100644 --- a/tests/unittest/scaffolding/test_scaffolding.py +++ b/tests/unittest/scaffolding/test_scaffolding.py @@ -49,8 +49,8 @@ def test_unbatched_scaffolding_sync(default_prompt, deepseek_distill_7b_path): scaffolding_llm = create_scaffolding_llm_with_native_generation_controller( deepseek_distill_7b_path) result = scaffolding_llm.generate(default_prompt) - assert isinstance(result.output.output_str, str) and len( - result.output.output_str) > 0, "Output should be a non-empty string" + assert isinstance(result.outputs[0].text, str) and len( + result.outputs[0].text) > 0, "Output should be a non-empty string" scaffolding_llm.shutdown(shutdown_workers=True) @@ -62,8 +62,8 @@ def test_batched_scaffolding_sync(default_prompt, deepseek_distill_7b_path): results = scaffolding_llm.generate(prompts) assert len(results) == batch_size for result in results: - assert isinstance(result.output.output_str, str) and len( - result.output.output_str) > 0, "Output should be a non-empty string" + assert isinstance(result.outputs[0].text, str) and len( + result.outputs[0].text) > 0, "Output should be a non-empty string" scaffolding_llm.shutdown(shutdown_workers=True) @@ -74,8 +74,8 @@ def test_async_scaffolding_generation(default_prompt, deepseek_distill_7b_path): deepseek_distill_7b_path) future = scaffolding_llm.generate_async(default_prompt) result = await future.aresult() - assert isinstance(result.output.output_str, str) and len( - result.output.output_str) > 0, "Output should be a non-empty string" + assert isinstance(result.outputs[0].text, str) and len( + result.outputs[0].text) > 0, "Output should be a non-empty string" scaffolding_llm.shutdown(shutdown_workers=True) import asyncio @@ -86,6 +86,6 @@ def test_majority_vote(default_prompt, deepseek_distill_7b_path): scaffolding_llm = create_scaffolding_llm_with_majority_vote_controller( deepseek_distill_7b_path, samples_num=3) result = scaffolding_llm.generate(default_prompt) - assert isinstance(result.output.output_str, str) and len( - result.output.output_str) > 0, "Output should be a non-empty string" + assert isinstance(result.outputs[0].text, str) and len( + result.outputs[0].text) > 0, "Output should be a non-empty string" scaffolding_llm.shutdown(shutdown_workers=True)