[None][feat] Refactor scaffolding streaming feature and fix openai wo… (#8622)

Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
This commit is contained in:
WeiHaocheng 2025-10-30 16:02:40 +08:00 committed by GitHub
parent a4f75399b9
commit cc286687c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 164 additions and 337 deletions

View File

@ -1,10 +0,0 @@
This example shows how to use the `StreamGenerationTask` and `stream_generation_handler` to enable efficient streaming-based generation workflows.
How to run the example?
```bash
python stream_generation_run.py
```
See more detail on [tensorrt_llm/scaffolding/contrib/AsyncGeneration/README.md](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/scaffolding/contrib/AsyncGeneration/README.md).

View File

@ -1,104 +0,0 @@
import argparse
from stream_generation_controller import NativeStreamGenerationController
from tensorrt_llm.scaffolding import ScaffoldingLlm, TRTLLMWorker
from tensorrt_llm.scaffolding.contrib import (StreamGenerationTask,
stream_generation_handler)
def parse_arguments():
parser = argparse.ArgumentParser()
# .e.g. DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B
parser.add_argument(
'--model_dir',
type=str,
required=True,
help="Path to the directory containing the generation model")
parser.add_argument('--run_type', type=str, default='original')
args = parser.parse_args()
return args
def test(prompts, proposer_worker):
prototype_controller = NativeStreamGenerationController(
sampling_params={"temperature": 0.9})
llm = ScaffoldingLlm(
prototype_controller,
{NativeStreamGenerationController.WorkerTag.STREAM: proposer_worker},
)
results = llm.generate(prompts)
for result in results:
print(result.output.output_str)
print(f'test main shutting down...')
llm.shutdown()
print(f'test worker shutting down...')
proposer_worker.shutdown()
print(f'test main shut down done')
def test_step(prompts, proposer_worker):
prototype_controller = NativeStreamGenerationController()
prototype_controller.set_stream_step(20)
llm = ScaffoldingLlm(
prototype_controller,
{NativeStreamGenerationController.WorkerTag.STREAM: proposer_worker},
)
results = llm.generate(prompts)
for result in results:
print(result.output.output_str)
print(f'test step main shutting down...')
llm.shutdown()
print(f'test step worker shutting down...')
proposer_worker.shutdown()
print(f'test step main shut down done')
def test_cancel(prompts, proposer_worker):
prototype_controller = NativeStreamGenerationController()
prototype_controller.set_output_threshold(200)
llm = ScaffoldingLlm(
prototype_controller,
{NativeStreamGenerationController.WorkerTag.STREAM: proposer_worker},
)
results = llm.generate(prompts)
for result in results:
print(result.output.output_str)
print(f'test cancel main shutting down...')
llm.shutdown()
print(f'test cancel worker shutting down...')
proposer_worker.shutdown()
print(f'test cancel main shut down done')
def main():
args = parse_arguments()
prompts = [
"Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\r\n\r\n",
"There exist real numbers $x$ and $y$, both greater than 1, such that $\\log_x\\left(y^x\\right)=\\log_y\\left(x^{4y}\\right)=10$. Find $xy$.",
"Find the largest possible real part of \\[(75+117i)z+\\frac{96+144i}{z}\\]where $z$ is a complex number with $|z|=4$.",
]
llm_worker = TRTLLMWorker.init_with_new_llm(
args.model_dir,
backend="pytorch",
max_batch_size=32,
max_num_tokens=4096,
)
print(f'main llm worker init done')
llm_worker.register_task_handler(StreamGenerationTask,
stream_generation_handler)
if args.run_type == 'original':
test(prompts, llm_worker)
elif args.run_type == 'step':
test_step(prompts, llm_worker)
elif args.run_type == 'cancel':
test_cancel(prompts, llm_worker)
if __name__ == "__main__":
main()

View File

@ -65,15 +65,14 @@ def main():
if args.streaming:
async def task(prompt: str):
i = 0
step = 0
async for result in llm.generate_async(prompt):
i += 1
print(">>>", i, result)
async for output in result.cur_output:
print(">>>", i, len(output.outputs[0].token_ids), "\n",
output.outputs[0].text)
print(f">>> final output {len(result.outputs[0].token_ids)}\n",
result.outputs[0].text)
step += 1
tokens_num = len(
result.outputs[0].token_ids
) if result.outputs[0].token_ids is not None else 0
print(">>>", step, tokens_num, "\n", result.outputs[0].text)
print(f">>> final output {tokens_num}\n", result.outputs[0].text)
# Need to provide LLM's event loop to get results in the middle of the whole process.
asyncio.run_coroutine_threadsafe(task(prompts[0]), llm.loop).result()

View File

@ -4,8 +4,8 @@ import asyncio
from openai import AsyncOpenAI
from tensorrt_llm.scaffolding import OpenaiWorker, ScaffoldingLlm
from tensorrt_llm.scaffolding.contrib import (ChatTask, MCPController,
MCPWorker, chat_handler)
from tensorrt_llm.scaffolding.contrib.mcp import (ChatTask, MCPController,
MCPWorker, chat_handler)
def parse_arguments():
@ -28,7 +28,7 @@ def parse_arguments():
from openai import AsyncOpenAI
from tensorrt_llm.scaffolding import OpenaiWorker, ScaffoldingLlm
from tensorrt_llm.scaffolding.contrib import MCPController, MCPWorker
from tensorrt_llm.scaffolding.contrib.mcp import MCPController, MCPWorker
async def main():
@ -41,7 +41,7 @@ async def main():
]
API_KEY = args.API_KEY
urls = [
"http://0.0.0.0:8080/sse", "http://0.0.0.0:8081/sse",
#"http://0.0.0.0:8080/sse", "http://0.0.0.0:8081/sse",
"http://0.0.0.0:8082/sse"
]
print(f"API_KEY {API_KEY}")
@ -61,7 +61,7 @@ async def main():
future = llm.generate_async(prompts[0])
result = await future.aresult()
print(f"\nresult is {result.output.output_str}\n")
print(f"\nresult is {result.outputs[0].text}\n")
print(f'main shutting down...')
llm.shutdown()

View File

@ -53,14 +53,12 @@ def test_async(prompt, proposer_worker):
prototype_controller,
{NativeGenerationController.WorkerTag.GENERATION: proposer_worker},
)
i = 0
step = 0
async for result in llm.generate_async(prompt):
i += 1
print(">>>", i, result)
async for output in result.cur_output:
print(">>>", i, len(output.outputs[0].token_ids), "\n",
output.outputs[0].text)
step += 1
print(">>>", step, len(result.outputs[0].token_ids), "\n",
result.outputs[0].text)
print(f">>> final output {len(result.outputs[0].token_ids)}\n",
result.outputs[0].text)

View File

@ -104,7 +104,7 @@ def main():
prompt = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\r\n\r\n"
result = llm.generate(prompt)
extracted_answer = extract_answer_from_boxed(result.output.output_str)
extracted_answer = extract_answer_from_boxed(result.outputs[0].text)
print(f'extracted_answer={extracted_answer}')
llm.shutdown(shutdown_workers=True)

View File

@ -25,6 +25,7 @@ __all__ = [
"GenerationTask",
"StreamGenerationTask",
"RewardTask",
"StreamGenerationTask",
"Worker",
"OpenaiWorker",
"TRTOpenaiWorker",

View File

@ -1,46 +0,0 @@
## Overview
`StreamGenerationTask` is an extension of `GenerationTask` designed for token-level streaming generation in asynchronous LLM workflows using TensorRT-LLM. It enables the controller to receive partial results during generation, which is critical for real-time or latency-sensitive applications such as chatbots, speech generation, or UI-interactive systems.
---
## Features
- ✅ Supports **streamed token delivery** by step (e.g., `streaming_step=1`).
- ✅ Supports **cancellation** of generation using a flag (`cancel_flag=True`).
- ✅ Tracks **stream completion status** (`end_flag=True` when done).
- ✅ Integrated with a streaming-capable LLM interface (`generate_async`).
---
## Fields in `StreamGenerationTask`
| Field | Description |
|-------|-------------|
| `cancel_flag` | If `True`, the generation will be cancelled on the worker side. |
| `streaming_step` | Number of new tokens required before returning control to the controller. If set to `0`, the task is returned immediately if any new tokens are available. |
| `request_handle` | Internal handle for the streaming generation (used by the worker). |
| `end_flag` | Indicates whether generation is finished. |
| `output_str` / `output_tokens` / `logprobs` | Outputs after each generation step. |
---
## Usage in Controller/Worker
The Controller can utilize `StreamGenerationTask` to enable efficient streaming-based generation workflows:
- It sends tasks to the worker, which returns them when the number of newly generated tokens reaches the specified `streaming_step`.
- It can cancel long-running tasks by setting `task.cancel_flag = True` when the number of generated tokens exceeds a predefined threshold.
To support this behavior on the worker side, we have implemented `stream_generation_handler` and you need to register it with the worker in your project. This handler should process `StreamGenerationTask` instances step-by-step and update relevant fields such as `output_tokens`, `output_str`.
This design allows the controller and worker to coordinate generation in a token-efficient and responsive manner, ideal for real-time applications.
You can see more details in `stream_generation_controller.py` and `stream_generation_task.py` from `examples/scaffolding/contrib/AsyncGeneration`.
## Notes
Remember to register the `stream_generation_handler` with the `TRTLLMWorker`.
## TODO
- Add error handling for failed `request_handle`.
- Support retry or backoff mechanism if generation stalls.

View File

@ -1,3 +0,0 @@
from .stream_generation import StreamGenerationTask, stream_generation_handler
__all__ = ["stream_generation_handler", "StreamGenerationTask"]

View File

@ -126,18 +126,10 @@ class DynasorGenerationController(Controller):
probe_answers[-self.certainty_threshold:])
== self.certainty_threshold
and sum(probe_certain_count) == self.certainty_threshold):
tasks[0].result = probe_task.result
# If the current prompt indicates the chain-of-thought phase has ended, use one type of suffix.
if "</think>" in current_prompt:
tasks[0].output_str = (current_prompt + self.answer_suffix +
probe_answers[-1] + "}\n\\]")
return
else:
# Otherwise, use the suffix with marker to transition clearly.
tasks[0].output_str = (current_prompt +
self.answer_suffix_with_marker +
probe_answers[-1] + "}\n\\]")
return
suffix = self.answer_suffix if "</think>" in current_prompt else self.answer_suffix_with_marker
suffix += probe_answers[-1] + "}\n\\]"
current_prompt += suffix
break
# If not confident, do another round of generation
# Append the newly generated text from the proposer to the current prompt for the next iteration.
@ -145,7 +137,6 @@ class DynasorGenerationController(Controller):
# If the maximum token limit is reached without satisfying the certainty condition,
# output the accumulated prompt as the final output.
tasks[0].result = proposer_task.result
tasks[0].output_str = current_prompt
return

View File

@ -19,11 +19,9 @@ def add_param_if_not_none(params, key, candidate_values):
def combine_params_with_chat_task(worker, params: dict, task: ChatTask):
params["messages"] = task.messages
add_param_if_not_none(params, "max_tokens",
[task.max_tokens, worker.max_tokens])
add_param_if_not_none(params, "temperature",
[task.temperature, worker.temperature])
add_param_if_not_none(params, "top_p", [task.top_p, worker.top_p])
add_param_if_not_none(params, "max_tokens", [task.max_tokens])
add_param_if_not_none(params, "temperature", [task.temperature])
add_param_if_not_none(params, "top_p", [task.top_p])
add_param_if_not_none(params, "tools", [task.tools])

View File

@ -67,7 +67,7 @@ class NativeGenerationController(Controller):
for key, value in self.sampling_params.items():
if getattr(task, key) is None:
setattr(task, key, value)
task.streaming = self.streaming
task.streaming_output_flag = self.streaming
yield tasks

View File

@ -1,80 +1,67 @@
import asyncio
from typing import Mapping, Optional
from dataclasses import dataclass
from typing import Any, List, Mapping, Optional, Union
from tensorrt_llm.executor.result import GenerationResult
@dataclass
class ScaffoldingOutput:
text: str
token_ids: List[int]
class ScaffoldingResult:
def __init__(self, streaming_event: Optional[asyncio.Event] = None):
def __init__(self):
super().__init__()
self.aqueue = asyncio.Queue()
self.cur_output: GenerationResult = None
#self.cur_output: GenerationResult = None
self.outputs = []
# only support one output for now, so use an empty obj to init
self.outputs.append(ScaffoldingOutput("", []))
self._done = False
self.task_collections = None
self.streaming_event = streaming_event
def set_output(self, output: GenerationResult):
def set_output(self, output: Union[ScaffoldingOutput, Any]):
if isinstance(output, ScaffoldingOutput):
self.set_output_streaming(output)
# terminate
self.set_output_streaming(None)
def set_output_streaming(self, output: Union[ScaffoldingOutput, Any]):
self.aqueue.put_nowait(output)
self._done = True
async def set_output_async(self, output: GenerationResult):
await self.aqueue.put(output)
def set_task_collections(self, task_collections: Mapping[str,
"TaskCollection"]):
self.task_collections = task_collections
@property
def outputs(self):
return self.cur_output.outputs if self.cur_output else None
@property
def finished(self) -> bool:
return self.cur_output is not None and self.cur_output.finished
async def _aresult_step(self):
# TODO: error handling or raise exception?
response = await self.aqueue.get()
if response is None:
raise Exception("ScaffoldingLlm execution failed")
self._handle_response(response)
obj = await self.aqueue.get()
if obj is None:
self._done = True
else: # obj is ScaffoldingOutput
self.outputs[0] = obj
def result(self, timeout: Optional[float] = None) -> "ScaffoldingResult":
if not self.finished:
if not self._done:
loop = asyncio.get_event_loop()
asyncio.run_coroutine_threadsafe(self.aresult(), loop).result()
return self
async def aresult(self) -> "ScaffoldingResult":
while not self.finished:
while not self._done:
await self._aresult_step()
return self
def __await__(self):
return self.aresult().__await__()
def __iter__(self):
return self
def __next__(self):
if self._done and self.finished:
raise StopIteration
self._result_step()
return self
def __aiter__(self):
return self
async def __anext__(self):
if self.finished:
self.streaming_event.set() if self.streaming_event else None
if self._done and self.finished:
if self._done:
raise StopAsyncIteration
await self._aresult_step()
return self
def _handle_response(self, response: GenerationResult):
self.cur_output = response

View File

@ -34,7 +34,6 @@ class ScaffoldingLlm:
self.task_queue = asyncio.Queue()
self.main_loop_stop_event = asyncio.Event()
self.shutdown_event = asyncio.Event()
self.streaming_event = asyncio.Event()
if self.own_loop:
self._run_main_loop_thread()
else:
@ -82,10 +81,10 @@ class ScaffoldingLlm:
]
await asyncio.gather(*async_tasks)
for task in tasks:
if getattr(task, 'streaming', False):
await request.result.set_output_async(task.result)
self.streaming_event.clear()
await self.streaming_event.wait()
if task.streaming_output_flag:
for output in task.streaming_output_list:
request.result.set_output_streaming(output)
task.streaming_output_list = []
async def _handle_parallel_process(self,
tasks: ParallelProcess,
@ -172,7 +171,7 @@ class ScaffoldingLlm:
self.main_loop_thread.start()
def generate_async(self, prompt: str) -> ScaffoldingResult:
result = ScaffoldingResult(self.streaming_event)
result = ScaffoldingResult()
async def put_request():
try:

View File

@ -5,20 +5,36 @@ from typing import Any, Dict, List, Optional, Union
import torch
from tensorrt_llm.executor.result import GenerationResult, TokenLogprobs
from tensorrt_llm.executor.result import TokenLogprobs
from .result import ScaffoldingOutput
@dataclass
class Task:
# Reserve for custom input params.
custom_input_params: Optional[dict] = None
# Scaffolding delivers the task to the Worker by worker_tag.
worker_tag: str = field(default=None)
# For streaming output.
streaming_output_flag: bool = field(default=False)
streaming_output_list: list[Any] = field(default_factory=list)
# Reserve for custom input params.
custom_input_params: Optional[dict] = None
# Reserve for custom output params.
custom_output_params: Optional[dict] = None
@staticmethod
def create_from_prompt(prompt: str) -> "Task":
pass
def create_scaffolding_output(self) -> ScaffoldingOutput:
pass
def create_scaffolding_output_stream(self) -> List[ScaffoldingOutput]:
pass
class TaskStatus(Enum):
SUCCESS = "success"
@ -33,7 +49,7 @@ class GenerationTask(Task):
input_str: Optional[str] = None
skip_tokenizer: bool = False
skip_detokenizer: bool = False
streaming: bool = False
#streaming: bool = False
# sampling params for openai
# Ordered by official OpenAI API documentation
@ -63,48 +79,14 @@ class GenerationTask(Task):
worker_tag: Union[str, "Controller.WorkerTag"] = None
# result field
# link to TRTLLM's GenerationResult, for async update in streaming mode
_result: Optional[GenerationResult] = None
output_str: Optional[str] = None
output_tokens: Optional[List[int]] = None
# TODO: support openai API format context logits
context_logits: Optional[torch.Tensor] = None
# TODO: don't not use TokenLogprobs for general support
logprobs: Optional[TokenLogprobs] = None
customized_result_fields: Dict[str, Any] = field(default_factory=dict)
@property
def result(self) -> GenerationResult:
return self._result
@result.setter
def result(self, result: GenerationResult) -> None:
self._result = result
@property
def outputs(self) -> Optional[List[dict]]:
return self._result.outputs if self._result else None
@property
def output_tokens(self) -> List[int]:
return self._result.outputs[0].token_ids if self._result else None
@property
def output_str(self) -> Optional[str]:
return self._result.outputs[0].text if self._result else None
@output_str.setter
def output_str(self, output) -> Optional[str]:
assert self.result
self._result.outputs[0].text = output
@property
def cumulative_logprob(self) -> Optional[float]:
return self._result.outputs[
0].cumulative_logprob if self._result else None
@property
def logprobs(self) -> Optional[TokenLogprobs]:
return self._result.outputs[0].logprobs if self._result else None
@property
def context_logits(self) -> Optional[torch.Tensor]:
return self._result.context_logits if self._result else None
@staticmethod
def create_from_prompt(prompt: str) -> "GenerationTask":
task = GenerationTask()
@ -113,8 +95,8 @@ class GenerationTask(Task):
task.skip_detokenizer = False
return task
def create_scaffolding_output(self) -> GenerationResult:
return self._result
def create_scaffolding_output(self) -> ScaffoldingOutput:
return ScaffoldingOutput(self.output_str, self.output_tokens)
@dataclass
@ -148,3 +130,21 @@ class RewardTask(Task):
# input field
input_tokens: Optional[List[int]] = field(default=None)
input_str: Optional[str] = field(default=None)
@dataclass
class StreamGenerationTask(GenerationTask):
# input field
# if the flag is set to True, the worker will cancel the generation work
cancel_flag: Optional[bool] = field(default=False)
# the task will be returned to the controller with at least new streaming_step tokens
# if the streaming_step is set to 0,
# the task will be returned to the controller immediately with
# new tokens that have already been generated.
streaming_step: Optional[int] = field(default=1)
#result field
# worker set this field and identify the same task by this field
request_handle: Any = field(default=None)
# worker set this field to True when the generation is finished
end_flag: bool = field(default=False)

View File

@ -1,4 +1,5 @@
import asyncio
import copy
from abc import ABC
from typing import Callable, Optional
@ -6,10 +7,11 @@ import openai
from transformers import AutoTokenizer
from tensorrt_llm import LLM
from tensorrt_llm.executor import GenerationExecutor
from tensorrt_llm.executor import GenerationExecutor, GenerationResult
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, SchedulerConfig
from tensorrt_llm.sampling_params import SamplingParams
from .result import ScaffoldingOutput
from .task import GenerationTask, StreamGenerationTask, Task, TaskStatus
ExecutorCls = GenerationExecutor
@ -96,7 +98,7 @@ class OpenaiWorker(Worker):
def fill_generation_task_with_response(self, task: GenerationTask,
response: openai.Completion):
task.output_str = response.choices[0].text
task.logprobs = response.choices[0].logprobs
task.output_tokens = response.choices[0].token_ids
async def generation_handler(self, task: GenerationTask) -> TaskStatus:
params = self.convert_task_params(task)
@ -190,55 +192,70 @@ class TRTLLMWorker(Worker):
logprobs=task.num_logprobs)
return sampling_params
async def streaming_generate_helper(self, generate_result, step_at_least,
streaming_output_list):
step = 0
while not generate_result._done:
async_task = asyncio.create_task(generate_result._aresult_step())
if step_at_least and step >= step_at_least and not async_task.done(
):
async_task.cancel()
break
await async_task
step += 1
# do not put the last token to the streaming_output_list
if streaming_output_list is not None and not generate_result._done:
streaming_output_list.append(
ScaffoldingOutput(
generate_result.outputs[0].text,
copy.deepcopy(generate_result.outputs[0].token_ids)))
def fill_task_with_result(self, task: GenerationTask,
result: GenerationResult):
task.output_str = result.outputs[0].text
task.output_tokens = result.outputs[0].token_ids
task.context_logits = result.context_logits
task.logprobs = result.outputs[0].logprobs
async def generation_handler(self, task: GenerationTask) -> TaskStatus:
sampling_params = self.convert_task_params(task)
# If the task is streaming, we will return result directly for
# async iteration outside. Otherwise, we will wait.
if task.streaming:
if task.streaming_output_flag:
result = self.llm.generate_async(task.input_str,
sampling_params=sampling_params,
streaming=True)
await self.streaming_generate_helper(result, None,
task.streaming_output_list)
else:
result = await self.llm.generate_async(
task.input_str, sampling_params=sampling_params)
task.result = result
#task.result = result
self.fill_task_with_result(task, result)
# TODO: error handle
return TaskStatus.SUCCESS
async def stream_generation_handler(
self, task: StreamGenerationTask) -> TaskStatus:
async def get_step_or_more_tokens(task: StreamGenerationTask):
if task.cancel_flag:
task.end_flag = True
task.request_handle.abort()
return TaskStatus.SUCCESS
for _ in range(task.streaming_step):
await task.request_handle._aresult_step()
if task.request_handle._done:
break
while not task.request_handle._done:
async_task = asyncio.create_task(
task.request_handle._aresult_step())
if not async_task.done():
async_task.cancel()
break
if task.request_handle._done:
task.end_flag = True
if getattr(task, 'end_flag', False):
return TaskStatus.SUCCESS
sampling_params = self.convert_task_params(task)
if task.request_handle is None:
sampling_params = self.convert_task_params(task)
task.request_handle = self.llm.generate_async(
task.input_str, sampling_params=sampling_params, streaming=True)
task._result = task.request_handle
await get_step_or_more_tokens(task)
if task.cancel_flag:
task.end_flag = True
task.request_handle.abort()
return TaskStatus.SUCCESS
await self.streaming_generate_helper(
task.request_handle, task.streaming_step,
task.streaming_output_queue if task.streaming_output_flag else None)
self.fill_task_with_result(task, task.request_handle)
if task.request_handle._done:
task.end_flag = True
return TaskStatus.SUCCESS
def shutdown(self):
if self.own_llm:

View File

@ -56,6 +56,6 @@ def test_scaffolding_benchmark():
assert len(results) == requests_num
assert len(requests_execution_time) == requests_num
assert results[0].cur_output == OUTPUT_STR
assert results[0].outputs[0].text == OUTPUT_STR
assert results[0].task_collections[
"bench_dummy_collection"].output_len == len(OUTPUT_STR)

View File

@ -49,8 +49,8 @@ def test_unbatched_scaffolding_sync(default_prompt, deepseek_distill_7b_path):
scaffolding_llm = create_scaffolding_llm_with_native_generation_controller(
deepseek_distill_7b_path)
result = scaffolding_llm.generate(default_prompt)
assert isinstance(result.output.output_str, str) and len(
result.output.output_str) > 0, "Output should be a non-empty string"
assert isinstance(result.outputs[0].text, str) and len(
result.outputs[0].text) > 0, "Output should be a non-empty string"
scaffolding_llm.shutdown(shutdown_workers=True)
@ -62,8 +62,8 @@ def test_batched_scaffolding_sync(default_prompt, deepseek_distill_7b_path):
results = scaffolding_llm.generate(prompts)
assert len(results) == batch_size
for result in results:
assert isinstance(result.output.output_str, str) and len(
result.output.output_str) > 0, "Output should be a non-empty string"
assert isinstance(result.outputs[0].text, str) and len(
result.outputs[0].text) > 0, "Output should be a non-empty string"
scaffolding_llm.shutdown(shutdown_workers=True)
@ -74,8 +74,8 @@ def test_async_scaffolding_generation(default_prompt, deepseek_distill_7b_path):
deepseek_distill_7b_path)
future = scaffolding_llm.generate_async(default_prompt)
result = await future.aresult()
assert isinstance(result.output.output_str, str) and len(
result.output.output_str) > 0, "Output should be a non-empty string"
assert isinstance(result.outputs[0].text, str) and len(
result.outputs[0].text) > 0, "Output should be a non-empty string"
scaffolding_llm.shutdown(shutdown_workers=True)
import asyncio
@ -86,6 +86,6 @@ def test_majority_vote(default_prompt, deepseek_distill_7b_path):
scaffolding_llm = create_scaffolding_llm_with_majority_vote_controller(
deepseek_distill_7b_path, samples_num=3)
result = scaffolding_llm.generate(default_prompt)
assert isinstance(result.output.output_str, str) and len(
result.output.output_str) > 0, "Output should be a non-empty string"
assert isinstance(result.outputs[0].text, str) and len(
result.outputs[0].text) > 0, "Output should be a non-empty string"
scaffolding_llm.shutdown(shutdown_workers=True)