mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
89 lines
2.5 KiB
Python
89 lines
2.5 KiB
Python
import asyncio
|
|
from dataclasses import dataclass
|
|
from typing import Mapping, Optional
|
|
|
|
from tensorrt_llm.executor.result import GenerationResult
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class ScaffoldingOutput:
|
|
|
|
def __init__(self):
|
|
self.output_str = None
|
|
|
|
|
|
class ScaffoldingResult:
|
|
|
|
def __init__(self, streaming_event: Optional[asyncio.Event] = None):
|
|
super().__init__()
|
|
self.aqueue = asyncio.Queue()
|
|
self.cur_output = None
|
|
self._done = False
|
|
self.task_collections = None
|
|
self.streaming_event = streaming_event
|
|
|
|
def set_output(self, output: GenerationResult):
|
|
self.aqueue.put_nowait(output)
|
|
self._done = True
|
|
|
|
async def set_output_async(self, output: GenerationResult):
|
|
await self.aqueue.put(output)
|
|
|
|
def set_task_collections(self, task_collections: Mapping[str,
|
|
"TaskCollection"]):
|
|
self.task_collections = task_collections
|
|
|
|
@property
|
|
def outputs(self):
|
|
return self.cur_output.outputs if self.cur_output else None
|
|
|
|
@property
|
|
def finished(self) -> bool:
|
|
return self.cur_output is not None and self.cur_output.finished
|
|
|
|
async def _aresult_step(self):
|
|
# TODO: error handling or raise exception?
|
|
response = await self.aqueue.get()
|
|
if response is None:
|
|
raise Exception("ScaffoldingLlm execution failed")
|
|
self._handle_response(response)
|
|
|
|
def result(self, timeout: Optional[float] = None) -> "ScaffoldingResult":
|
|
if not self.finished:
|
|
loop = asyncio.get_event_loop()
|
|
asyncio.run_coroutine_threadsafe(self.aresult(), loop).result()
|
|
return self
|
|
|
|
async def aresult(self) -> "ScaffoldingResult":
|
|
while not self.finished:
|
|
await self._aresult_step()
|
|
return self
|
|
|
|
def __await__(self):
|
|
return self.aresult().__await__()
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self._done and self.finished:
|
|
raise StopIteration
|
|
|
|
self._result_step()
|
|
return self
|
|
|
|
def __aiter__(self):
|
|
return self
|
|
|
|
async def __anext__(self):
|
|
if self.finished:
|
|
self.streaming_event.set() if self.streaming_event else None
|
|
if self._done and self.finished:
|
|
raise StopAsyncIteration
|
|
|
|
await self._aresult_step()
|
|
return self
|
|
|
|
def _handle_response(self, response: GenerationResult):
|
|
self.cur_output = response
|