[None][feat] Move StreamGeneration to scaffolding main directory (#8347)

Signed-off-by: Dong Cao <docao@nvidia.com>
This commit is contained in:
Cao Dong 2025-10-14 17:16:04 +08:00 committed by GitHub
parent 72d65d079a
commit 62cea877b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 109 additions and 24 deletions

View File

@ -42,9 +42,8 @@ class NativeStreamGenerationController(Controller):
"custom_sampling_params")
elif self.custom_sampling_params:
task.custom_sampling_params = self.custom_sampling_params
stream_task = StreamGenerationTask()
stream_task.__dict__ = copy.deepcopy(task.__dict__)
stream_task.streaming_step = self.stream_step
stream_task = StreamGenerationTask.create_from_generation_task(
task, self.stream_step)
stream_tasks.append(stream_task)
lst = list(range(len(stream_tasks)))

View File

@ -6,7 +6,8 @@ from .controller import (BestOfNController, Controller, MajorityVoteController,
from .math_utils import (extract_answer_from_boxed, extract_answer_with_regex,
get_digit_majority_vote_result)
from .scaffolding_llm import ScaffoldingLlm
from .task import GenerationTask, RewardTask, Task, TaskStatus
from .task import (GenerationTask, RewardTask, StreamGenerationTask, Task,
TaskStatus)
from .task_collection import (GenerationTokenCounter, TaskCollection,
with_task_collection)
from .worker import OpenaiWorker, TRTLLMWorker, TRTOpenaiWorker, Worker
@ -22,6 +23,7 @@ __all__ = [
"BestOfNController",
"Task",
"GenerationTask",
"StreamGenerationTask",
"RewardTask",
"Worker",
"OpenaiWorker",

View File

@ -1,4 +1,5 @@
import asyncio
import copy
from dataclasses import dataclass, field
from typing import Any, Optional
@ -22,6 +23,15 @@ class StreamGenerationTask(GenerationTask):
# worker set this field to True when the generation is finished
end_flag: bool = field(default=False)
@staticmethod
def create_from_generation_task(
task: GenerationTask,
streaming_step: int) -> "StreamGenerationTask":
stream_task = StreamGenerationTask()
stream_task.__dict__ = copy.deepcopy(task.__dict__)
stream_task.streaming_step = streaming_step
return stream_task
async def stream_generation_handler(worker,
task: StreamGenerationTask) -> TaskStatus:

View File

@ -230,15 +230,16 @@ class MajorityVoteController(Controller):
yield ParallelProcess(generation_controllers, tasks_list,
generation_kwargs_list)
candidates = [tasks[0].output_str for tasks in tasks_list]
majority_index, majority_answer = self.majority_vote(
candidates, **majority_vote_kwargs)
tasks_list, **majority_vote_kwargs)
assert isinstance(majority_answer, str), "majority_vote failed"
# The task returned by majority vote does not have output_tokens and logits.
tasks[0].result = tasks_list[majority_index][0].result
def majority_vote(self, candidates: List[str], **kwargs) -> Tuple[int, str]:
def majority_vote(self, candidates_tasks: List[List[Task]],
**kwargs) -> Tuple[int, str]:
candidates = [tasks[0].output_str for tasks in candidates_tasks]
return get_digit_majority_vote_result(candidates)

View File

@ -175,13 +175,19 @@ class ScaffoldingLlm:
result = ScaffoldingResult(self.streaming_event)
async def put_request():
request = ScaffoldingRequest(
prompt=prompt,
kwargs={},
result=result,
controller=self.prototype_controller.clone())
await self.task_queue.put(request)
try:
request = ScaffoldingRequest(
prompt=prompt,
kwargs={},
result=result,
controller=self.prototype_controller.clone())
except Exception as e:
self.task_queue.put(None)
print(
f"Error: build ScaffoldingRequest failed: {e} \n {traceback.format_exc()}"
)
else:
await self.task_queue.put(request)
asyncio.run_coroutine_threadsafe(put_request(), self.loop)
@ -208,7 +214,7 @@ class ScaffoldingLlm:
def shutdown(self, shutdown_workers=False):
def shutdown_workers():
def shutdown_workers_func():
for worker in self.workers.values():
worker.shutdown()
@ -228,4 +234,4 @@ class ScaffoldingLlm:
self.shutdown_event.set()
if shutdown_workers:
shutdown_workers()
shutdown_workers_func()

View File

@ -1,10 +1,11 @@
import copy
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
import torch
from tensorrt_llm.executor.result import GenerationResult
from tensorrt_llm.executor.result import GenerationResult, TokenLogprobs
@dataclass
@ -64,6 +65,7 @@ class GenerationTask(Task):
# result field
# link to TRTLLM's GenerationResult, for async update in streaming mode
_result: Optional[GenerationResult] = None
customized_result_fields: Dict[str, Any] = field(default_factory=dict)
@property
def result(self) -> GenerationResult:
@ -96,7 +98,7 @@ class GenerationTask(Task):
0].cumulative_logprob if self._result else None
@property
def logprobs(self) -> Optional[List[float]]:
def logprobs(self) -> Optional[TokenLogprobs]:
return self._result.outputs[0].logprobs if self._result else None
@property
@ -115,6 +117,32 @@ class GenerationTask(Task):
return self._result
@dataclass
class StreamGenerationTask(GenerationTask):
# input field
# if the flag is set to True, the worker will cancel the generation work
cancel_flag: Optional[bool] = field(default=False)
# the task will be returned to the controller with at least new streaming_step tokens
# if the streaming_step is set to 0,
# the task will be returned to the controller immediately with
# new tokens that have already been generated.
streaming_step: Optional[int] = field(default=1)
#result field
# worker set this field and identify the same task by this field
request_handle: Any = field(default=None)
# worker set this field to True when the generation is finished
end_flag: bool = field(default=False)
@staticmethod
def create_from_generation_task(task: GenerationTask,
streaming_step) -> "StreamGenerationTask":
stream_task = StreamGenerationTask()
stream_task.__dict__ = copy.deepcopy(task.__dict__)
stream_task.streaming_step = streaming_step
return stream_task
@dataclass
class RewardTask(Task):
# input field

View File

@ -1,15 +1,16 @@
import asyncio
from abc import ABC
from typing import Callable
from typing import Callable, Optional
import openai
from transformers import AutoTokenizer
from tensorrt_llm import LLM
from tensorrt_llm.executor import GenerationExecutor
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, SchedulerConfig
from tensorrt_llm.sampling_params import SamplingParams
from .task import GenerationTask, Task, TaskStatus
from .task import GenerationTask, StreamGenerationTask, Task, TaskStatus
ExecutorCls = GenerationExecutor
@ -150,6 +151,7 @@ class TRTLLMWorker(Worker):
max_num_tokens: int = 4096,
kv_cache_free_gpu_memory_fraction: float = 0.9,
disable_overlap_scheduler: bool = False,
scheduler_config: Optional[SchedulerConfig] = None,
):
kv_cache_config = KvCacheConfig(
free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction, )
@ -168,7 +170,8 @@ class TRTLLMWorker(Worker):
disable_overlap_scheduler=disable_overlap_scheduler,
kv_cache_config=kv_cache_config,
max_batch_size=max_batch_size,
max_num_tokens=max_num_tokens)
max_num_tokens=max_num_tokens,
scheduler_config=scheduler_config)
worker = cls(llm, tokenizer)
worker.own_llm = True
@ -201,8 +204,44 @@ class TRTLLMWorker(Worker):
# TODO: error handle
return TaskStatus.SUCCESS
async def stream_generation_handler(
self, task: StreamGenerationTask) -> TaskStatus:
async def get_step_or_more_tokens(task: StreamGenerationTask):
if task.cancel_flag:
task.end_flag = True
task.request_handle.abort()
return TaskStatus.SUCCESS
for _ in range(task.streaming_step):
await task.request_handle._aresult_step()
if task.request_handle._done:
break
while not task.request_handle._done:
async_task = asyncio.create_task(
task.request_handle._aresult_step())
if not async_task.done():
async_task.cancel()
break
if task.request_handle._done:
task.end_flag = True
if getattr(task, 'end_flag', False):
return TaskStatus.SUCCESS
if task.request_handle is None:
sampling_params = self.convert_task_params(task)
task.request_handle = self.llm.generate_async(
task.input_str, sampling_params=sampling_params, streaming=True)
task._result = task.request_handle
await get_step_or_more_tokens(task)
def shutdown(self):
if self.own_llm:
self.llm.shutdown()
task_handlers = {GenerationTask: generation_handler}
task_handlers = {
GenerationTask: generation_handler,
StreamGenerationTask: stream_generation_handler
}