TensorRT-LLMs/tensorrt_llm/scaffolding/benchmark.py
WeiHaocheng 54d28718c7
feat: support benchmark on scaffolding (#3328) (#4286)
Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
2025-05-16 12:28:49 +08:00

102 lines
3.5 KiB
Python

import asyncio
import time
from typing import List, Mapping, Tuple, Type
from pydantic import BaseModel
from tensorrt_llm.scaffolding.scaffolding_llm import (ScaffoldingLlm,
ScaffoldingResult)
from tensorrt_llm.scaffolding.task_collection import (TaskCollection,
with_task_collection)
class ScaffoldingBenchRequest(BaseModel):
prompt: str
async def enqueue_requests(input_queue, requests):
for request in requests:
await input_queue.put(request)
await input_queue.put(None)
async def process_request(scaffolding_llm, request, output_queue, semaphore):
async with semaphore:
request_start_time = time.time()
result = scaffolding_llm.generate_async(request.prompt)
await result.aresult()
request_execution_time = time.time() - request_start_time
await output_queue.put((result, request_execution_time))
async def run_scaffolding_llm(scaffolding_llm, input_queue, output_queue,
concurrency):
semaphore = asyncio.Semaphore(concurrency)
tasks = set()
while True:
request = await input_queue.get()
if request is None:
break
task = asyncio.create_task(
process_request(scaffolding_llm, request, output_queue, semaphore))
tasks.add(task)
task.add_done_callback(tasks.discard)
await asyncio.gather(*tasks)
await output_queue.put(None)
def wrapper_prototype_controller_with_task_collection(scaffolding_llm,
task_collection_types):
prototype_controller_type = type(scaffolding_llm.prototype_controller)
controller_type_with_task_collection = prototype_controller_type
for name, task_collection_type in task_collection_types.items():
scaffolding_llm.prototype_controller.task_collections[
name] = task_collection_type()
controller_type_with_task_collection = with_task_collection(
name, task_collection_type)(controller_type_with_task_collection)
scaffolding_llm.enable_output_task_collection()
async def async_scaffolding_benchmark(
scaffolding_llm: ScaffoldingLlm,
task_collection_types: Mapping[str, Type[TaskCollection]],
requests: List[ScaffoldingBenchRequest],
concurrency: int) -> Tuple[List[ScaffoldingResult], List[float], float]:
wrapper_prototype_controller_with_task_collection(scaffolding_llm,
task_collection_types)
input_queue = asyncio.Queue()
output_queue = asyncio.Queue()
start_time = time.time()
results = []
requests_execution_time = []
enqueue_task = asyncio.create_task(enqueue_requests(input_queue, requests))
run_scaffolding_llm_task = asyncio.create_task(
run_scaffolding_llm(scaffolding_llm, input_queue, output_queue,
concurrency))
while True:
try:
item = await asyncio.wait_for(output_queue.get(), timeout=1.0)
if item is None:
break
result, request_execution_time = item
results.append(result)
requests_execution_time.append(request_execution_time)
except asyncio.TimeoutError:
continue
total_time = time.time() - start_time
enqueue_task.result()
run_scaffolding_llm_task.result()
return results, requests_execution_time, total_time