mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: fredw (generated by with_the_same_user script) <20514172+WeiHaocheng@users.noreply.github.com>
This commit is contained in:
parent
012fb9a1c4
commit
c6081abb0e
@ -1,6 +1,6 @@
|
||||
from .controller import (BestOfNController, Controller, MajorityVoteController,
|
||||
NativeGenerationController, NativeRewardController,
|
||||
ScaffoldingOutput)
|
||||
ParallelProcess, ScaffoldingOutput)
|
||||
from .math_utils import (extract_answer_from_boxed, extract_answer_with_regex,
|
||||
get_digit_majority_vote_result)
|
||||
from .scaffolding_llm import ScaffoldingLlm
|
||||
@ -8,7 +8,7 @@ from .task import GenerationTask, RewardTask, Task, TaskStatus
|
||||
from .worker import OpenaiWorker, TRTLLMWorker, TRTOpenaiWorker, Worker
|
||||
|
||||
__all__ = [
|
||||
"ScaffoldingLlm", "ScaffoldingOutput", "Controller",
|
||||
"ScaffoldingLlm", "ScaffoldingOutput", "ParallelProcess", "Controller",
|
||||
"NativeGenerationController", "NativeRewardController",
|
||||
"MajorityVoteController", "BestOfNController", "Task", "GenerationTask",
|
||||
"RewardTask", "Worker", "OpenaiWorker", "TRTOpenaiWorker", "TRTLLMWorker",
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import copy
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Tuple
|
||||
from typing import Any, List, Mapping, Tuple
|
||||
|
||||
from tensorrt_llm.scaffolding.math_utils import get_digit_majority_vote_result
|
||||
from tensorrt_llm.scaffolding.task import (GenerationTask, RewardTask,
|
||||
@ -32,6 +33,13 @@ class Controller(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ParallelProcess:
|
||||
controllers: List[Controller]
|
||||
tasks_list: List[List[Task]]
|
||||
kwargs_list: List[Mapping[str, Any]]
|
||||
|
||||
|
||||
# Controller runs multiple generation tasks.
|
||||
class NativeGenerationController(Controller):
|
||||
|
||||
@ -40,22 +48,16 @@ class NativeGenerationController(Controller):
|
||||
|
||||
def __init__(self, custom_sampling_params: dict = None):
|
||||
super().__init__()
|
||||
self.custom_sampling_params = copy.deepcopy(custom_sampling_params)
|
||||
self.custom_sampling_params = copy.deepcopy(
|
||||
custom_sampling_params) if custom_sampling_params else None
|
||||
|
||||
def process(self, tasks: List[Task], **kwargs):
|
||||
for task in tasks:
|
||||
if not isinstance(task, GenerationTask):
|
||||
raise ValueError(
|
||||
"NativeGenerationController requires exactly one GenerationTask"
|
||||
)
|
||||
|
||||
for task in tasks:
|
||||
task.worker_tag = self.WorkerTag.GENERATION
|
||||
if kwargs.get("custom_sampling_params"):
|
||||
task.custom_sampling_params = kwargs.get(
|
||||
"custom_sampling_params")
|
||||
elif self.custom_sampling_params:
|
||||
task.custom_sampling_params = self.custom_sampling_params
|
||||
if self.custom_sampling_params:
|
||||
for key, value in self.custom_sampling_params.items():
|
||||
if hasattr(task, key) and getattr(task, key) is None:
|
||||
setattr(task, key, value)
|
||||
|
||||
yield tasks
|
||||
|
||||
@ -67,11 +69,6 @@ class NativeRewardController(Controller):
|
||||
REWARD = "reward"
|
||||
|
||||
def process(self, tasks: List[Task], **kwargs):
|
||||
for task in tasks:
|
||||
if not isinstance(task, RewardTask):
|
||||
raise ValueError(
|
||||
"NativeRewardController requires exactly one RewardTask")
|
||||
|
||||
for task in tasks:
|
||||
task.worker_tag = self.WorkerTag.REWARD
|
||||
|
||||
@ -95,20 +92,27 @@ class MajorityVoteController(Controller):
|
||||
return MajorityVoteController(generation_controller,
|
||||
self.default_sample_num)
|
||||
|
||||
def process(self, tasks: List[Task], **kwargs):
|
||||
assert len(tasks) == 1 and isinstance(tasks[0], GenerationTask), \
|
||||
"MajorityVoteController requires exactly one GenerationTask"
|
||||
def process(self,
|
||||
tasks: List[Task],
|
||||
sample_num: int = 1,
|
||||
generation_kwargs: dict = {},
|
||||
majority_vote_kwargs: dict = {}):
|
||||
sample_num = max(sample_num, self.default_sample_num)
|
||||
generation_controllers = [
|
||||
self.generation_controller.clone() for _ in range(sample_num)
|
||||
]
|
||||
tasks_list = [copy.deepcopy(tasks) for _ in range(sample_num)]
|
||||
generation_kwargs_list = [
|
||||
copy.deepcopy(generation_kwargs) for _ in range(sample_num)
|
||||
]
|
||||
|
||||
sample_num = kwargs.get("sample_num", self.default_sample_num)
|
||||
generation_tasks = [copy.deepcopy(tasks[0]) for _ in range(sample_num)]
|
||||
yield from self.generation_controller.process(
|
||||
generation_tasks, **kwargs.get("generation_kwargs", {}))
|
||||
yield ParallelProcess(generation_controllers, tasks_list,
|
||||
generation_kwargs_list)
|
||||
|
||||
candidates = [task.output_str for task in generation_tasks]
|
||||
result = self.majority_vote(candidates,
|
||||
**kwargs.get("majority_vote_kwargs", {}))
|
||||
candidates = [tasks[0].output_str for tasks in tasks_list]
|
||||
result = self.majority_vote(candidates, **majority_vote_kwargs)
|
||||
|
||||
assert (isinstance(result, str))
|
||||
assert isinstance(result, str), "majority_vote failed"
|
||||
# The task returned by majority vote does not have output_tokens and logits.
|
||||
tasks[0].output_str = result
|
||||
|
||||
@ -134,31 +138,51 @@ class BestOfNController(Controller):
|
||||
return BestOfNController(generation_controller, reward_controller,
|
||||
self.default_sample_num)
|
||||
|
||||
def process(self, tasks: List[Task], **kwargs):
|
||||
assert len(tasks) == 1 and isinstance(tasks[0], GenerationTask), \
|
||||
"BestOfNController requires exactly one GenerationTask"
|
||||
|
||||
sample_num = kwargs.get("sample_num", self.default_sample_num)
|
||||
generation_tasks = [tasks[0].deepcopy() for _ in range(sample_num)]
|
||||
yield from self.generation_controller.process(
|
||||
generation_tasks, **kwargs.get("generation_kwargs"))
|
||||
|
||||
reward_tasks = [
|
||||
RewardTask.create_from_generation_task(generation_task)
|
||||
for generation_task in generation_tasks
|
||||
def process(self,
|
||||
tasks: List[Task],
|
||||
sample_num: int = 1,
|
||||
generation_kwargs: dict = {},
|
||||
reward_kwargs: dict = {},
|
||||
select_best_kwargs: dict = {}):
|
||||
sample_num = max(sample_num, self.default_sample_num)
|
||||
generation_controllers = [
|
||||
self.generation_controller.clone() for _ in range(sample_num)
|
||||
]
|
||||
yield from self.reward_controller.process(reward_tasks,
|
||||
**kwargs.get("reward_kwargs"))
|
||||
self.generation_tasks_list = [tasks for _ in range(sample_num)]
|
||||
generation_kwargs_list = [generation_kwargs for _ in range(sample_num)]
|
||||
|
||||
yield ParallelProcess(generation_controllers,
|
||||
self.generation_tasks_list,
|
||||
generation_kwargs_list)
|
||||
|
||||
# Some best of N algorithms create sample_num reward task lists while some just create one.
|
||||
# We maintain generic here as much as possible.
|
||||
self.reward_tasks_list = self.create_reward_tasks(
|
||||
self.generation_tasks_list)
|
||||
reward_paraller_num = len(self.reward_tasks_list)
|
||||
reward_controllers = [
|
||||
self.reward_controller.clone() for _ in range(reward_paraller_num)
|
||||
]
|
||||
reward_kwargs_list = [reward_kwargs for _ in range(reward_paraller_num)]
|
||||
|
||||
yield ParallelProcess(reward_controllers, self.reward_tasks_list,
|
||||
reward_kwargs_list)
|
||||
|
||||
# may used for upper layer controllers
|
||||
self.best_generation_task, self.best_reward_task = (self.select_best(
|
||||
generation_tasks, reward_tasks, **kwargs.get("select_best_kwargs")))
|
||||
tasks[0] = self.best_generation_task
|
||||
self.best_generation_task, self.best_reward_task = self.select_best(
|
||||
self.generation_tasks_list, self.reward_tasks_list,
|
||||
**select_best_kwargs)
|
||||
tasks = self.best_generation_task
|
||||
|
||||
def select_best(self, generation_tasks: List[GenerationTask],
|
||||
reward_tasks: List[RewardTask],
|
||||
**kwargs) -> Tuple[GenerationTask, RewardTask]:
|
||||
def select_best(self, generation_tasks: List[List[Task]],
|
||||
reward_tasks: List[List[Task]],
|
||||
**kwargs) -> Tuple[List[Task], List[Task]]:
|
||||
assert len(generation_tasks[0]) == 1 and isinstance(generation_tasks[0][0], GenerationTask), \
|
||||
"Should not use default select_best implementation for BestOfNController"
|
||||
assert len(reward_tasks[0]) == 1 and isinstance(reward_tasks[0][0], RewardTask), \
|
||||
"Should not use default select_best implementation for BestOfNController"
|
||||
# select the best generation task and reward task
|
||||
max_reward_value_index = reward_tasks.index(
|
||||
max(reward_tasks, key=lambda x: x.reward_value))
|
||||
max(reward_tasks, key=lambda x: x[0].reward_value))
|
||||
return generation_tasks[max_reward_value_index], reward_tasks[
|
||||
max_reward_value_index]
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import traceback
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Mapping, Union
|
||||
from typing import Any, Generator, List, Mapping, Union
|
||||
|
||||
from .controller import Controller, ScaffoldingOutput
|
||||
from .controller import Controller, ParallelProcess, ScaffoldingOutput
|
||||
from .worker import Worker
|
||||
|
||||
|
||||
@ -27,8 +28,10 @@ class ScaffoldingResult:
|
||||
self.aqueue.put_nowait(output)
|
||||
|
||||
async def aresult_step(self):
|
||||
# TODO: error handling?
|
||||
# TODO: error handling or raise exception?
|
||||
self.output = await self.aqueue.get()
|
||||
if self.output is None:
|
||||
raise Exception("ScaffoldingLlm execution failed")
|
||||
self._done = True
|
||||
|
||||
def result(self) -> "ScaffoldingResult":
|
||||
@ -90,11 +93,12 @@ class ScaffoldingLlm:
|
||||
|
||||
async def _main_loop_async_func(self):
|
||||
|
||||
async def handle_single_request(request: ScaffoldingRequest):
|
||||
gen = request.controller.generate(request.prompt, **request.kwargs)
|
||||
try:
|
||||
while True:
|
||||
task_list = next(gen)
|
||||
async def handle_controller_generator(gen: Generator):
|
||||
for obj in gen:
|
||||
if isinstance(obj, ParallelProcess):
|
||||
await handle_parallel_process(obj)
|
||||
else:
|
||||
task_list = obj
|
||||
async_tasks = []
|
||||
for task in task_list:
|
||||
task_worker_tag = task.worker_tag
|
||||
@ -103,13 +107,37 @@ class ScaffoldingLlm:
|
||||
async_tasks.append(
|
||||
asyncio.create_task(worker.run_task(task)))
|
||||
await asyncio.gather(*async_tasks)
|
||||
# TODO: if we need use results?
|
||||
except StopIteration as e:
|
||||
scaffolding_output = e.value
|
||||
|
||||
async def handle_parallel_process(request: ParallelProcess):
|
||||
async_tasks = []
|
||||
for controller, tasks, kwargs in zip(request.controllers,
|
||||
request.tasks_list,
|
||||
request.kwargs_list):
|
||||
gen = controller.process(tasks, **kwargs)
|
||||
async_task = asyncio.create_task(
|
||||
handle_controller_generator(gen))
|
||||
async_tasks.append(async_task)
|
||||
await asyncio.gather(*async_tasks)
|
||||
|
||||
async def handle_single_request(request: ScaffoldingRequest):
|
||||
# warp to a generator without return value
|
||||
def controller_generator_wrapper(request: ScaffoldingRequest):
|
||||
scaffolding_output = yield from request.controller.generate(
|
||||
request.prompt, **request.kwargs)
|
||||
request.result.set_output(scaffolding_output)
|
||||
|
||||
self.running_req_count -= 1
|
||||
maybe_schedule()
|
||||
try:
|
||||
gen = controller_generator_wrapper(request)
|
||||
await handle_controller_generator(gen)
|
||||
except Exception as e:
|
||||
# Catch the exception and set output to avoid the user thread to be hang
|
||||
print('scaffoldingLlm handle request exception:', str(e))
|
||||
traceback.print_exc()
|
||||
request.result.set_output(None)
|
||||
raise e
|
||||
finally:
|
||||
self.running_req_count -= 1
|
||||
maybe_schedule()
|
||||
|
||||
def schedule_request(request: ScaffoldingRequest):
|
||||
asyncio.create_task(handle_single_request(request))
|
||||
@ -156,7 +184,6 @@ class ScaffoldingLlm:
|
||||
self.main_loop_thread.start()
|
||||
|
||||
def generate_async(self, prompt: str) -> ScaffoldingResult:
|
||||
|
||||
result = ScaffoldingResult()
|
||||
|
||||
async def put_request():
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
@ -30,9 +30,8 @@ class Worker(ABC):
|
||||
|
||||
task_handlers = {}
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self):
|
||||
raise NotImplementedError
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
150
tests/unittest/scaffolding/test_parallel_process.py
Normal file
150
tests/unittest/scaffolding/test_parallel_process.py
Normal file
@ -0,0 +1,150 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from tensorrt_llm.scaffolding import (Controller, ParallelProcess,
|
||||
ScaffoldingLlm, ScaffoldingOutput, Task,
|
||||
TaskStatus, Worker)
|
||||
|
||||
|
||||
class DummyTask(Task):
|
||||
|
||||
def __init__(self, turn: int):
|
||||
self.turn = turn
|
||||
self.numbers = []
|
||||
|
||||
@staticmethod
|
||||
def create_from_prompt(prompt: str) -> "DummyTask":
|
||||
task = DummyTask(2)
|
||||
return task
|
||||
|
||||
def create_scaffolding_output(self) -> "ScaffoldingOutput":
|
||||
self.verify()
|
||||
return ScaffoldingOutput()
|
||||
|
||||
def verify(self):
|
||||
for i in range(len(self.numbers)):
|
||||
assert self.numbers[i] == i, "task.numbers[i] should be i"
|
||||
|
||||
|
||||
class DummyControllerBase(Controller):
|
||||
|
||||
def generate(self, prompt: str, **kwargs) -> ScaffoldingOutput:
|
||||
task = DummyTask.create_from_prompt(prompt)
|
||||
yield from self.process([task], **kwargs)
|
||||
return task.create_scaffolding_output()
|
||||
|
||||
|
||||
# Controller that yields task.turn times for each task
|
||||
class DummyController(DummyControllerBase):
|
||||
|
||||
class WorkerTag(Enum):
|
||||
DUMMY = "dummy"
|
||||
|
||||
def process(self, tasks: List[Task], **kwargs):
|
||||
yield_tasks = tasks
|
||||
while len(yield_tasks) > 0:
|
||||
new_tasks = []
|
||||
for task in yield_tasks:
|
||||
if len(task.numbers) < task.turn:
|
||||
task.worker_tag = self.WorkerTag.DUMMY
|
||||
new_tasks.append(task)
|
||||
yield_tasks = new_tasks
|
||||
|
||||
if len(yield_tasks) > 0:
|
||||
yield yield_tasks
|
||||
|
||||
|
||||
# The flag to enable parallel process
|
||||
# We can use this flag to compare the performance of parallel process and sequence process
|
||||
ENABLE_PARALLEL_PROCESS = True
|
||||
|
||||
|
||||
class DummyParallelController(DummyControllerBase):
|
||||
|
||||
def __init__(self, controllers):
|
||||
self.controllers = controllers
|
||||
|
||||
def process(self, tasks: List[Task], **kwargs):
|
||||
global ENABLE_PARALLEL_PROCESS
|
||||
if ENABLE_PARALLEL_PROCESS:
|
||||
tasks_list = [
|
||||
copy.deepcopy(tasks) for _ in range(len(self.controllers))
|
||||
]
|
||||
|
||||
kwargs_list = [kwargs for _ in range(len(self.controllers))]
|
||||
#yield from parallel_process_helper(self.controllers, tasks_list, kwargs_list)
|
||||
yield ParallelProcess(self.controllers, tasks_list, kwargs_list)
|
||||
|
||||
tasks = tasks_list[0]
|
||||
else:
|
||||
original_tasks = copy.deepcopy(tasks)
|
||||
for controller in self.controllers:
|
||||
tasks = copy.deepcopy(original_tasks)
|
||||
yield from controller.process(tasks, **kwargs)
|
||||
|
||||
|
||||
class DummyWorker(Worker):
|
||||
|
||||
async def dummy_handler(self, task: DummyTask):
|
||||
await asyncio.sleep(1)
|
||||
task.numbers.append(len(task.numbers))
|
||||
return TaskStatus.SUCCESS
|
||||
|
||||
task_handlers = {DummyTask: dummy_handler}
|
||||
|
||||
|
||||
def parallel_process_helper_run_and_verify(controllers):
|
||||
# Obtain the generator from parallel_process_helper.
|
||||
parallel_controller = DummyParallelController(controllers)
|
||||
worker = DummyWorker()
|
||||
llm = ScaffoldingLlm(parallel_controller,
|
||||
{DummyController.WorkerTag.DUMMY: worker})
|
||||
|
||||
global ENABLE_PARALLEL_PROCESS
|
||||
ENABLE_PARALLEL_PROCESS = True
|
||||
start_time = time.time()
|
||||
llm.generate("")
|
||||
end_time = time.time()
|
||||
print('Parallel process time:', end_time - start_time)
|
||||
|
||||
ENABLE_PARALLEL_PROCESS = False
|
||||
start_time = time.time()
|
||||
llm.generate("")
|
||||
end_time = time.time()
|
||||
print('Sequence process time:', end_time - start_time)
|
||||
|
||||
llm.shutdown()
|
||||
|
||||
|
||||
def test_parallel_process_helper():
|
||||
NUM_CONTROLLERS = 3
|
||||
controllers = []
|
||||
|
||||
for _ in range(NUM_CONTROLLERS):
|
||||
controller = DummyController()
|
||||
controllers.append(controller)
|
||||
|
||||
parallel_process_helper_run_and_verify(controllers)
|
||||
|
||||
|
||||
def test_parallel_process_helper_with_two_level():
|
||||
NUM_CONTROLLERS_LEVEL_1 = 2
|
||||
NUM_CONTROLLERS_LEVEL_2 = 2
|
||||
|
||||
controllers_level_1 = []
|
||||
|
||||
for _ in range(NUM_CONTROLLERS_LEVEL_1):
|
||||
controller = DummyController()
|
||||
controllers_level_1.append(controller)
|
||||
|
||||
parallel_controller = DummyParallelController(controllers_level_1)
|
||||
controllers_level_2 = [parallel_controller]
|
||||
|
||||
for _ in range(NUM_CONTROLLERS_LEVEL_2):
|
||||
controller = DummyController()
|
||||
controllers_level_2.append(controller)
|
||||
|
||||
parallel_process_helper_run_and_verify(controllers_level_2)
|
||||
@ -1,7 +1,7 @@
|
||||
# isort: off
|
||||
# isort: on
|
||||
# autoflake: skip_file
|
||||
|
||||
from scaffolding.test_worker import create_trtllm_worker
|
||||
from scaffolding.test_worker import (create_trtllm_worker,
|
||||
deepseek_distill_7b_path, default_prompt)
|
||||
|
||||
from tensorrt_llm.scaffolding import (MajorityVoteController,
|
||||
NativeGenerationController,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user