TensorRT-LLMs/tensorrt_llm/scaffolding/result.py
Zhenhuan Chen 992b273045
[https://nvbugs/5387375] fix(scaffolding): fix scaffolding aime test in test_e2e (#6140)
Signed-off-by: Zhenhuan Chen <chenzhh3671@gmail.com>
2025-07-18 10:34:37 +08:00

81 lines
2.3 KiB
Python

import asyncio
from typing import Mapping, Optional
from tensorrt_llm.executor.result import GenerationResult
class ScaffoldingResult:
def __init__(self, streaming_event: Optional[asyncio.Event] = None):
super().__init__()
self.aqueue = asyncio.Queue()
self.cur_output: GenerationResult = None
self._done = False
self.task_collections = None
self.streaming_event = streaming_event
def set_output(self, output: GenerationResult):
self.aqueue.put_nowait(output)
self._done = True
async def set_output_async(self, output: GenerationResult):
await self.aqueue.put(output)
def set_task_collections(self, task_collections: Mapping[str,
"TaskCollection"]):
self.task_collections = task_collections
@property
def outputs(self):
return self.cur_output.outputs if self.cur_output else None
@property
def finished(self) -> bool:
return self.cur_output is not None and self.cur_output.finished
async def _aresult_step(self):
# TODO: error handling or raise exception?
response = await self.aqueue.get()
if response is None:
raise Exception("ScaffoldingLlm execution failed")
self._handle_response(response)
def result(self, timeout: Optional[float] = None) -> "ScaffoldingResult":
if not self.finished:
loop = asyncio.get_event_loop()
asyncio.run_coroutine_threadsafe(self.aresult(), loop).result()
return self
async def aresult(self) -> "ScaffoldingResult":
while not self.finished:
await self._aresult_step()
return self
def __await__(self):
return self.aresult().__await__()
def __iter__(self):
return self
def __next__(self):
if self._done and self.finished:
raise StopIteration
self._result_step()
return self
def __aiter__(self):
return self
async def __anext__(self):
if self.finished:
self.streaming_event.set() if self.streaming_event else None
if self._done and self.finished:
raise StopAsyncIteration
await self._aresult_step()
return self
def _handle_response(self, response: GenerationResult):
self.cur_output = response