TensorRT-LLMs/tests/unittest/scaffolding/test_bench.py
WeiHaocheng cc286687c4
[None][feat] Refactor scaffolding streaming feature and fix openai wo… (#8622)
Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
2025-10-30 16:02:40 +08:00

62 lines
1.9 KiB
Python

import asyncio
from typing import List
from tensorrt_llm.scaffolding import (GenerationTask,
NativeGenerationController,
ScaffoldingBenchRequest, ScaffoldingLlm,
Task, TaskCollection, TaskStatus, Worker,
async_scaffolding_benchmark)
OUTPUT_STR = "Yes."
class DummyWorker(Worker):
async def dummy_generation_handler(self, task: GenerationTask):
task.result = OUTPUT_STR
return TaskStatus.SUCCESS
task_handlers = {GenerationTask: dummy_generation_handler}
class DummyTaskCollection(TaskCollection):
def __init__(self):
super().__init__()
self.output_len = 0
def before_yield(self, tasks: List[Task]):
pass
def after_yield(self, tasks: List[Task]):
self.output_len = len(tasks[0].result)
def test_scaffolding_benchmark():
task_collection_types = {"bench_dummy_collection": DummyTaskCollection}
prototype_controller = NativeGenerationController()
dummy_worker = DummyWorker()
workers = {NativeGenerationController.WorkerTag.GENERATION: dummy_worker}
scaffolding_llm = ScaffoldingLlm(prototype_controller, workers)
requests_num = 100
requests = [
ScaffoldingBenchRequest(prompt="Is today a nice day?")
for _ in range(requests_num)
]
concurrency = 10
results, requests_execution_time, total_time = asyncio.run(
async_scaffolding_benchmark(scaffolding_llm, task_collection_types,
requests, concurrency))
scaffolding_llm.shutdown()
assert len(results) == requests_num
assert len(requests_execution_time) == requests_num
assert results[0].outputs[0].text == OUTPUT_STR
assert results[0].task_collections[
"bench_dummy_collection"].output_len == len(OUTPUT_STR)