feat: Make scaffolding Controller more generic #3408 (#3416)

Signed-off-by: fredw (generated by with_the_same_user script) <20514172+WeiHaocheng@users.noreply.github.com>
This commit is contained in:
WeiHaocheng 2025-04-12 21:35:38 +08:00 committed by GitHub
parent 012fb9a1c4
commit c6081abb0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 272 additions and 72 deletions

View File

@ -1,6 +1,6 @@
from .controller import (BestOfNController, Controller, MajorityVoteController,
NativeGenerationController, NativeRewardController,
ScaffoldingOutput)
ParallelProcess, ScaffoldingOutput)
from .math_utils import (extract_answer_from_boxed, extract_answer_with_regex,
get_digit_majority_vote_result)
from .scaffolding_llm import ScaffoldingLlm
@ -8,7 +8,7 @@ from .task import GenerationTask, RewardTask, Task, TaskStatus
from .worker import OpenaiWorker, TRTLLMWorker, TRTOpenaiWorker, Worker
__all__ = [
"ScaffoldingLlm", "ScaffoldingOutput", "Controller",
"ScaffoldingLlm", "ScaffoldingOutput", "ParallelProcess", "Controller",
"NativeGenerationController", "NativeRewardController",
"MajorityVoteController", "BestOfNController", "Task", "GenerationTask",
"RewardTask", "Worker", "OpenaiWorker", "TRTOpenaiWorker", "TRTLLMWorker",

View File

@ -1,7 +1,8 @@
import copy
from abc import ABC
from dataclasses import dataclass
from enum import Enum
from typing import List, Tuple
from typing import Any, List, Mapping, Tuple
from tensorrt_llm.scaffolding.math_utils import get_digit_majority_vote_result
from tensorrt_llm.scaffolding.task import (GenerationTask, RewardTask,
@ -32,6 +33,13 @@ class Controller(ABC):
raise NotImplementedError
@dataclass(frozen=True)
class ParallelProcess:
controllers: List[Controller]
tasks_list: List[List[Task]]
kwargs_list: List[Mapping[str, Any]]
# Controller runs multiple generation tasks.
class NativeGenerationController(Controller):
@ -40,22 +48,16 @@ class NativeGenerationController(Controller):
def __init__(self, custom_sampling_params: dict = None):
super().__init__()
self.custom_sampling_params = copy.deepcopy(custom_sampling_params)
self.custom_sampling_params = copy.deepcopy(
custom_sampling_params) if custom_sampling_params else None
def process(self, tasks: List[Task], **kwargs):
for task in tasks:
if not isinstance(task, GenerationTask):
raise ValueError(
"NativeGenerationController requires exactly one GenerationTask"
)
for task in tasks:
task.worker_tag = self.WorkerTag.GENERATION
if kwargs.get("custom_sampling_params"):
task.custom_sampling_params = kwargs.get(
"custom_sampling_params")
elif self.custom_sampling_params:
task.custom_sampling_params = self.custom_sampling_params
if self.custom_sampling_params:
for key, value in self.custom_sampling_params.items():
if hasattr(task, key) and getattr(task, key) is None:
setattr(task, key, value)
yield tasks
@ -67,11 +69,6 @@ class NativeRewardController(Controller):
REWARD = "reward"
def process(self, tasks: List[Task], **kwargs):
for task in tasks:
if not isinstance(task, RewardTask):
raise ValueError(
"NativeRewardController requires exactly one RewardTask")
for task in tasks:
task.worker_tag = self.WorkerTag.REWARD
@ -95,20 +92,27 @@ class MajorityVoteController(Controller):
return MajorityVoteController(generation_controller,
self.default_sample_num)
def process(self, tasks: List[Task], **kwargs):
assert len(tasks) == 1 and isinstance(tasks[0], GenerationTask), \
"MajorityVoteController requires exactly one GenerationTask"
def process(self,
tasks: List[Task],
sample_num: int = 1,
generation_kwargs: dict = {},
majority_vote_kwargs: dict = {}):
sample_num = max(sample_num, self.default_sample_num)
generation_controllers = [
self.generation_controller.clone() for _ in range(sample_num)
]
tasks_list = [copy.deepcopy(tasks) for _ in range(sample_num)]
generation_kwargs_list = [
copy.deepcopy(generation_kwargs) for _ in range(sample_num)
]
sample_num = kwargs.get("sample_num", self.default_sample_num)
generation_tasks = [copy.deepcopy(tasks[0]) for _ in range(sample_num)]
yield from self.generation_controller.process(
generation_tasks, **kwargs.get("generation_kwargs", {}))
yield ParallelProcess(generation_controllers, tasks_list,
generation_kwargs_list)
candidates = [task.output_str for task in generation_tasks]
result = self.majority_vote(candidates,
**kwargs.get("majority_vote_kwargs", {}))
candidates = [tasks[0].output_str for tasks in tasks_list]
result = self.majority_vote(candidates, **majority_vote_kwargs)
assert (isinstance(result, str))
assert isinstance(result, str), "majority_vote failed"
# The task returned by majority vote does not have output_tokens and logits.
tasks[0].output_str = result
@ -134,31 +138,51 @@ class BestOfNController(Controller):
return BestOfNController(generation_controller, reward_controller,
self.default_sample_num)
def process(self, tasks: List[Task], **kwargs):
assert len(tasks) == 1 and isinstance(tasks[0], GenerationTask), \
"BestOfNController requires exactly one GenerationTask"
sample_num = kwargs.get("sample_num", self.default_sample_num)
generation_tasks = [tasks[0].deepcopy() for _ in range(sample_num)]
yield from self.generation_controller.process(
generation_tasks, **kwargs.get("generation_kwargs"))
reward_tasks = [
RewardTask.create_from_generation_task(generation_task)
for generation_task in generation_tasks
def process(self,
tasks: List[Task],
sample_num: int = 1,
generation_kwargs: dict = {},
reward_kwargs: dict = {},
select_best_kwargs: dict = {}):
sample_num = max(sample_num, self.default_sample_num)
generation_controllers = [
self.generation_controller.clone() for _ in range(sample_num)
]
yield from self.reward_controller.process(reward_tasks,
**kwargs.get("reward_kwargs"))
self.generation_tasks_list = [tasks for _ in range(sample_num)]
generation_kwargs_list = [generation_kwargs for _ in range(sample_num)]
yield ParallelProcess(generation_controllers,
self.generation_tasks_list,
generation_kwargs_list)
# Some best of N algorithms create sample_num reward task lists while some just create one.
# We maintain generic here as much as possible.
self.reward_tasks_list = self.create_reward_tasks(
self.generation_tasks_list)
reward_paraller_num = len(self.reward_tasks_list)
reward_controllers = [
self.reward_controller.clone() for _ in range(reward_paraller_num)
]
reward_kwargs_list = [reward_kwargs for _ in range(reward_paraller_num)]
yield ParallelProcess(reward_controllers, self.reward_tasks_list,
reward_kwargs_list)
# may used for upper layer controllers
self.best_generation_task, self.best_reward_task = (self.select_best(
generation_tasks, reward_tasks, **kwargs.get("select_best_kwargs")))
tasks[0] = self.best_generation_task
self.best_generation_task, self.best_reward_task = self.select_best(
self.generation_tasks_list, self.reward_tasks_list,
**select_best_kwargs)
tasks = self.best_generation_task
def select_best(self, generation_tasks: List[GenerationTask],
reward_tasks: List[RewardTask],
**kwargs) -> Tuple[GenerationTask, RewardTask]:
def select_best(self, generation_tasks: List[List[Task]],
reward_tasks: List[List[Task]],
**kwargs) -> Tuple[List[Task], List[Task]]:
assert len(generation_tasks[0]) == 1 and isinstance(generation_tasks[0][0], GenerationTask), \
"Should not use default select_best implementation for BestOfNController"
assert len(reward_tasks[0]) == 1 and isinstance(reward_tasks[0][0], RewardTask), \
"Should not use default select_best implementation for BestOfNController"
# select the best generation task and reward task
max_reward_value_index = reward_tasks.index(
max(reward_tasks, key=lambda x: x.reward_value))
max(reward_tasks, key=lambda x: x[0].reward_value))
return generation_tasks[max_reward_value_index], reward_tasks[
max_reward_value_index]

View File

@ -1,10 +1,11 @@
import asyncio
import threading
import traceback
from collections import deque
from dataclasses import dataclass
from typing import Any, List, Mapping, Union
from typing import Any, Generator, List, Mapping, Union
from .controller import Controller, ScaffoldingOutput
from .controller import Controller, ParallelProcess, ScaffoldingOutput
from .worker import Worker
@ -27,8 +28,10 @@ class ScaffoldingResult:
self.aqueue.put_nowait(output)
async def aresult_step(self):
# TODO: error handling?
# TODO: error handling or raise exception?
self.output = await self.aqueue.get()
if self.output is None:
raise Exception("ScaffoldingLlm execution failed")
self._done = True
def result(self) -> "ScaffoldingResult":
@ -90,11 +93,12 @@ class ScaffoldingLlm:
async def _main_loop_async_func(self):
async def handle_single_request(request: ScaffoldingRequest):
gen = request.controller.generate(request.prompt, **request.kwargs)
try:
while True:
task_list = next(gen)
async def handle_controller_generator(gen: Generator):
for obj in gen:
if isinstance(obj, ParallelProcess):
await handle_parallel_process(obj)
else:
task_list = obj
async_tasks = []
for task in task_list:
task_worker_tag = task.worker_tag
@ -103,13 +107,37 @@ class ScaffoldingLlm:
async_tasks.append(
asyncio.create_task(worker.run_task(task)))
await asyncio.gather(*async_tasks)
# TODO: if we need use results?
except StopIteration as e:
scaffolding_output = e.value
async def handle_parallel_process(request: ParallelProcess):
async_tasks = []
for controller, tasks, kwargs in zip(request.controllers,
request.tasks_list,
request.kwargs_list):
gen = controller.process(tasks, **kwargs)
async_task = asyncio.create_task(
handle_controller_generator(gen))
async_tasks.append(async_task)
await asyncio.gather(*async_tasks)
async def handle_single_request(request: ScaffoldingRequest):
# warp to a generator without return value
def controller_generator_wrapper(request: ScaffoldingRequest):
scaffolding_output = yield from request.controller.generate(
request.prompt, **request.kwargs)
request.result.set_output(scaffolding_output)
self.running_req_count -= 1
maybe_schedule()
try:
gen = controller_generator_wrapper(request)
await handle_controller_generator(gen)
except Exception as e:
# Catch the exception and set output to avoid the user thread to be hang
print('scaffoldingLlm handle request exception:', str(e))
traceback.print_exc()
request.result.set_output(None)
raise e
finally:
self.running_req_count -= 1
maybe_schedule()
def schedule_request(request: ScaffoldingRequest):
asyncio.create_task(handle_single_request(request))
@ -156,7 +184,6 @@ class ScaffoldingLlm:
self.main_loop_thread.start()
def generate_async(self, prompt: str) -> ScaffoldingResult:
result = ScaffoldingResult()
async def put_request():

View File

@ -1,4 +1,4 @@
from abc import ABC, abstractmethod
from abc import ABC
from copy import deepcopy
from typing import Callable, List, Optional, Union
@ -30,9 +30,8 @@ class Worker(ABC):
task_handlers = {}
@abstractmethod
def shutdown(self):
raise NotImplementedError
pass
def __enter__(self):
return self

View File

@ -0,0 +1,150 @@
import asyncio
import copy
import time
from enum import Enum
from typing import List
from tensorrt_llm.scaffolding import (Controller, ParallelProcess,
ScaffoldingLlm, ScaffoldingOutput, Task,
TaskStatus, Worker)
class DummyTask(Task):
def __init__(self, turn: int):
self.turn = turn
self.numbers = []
@staticmethod
def create_from_prompt(prompt: str) -> "DummyTask":
task = DummyTask(2)
return task
def create_scaffolding_output(self) -> "ScaffoldingOutput":
self.verify()
return ScaffoldingOutput()
def verify(self):
for i in range(len(self.numbers)):
assert self.numbers[i] == i, "task.numbers[i] should be i"
class DummyControllerBase(Controller):
def generate(self, prompt: str, **kwargs) -> ScaffoldingOutput:
task = DummyTask.create_from_prompt(prompt)
yield from self.process([task], **kwargs)
return task.create_scaffolding_output()
# Controller that yields task.turn times for each task
class DummyController(DummyControllerBase):
class WorkerTag(Enum):
DUMMY = "dummy"
def process(self, tasks: List[Task], **kwargs):
yield_tasks = tasks
while len(yield_tasks) > 0:
new_tasks = []
for task in yield_tasks:
if len(task.numbers) < task.turn:
task.worker_tag = self.WorkerTag.DUMMY
new_tasks.append(task)
yield_tasks = new_tasks
if len(yield_tasks) > 0:
yield yield_tasks
# The flag to enable parallel process
# We can use this flag to compare the performance of parallel process and sequence process
ENABLE_PARALLEL_PROCESS = True
class DummyParallelController(DummyControllerBase):
def __init__(self, controllers):
self.controllers = controllers
def process(self, tasks: List[Task], **kwargs):
global ENABLE_PARALLEL_PROCESS
if ENABLE_PARALLEL_PROCESS:
tasks_list = [
copy.deepcopy(tasks) for _ in range(len(self.controllers))
]
kwargs_list = [kwargs for _ in range(len(self.controllers))]
#yield from parallel_process_helper(self.controllers, tasks_list, kwargs_list)
yield ParallelProcess(self.controllers, tasks_list, kwargs_list)
tasks = tasks_list[0]
else:
original_tasks = copy.deepcopy(tasks)
for controller in self.controllers:
tasks = copy.deepcopy(original_tasks)
yield from controller.process(tasks, **kwargs)
class DummyWorker(Worker):
async def dummy_handler(self, task: DummyTask):
await asyncio.sleep(1)
task.numbers.append(len(task.numbers))
return TaskStatus.SUCCESS
task_handlers = {DummyTask: dummy_handler}
def parallel_process_helper_run_and_verify(controllers):
# Obtain the generator from parallel_process_helper.
parallel_controller = DummyParallelController(controllers)
worker = DummyWorker()
llm = ScaffoldingLlm(parallel_controller,
{DummyController.WorkerTag.DUMMY: worker})
global ENABLE_PARALLEL_PROCESS
ENABLE_PARALLEL_PROCESS = True
start_time = time.time()
llm.generate("")
end_time = time.time()
print('Parallel process time:', end_time - start_time)
ENABLE_PARALLEL_PROCESS = False
start_time = time.time()
llm.generate("")
end_time = time.time()
print('Sequence process time:', end_time - start_time)
llm.shutdown()
def test_parallel_process_helper():
NUM_CONTROLLERS = 3
controllers = []
for _ in range(NUM_CONTROLLERS):
controller = DummyController()
controllers.append(controller)
parallel_process_helper_run_and_verify(controllers)
def test_parallel_process_helper_with_two_level():
NUM_CONTROLLERS_LEVEL_1 = 2
NUM_CONTROLLERS_LEVEL_2 = 2
controllers_level_1 = []
for _ in range(NUM_CONTROLLERS_LEVEL_1):
controller = DummyController()
controllers_level_1.append(controller)
parallel_controller = DummyParallelController(controllers_level_1)
controllers_level_2 = [parallel_controller]
for _ in range(NUM_CONTROLLERS_LEVEL_2):
controller = DummyController()
controllers_level_2.append(controller)
parallel_process_helper_run_and_verify(controllers_level_2)

View File

@ -1,7 +1,7 @@
# isort: off
# isort: on
# autoflake: skip_file
from scaffolding.test_worker import create_trtllm_worker
from scaffolding.test_worker import (create_trtllm_worker,
deepseek_distill_7b_path, default_prompt)
from tensorrt_llm.scaffolding import (MajorityVoteController,
NativeGenerationController,