TensorRT-LLMs/tests/unittest/scaffolding/test_task_collection.py
Zhenhuan Chen 992b273045
[https://nvbugs/5387375] fix(scaffolding): fix scaffolding aime test in test_e2e (#6140)
Signed-off-by: Zhenhuan Chen <chenzhh3671@gmail.com>
2025-07-18 10:34:37 +08:00

131 lines
3.5 KiB
Python

import copy
from enum import Enum
from typing import List
from tensorrt_llm.scaffolding import (Controller, ParallelProcess,
ScaffoldingLlm, Task, TaskCollection,
TaskStatus, Worker, with_task_collection)
class DummyTask(Task):
def __init__(self):
self.before_flag = False
self.after_flag = False
@staticmethod
def create_from_prompt(prompt: str) -> "DummyTask":
task = DummyTask()
return task
def create_scaffolding_output(self):
return None
def verify(self):
assert self.before_flag, "task.before_flag has not been set to True"
assert self.after_flag, "task.after_flag has not been set to True"
class DummyTaskCollection(TaskCollection):
def __init__(self):
super().__init__()
self.task_count = 0
def before_yield(self, tasks: List[Task]):
for task in tasks:
task.before_flag = True
def after_yield(self, tasks: List[Task]):
for task in tasks:
task.after_flag = True
self.task_count += 1
class DummyControllerBase(Controller):
class WorkerTag(Enum):
DUMMY = "dummy"
def __init__(self, expected_task_count: int):
super().__init__()
self.expected_task_count = expected_task_count
def generate(self, prompt: str, **kwargs):
task = DummyTask.create_from_prompt(prompt)
yield from self.process([task], **kwargs)
return task.create_scaffolding_output()
def verify(self):
assert self.task_collections[
"dummy"].task_count == self.expected_task_count, (
"task count is not as expected")
@with_task_collection("dummy", DummyTaskCollection)
# verify that we can have multiple task collections and they works separately
@with_task_collection("dummy2", DummyTaskCollection)
class DummyController(DummyControllerBase):
def process(self, tasks: List[Task], **kwargs):
for task in tasks:
task.worker_tag = self.WorkerTag.DUMMY
yield tasks
for task in tasks:
task.verify()
self.verify()
@with_task_collection("dummy", DummyTaskCollection)
class DummyParallelController(DummyControllerBase):
def __init__(self, controllers):
expected_task_count = sum(
[controller.expected_task_count for controller in controllers])
super().__init__(expected_task_count)
self.controllers = controllers
def process(self, tasks: List[Task], **kwargs):
tasks_list = [
copy.deepcopy(tasks) for _ in range(len(self.controllers))
]
kwargs_list = [kwargs for _ in range(len(self.controllers))]
yield ParallelProcess(self.controllers, tasks_list, kwargs_list)
for tasks in tasks_list:
for task in tasks:
task.verify()
tasks = tasks_list[0]
self.verify()
class DummyWorker(Worker):
async def dummy_handler(self, task: DummyTask):
return TaskStatus.SUCCESS
task_handlers = {DummyTask: dummy_handler}
def run(controller, expected_task_count):
worker = DummyWorker()
llm = ScaffoldingLlm(controller, {DummyController.WorkerTag.DUMMY: worker})
llm.generate("")
llm.shutdown()
def test_dummy_task_collection():
controller = DummyController(1)
run(controller, 1)
controller = DummyParallelController(
[DummyController(1), DummyController(1)])
run(controller, 2)