TensorRT-LLMs/tensorrt_llm/scaffolding/scaffolding_llm.py
WeiHaocheng 54d28718c7
feat: support benchmark on scaffolding (#3328) (#4286)
Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
2025-05-16 12:28:49 +08:00

251 lines
8.3 KiB
Python

import asyncio
import threading
import traceback
from collections import deque
from dataclasses import dataclass
from typing import Any, Generator, List, Mapping, Union
from .controller import Controller, ParallelProcess, ScaffoldingOutput
from .worker import Worker
@dataclass(frozen=True)
class ScaffoldingRequest:
prompt: str
kwargs: Mapping[str, Any]
controller: Controller
result: "ScaffoldingResult"
class ScaffoldingResult:
def __init__(self):
self._done = False
self.aqueue = asyncio.Queue()
self.output = None
self.task_collections = None
def set_output(self, output: ScaffoldingOutput):
self.aqueue.put_nowait(output)
def set_task_collections(self, task_collections: Mapping[str,
"TaskCollection"]):
self.task_collections = task_collections
async def aresult_step(self):
# TODO: error handling or raise exception?
self.output = await self.aqueue.get()
if self.output is None:
raise Exception("ScaffoldingLlm execution failed")
self._done = True
def result(self) -> "ScaffoldingResult":
if not self._done:
loop = asyncio.get_event_loop()
asyncio.run_coroutine_threadsafe(self.aresult(), loop).result()
return self
async def aresult(self) -> "ScaffoldingResult":
while not self._done:
await self.aresult_step()
return self
def __await__(self):
return self.aresult().__await__()
class ScaffoldingLlm:
def __init__(
self,
prototype_controller: Controller,
workers: Mapping[
str, Worker], # map of role of Crontroller to a worker instance
):
self.prototype_controller = prototype_controller
self.workers = workers
self.loop = self._get_loop()
asyncio.set_event_loop(self.loop)
self.task_queue = asyncio.Queue()
self.main_loop_stop_event = asyncio.Event()
self.shutdown_event = asyncio.Event()
if self.own_loop:
self._run_main_loop_thread()
else:
self._run_main_loop_coroutine()
# For top scheduler
self.running_req_count = 0
self.max_parallel_requests = 64
self.pending_queue = deque()
self.output_task_collection = False
def __enter__(self):
return self
def __exit__(self):
self.shutdown()
def _get_loop(self):
try:
self.own_loop = False
return asyncio.get_running_loop()
except RuntimeError:
self.own_loop = True
return asyncio.new_event_loop()
return None
async def _main_loop_async_func(self):
async def handle_controller_generator(gen: Generator):
for obj in gen:
if isinstance(obj, ParallelProcess):
await handle_parallel_process(obj)
else:
task_list = obj
async_tasks = []
for task in task_list:
task_worker_tag = task.worker_tag
assert task_worker_tag in self.workers.keys()
worker = self.workers[task_worker_tag]
async_tasks.append(
asyncio.create_task(worker.run_task(task)))
await asyncio.gather(*async_tasks)
async def handle_parallel_process(request: ParallelProcess):
async_tasks = []
for sub_gen in request.sub_gens:
async_task = asyncio.create_task(
handle_controller_generator(sub_gen))
async_tasks.append(async_task)
await asyncio.gather(*async_tasks)
async def handle_single_request(request: ScaffoldingRequest):
# warp to a generator without return value
def controller_generator_wrapper(request: ScaffoldingRequest):
scaffolding_output = yield from request.controller.generate(
request.prompt, **request.kwargs)
if self.output_task_collection:
request.result.set_task_collections(
request.controller.task_collections)
request.result.set_output(scaffolding_output)
try:
gen = controller_generator_wrapper(request)
await handle_controller_generator(gen)
except Exception as e:
# Catch the exception and set output to avoid the user thread to be hang
print('scaffoldingLlm handle request exception:', str(e))
traceback.print_exc()
request.result.set_output(None)
raise e
finally:
self.running_req_count -= 1
maybe_schedule()
def schedule_request(request: ScaffoldingRequest):
asyncio.create_task(handle_single_request(request))
self.running_req_count += 1
def maybe_schedule(request: ScaffoldingRequest = None):
if self.shutdown_event.is_set():
return
if request is not None:
self.pending_queue.append(request)
while self.running_req_count < self.max_parallel_requests and len(
self.pending_queue) > 0:
first_request = self.pending_queue.popleft()
schedule_request(first_request)
async def handle_event():
while True:
item = await self.task_queue.get()
if item is None:
return
elif isinstance(item, ScaffoldingRequest):
maybe_schedule(item)
else:
raise ValueError(
f'type of task_queue item ({type(item)}) is not supported'
)
handle_event_task = asyncio.create_task(handle_event())
await handle_event_task
self.main_loop_stop_event.set()
def _run_main_loop_coroutine(self):
asyncio.run_coroutine_threadsafe(self._main_loop_async_func(),
self.loop)
def _run_main_loop_thread(self):
def main_loop_thread():
self.loop.run_until_complete(self._main_loop_async_func())
self.main_loop_thread = threading.Thread(target=main_loop_thread)
self.main_loop_thread.start()
def generate_async(self, prompt: str) -> ScaffoldingResult:
result = ScaffoldingResult()
async def put_request():
request = ScaffoldingRequest(
prompt=prompt,
kwargs={},
result=result,
controller=self.prototype_controller.clone())
await self.task_queue.put(request)
asyncio.run_coroutine_threadsafe(put_request(), self.loop)
return result
def generate(
self, prompts: Union[str, List[str]]
) -> Union[ScaffoldingResult, List[ScaffoldingResult]]:
unbatched = not isinstance(prompts, list)
batched_prompts = [prompts] if unbatched else prompts
scaffolding_results = []
for prompt in batched_prompts:
scaffolding_results.append(self.generate_async(prompt))
for scaffolding_result in scaffolding_results:
scaffolding_result.result()
return scaffolding_results[0] if unbatched else scaffolding_results
def enable_output_task_collection(self):
self.output_task_collection = True
def shutdown(self, shutdown_workers=False):
def shutdown_workers():
for worker in self.workers.values():
worker.shutdown()
# Let the merge thread break
async def stop_task_on_loop():
await self.task_queue.put(None)
await self.main_loop_stop_event.wait()
asyncio.run_coroutine_threadsafe(stop_task_on_loop(), self.loop)
if self.own_loop:
self.main_loop_thread.join()
else:
# if we don't own the loop, we can't ensure the "stop_task_on_loop"
# is finished, so we need to set the shutdown event to make sure the main loop
# will not submit new tasks to workers.
self.shutdown_event.set()
if shutdown_workers:
shutdown_workers()