mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] Move StreamGeneration to scaffolding main directory (#8347)
Signed-off-by: Dong Cao <docao@nvidia.com>
This commit is contained in:
parent
72d65d079a
commit
62cea877b1
@ -42,9 +42,8 @@ class NativeStreamGenerationController(Controller):
|
||||
"custom_sampling_params")
|
||||
elif self.custom_sampling_params:
|
||||
task.custom_sampling_params = self.custom_sampling_params
|
||||
stream_task = StreamGenerationTask()
|
||||
stream_task.__dict__ = copy.deepcopy(task.__dict__)
|
||||
stream_task.streaming_step = self.stream_step
|
||||
stream_task = StreamGenerationTask.create_from_generation_task(
|
||||
task, self.stream_step)
|
||||
stream_tasks.append(stream_task)
|
||||
lst = list(range(len(stream_tasks)))
|
||||
|
||||
|
||||
@ -6,7 +6,8 @@ from .controller import (BestOfNController, Controller, MajorityVoteController,
|
||||
from .math_utils import (extract_answer_from_boxed, extract_answer_with_regex,
|
||||
get_digit_majority_vote_result)
|
||||
from .scaffolding_llm import ScaffoldingLlm
|
||||
from .task import GenerationTask, RewardTask, Task, TaskStatus
|
||||
from .task import (GenerationTask, RewardTask, StreamGenerationTask, Task,
|
||||
TaskStatus)
|
||||
from .task_collection import (GenerationTokenCounter, TaskCollection,
|
||||
with_task_collection)
|
||||
from .worker import OpenaiWorker, TRTLLMWorker, TRTOpenaiWorker, Worker
|
||||
@ -22,6 +23,7 @@ __all__ = [
|
||||
"BestOfNController",
|
||||
"Task",
|
||||
"GenerationTask",
|
||||
"StreamGenerationTask",
|
||||
"RewardTask",
|
||||
"Worker",
|
||||
"OpenaiWorker",
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -22,6 +23,15 @@ class StreamGenerationTask(GenerationTask):
|
||||
# worker set this field to True when the generation is finished
|
||||
end_flag: bool = field(default=False)
|
||||
|
||||
@staticmethod
|
||||
def create_from_generation_task(
|
||||
task: GenerationTask,
|
||||
streaming_step: int) -> "StreamGenerationTask":
|
||||
stream_task = StreamGenerationTask()
|
||||
stream_task.__dict__ = copy.deepcopy(task.__dict__)
|
||||
stream_task.streaming_step = streaming_step
|
||||
return stream_task
|
||||
|
||||
|
||||
async def stream_generation_handler(worker,
|
||||
task: StreamGenerationTask) -> TaskStatus:
|
||||
|
||||
@ -230,15 +230,16 @@ class MajorityVoteController(Controller):
|
||||
yield ParallelProcess(generation_controllers, tasks_list,
|
||||
generation_kwargs_list)
|
||||
|
||||
candidates = [tasks[0].output_str for tasks in tasks_list]
|
||||
majority_index, majority_answer = self.majority_vote(
|
||||
candidates, **majority_vote_kwargs)
|
||||
tasks_list, **majority_vote_kwargs)
|
||||
|
||||
assert isinstance(majority_answer, str), "majority_vote failed"
|
||||
# The task returned by majority vote does not have output_tokens and logits.
|
||||
tasks[0].result = tasks_list[majority_index][0].result
|
||||
|
||||
def majority_vote(self, candidates: List[str], **kwargs) -> Tuple[int, str]:
|
||||
def majority_vote(self, candidates_tasks: List[List[Task]],
|
||||
**kwargs) -> Tuple[int, str]:
|
||||
candidates = [tasks[0].output_str for tasks in candidates_tasks]
|
||||
return get_digit_majority_vote_result(candidates)
|
||||
|
||||
|
||||
|
||||
@ -175,13 +175,19 @@ class ScaffoldingLlm:
|
||||
result = ScaffoldingResult(self.streaming_event)
|
||||
|
||||
async def put_request():
|
||||
request = ScaffoldingRequest(
|
||||
prompt=prompt,
|
||||
kwargs={},
|
||||
result=result,
|
||||
controller=self.prototype_controller.clone())
|
||||
|
||||
await self.task_queue.put(request)
|
||||
try:
|
||||
request = ScaffoldingRequest(
|
||||
prompt=prompt,
|
||||
kwargs={},
|
||||
result=result,
|
||||
controller=self.prototype_controller.clone())
|
||||
except Exception as e:
|
||||
self.task_queue.put(None)
|
||||
print(
|
||||
f"Error: build ScaffoldingRequest failed: {e} \n {traceback.format_exc()}"
|
||||
)
|
||||
else:
|
||||
await self.task_queue.put(request)
|
||||
|
||||
asyncio.run_coroutine_threadsafe(put_request(), self.loop)
|
||||
|
||||
@ -208,7 +214,7 @@ class ScaffoldingLlm:
|
||||
|
||||
def shutdown(self, shutdown_workers=False):
|
||||
|
||||
def shutdown_workers():
|
||||
def shutdown_workers_func():
|
||||
for worker in self.workers.values():
|
||||
worker.shutdown()
|
||||
|
||||
@ -228,4 +234,4 @@ class ScaffoldingLlm:
|
||||
self.shutdown_event.set()
|
||||
|
||||
if shutdown_workers:
|
||||
shutdown_workers()
|
||||
shutdown_workers_func()
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm.executor.result import GenerationResult
|
||||
from tensorrt_llm.executor.result import GenerationResult, TokenLogprobs
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -64,6 +65,7 @@ class GenerationTask(Task):
|
||||
# result field
|
||||
# link to TRTLLM's GenerationResult, for async update in streaming mode
|
||||
_result: Optional[GenerationResult] = None
|
||||
customized_result_fields: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def result(self) -> GenerationResult:
|
||||
@ -96,7 +98,7 @@ class GenerationTask(Task):
|
||||
0].cumulative_logprob if self._result else None
|
||||
|
||||
@property
|
||||
def logprobs(self) -> Optional[List[float]]:
|
||||
def logprobs(self) -> Optional[TokenLogprobs]:
|
||||
return self._result.outputs[0].logprobs if self._result else None
|
||||
|
||||
@property
|
||||
@ -115,6 +117,32 @@ class GenerationTask(Task):
|
||||
return self._result
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamGenerationTask(GenerationTask):
|
||||
# input field
|
||||
# if the flag is set to True, the worker will cancel the generation work
|
||||
cancel_flag: Optional[bool] = field(default=False)
|
||||
# the task will be returned to the controller with at least new streaming_step tokens
|
||||
# if the streaming_step is set to 0,
|
||||
# the task will be returned to the controller immediately with
|
||||
# new tokens that have already been generated.
|
||||
streaming_step: Optional[int] = field(default=1)
|
||||
|
||||
#result field
|
||||
# worker set this field and identify the same task by this field
|
||||
request_handle: Any = field(default=None)
|
||||
# worker set this field to True when the generation is finished
|
||||
end_flag: bool = field(default=False)
|
||||
|
||||
@staticmethod
|
||||
def create_from_generation_task(task: GenerationTask,
|
||||
streaming_step) -> "StreamGenerationTask":
|
||||
stream_task = StreamGenerationTask()
|
||||
stream_task.__dict__ = copy.deepcopy(task.__dict__)
|
||||
stream_task.streaming_step = streaming_step
|
||||
return stream_task
|
||||
|
||||
|
||||
@dataclass
|
||||
class RewardTask(Task):
|
||||
# input field
|
||||
|
||||
@ -1,15 +1,16 @@
|
||||
import asyncio
|
||||
from abc import ABC
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional
|
||||
|
||||
import openai
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.executor import GenerationExecutor
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, SchedulerConfig
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
from .task import GenerationTask, Task, TaskStatus
|
||||
from .task import GenerationTask, StreamGenerationTask, Task, TaskStatus
|
||||
|
||||
ExecutorCls = GenerationExecutor
|
||||
|
||||
@ -150,6 +151,7 @@ class TRTLLMWorker(Worker):
|
||||
max_num_tokens: int = 4096,
|
||||
kv_cache_free_gpu_memory_fraction: float = 0.9,
|
||||
disable_overlap_scheduler: bool = False,
|
||||
scheduler_config: Optional[SchedulerConfig] = None,
|
||||
):
|
||||
kv_cache_config = KvCacheConfig(
|
||||
free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction, )
|
||||
@ -168,7 +170,8 @@ class TRTLLMWorker(Worker):
|
||||
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_batch_size=max_batch_size,
|
||||
max_num_tokens=max_num_tokens)
|
||||
max_num_tokens=max_num_tokens,
|
||||
scheduler_config=scheduler_config)
|
||||
|
||||
worker = cls(llm, tokenizer)
|
||||
worker.own_llm = True
|
||||
@ -201,8 +204,44 @@ class TRTLLMWorker(Worker):
|
||||
# TODO: error handle
|
||||
return TaskStatus.SUCCESS
|
||||
|
||||
async def stream_generation_handler(
|
||||
self, task: StreamGenerationTask) -> TaskStatus:
|
||||
|
||||
async def get_step_or_more_tokens(task: StreamGenerationTask):
|
||||
if task.cancel_flag:
|
||||
task.end_flag = True
|
||||
task.request_handle.abort()
|
||||
return TaskStatus.SUCCESS
|
||||
|
||||
for _ in range(task.streaming_step):
|
||||
await task.request_handle._aresult_step()
|
||||
if task.request_handle._done:
|
||||
break
|
||||
|
||||
while not task.request_handle._done:
|
||||
async_task = asyncio.create_task(
|
||||
task.request_handle._aresult_step())
|
||||
if not async_task.done():
|
||||
async_task.cancel()
|
||||
break
|
||||
|
||||
if task.request_handle._done:
|
||||
task.end_flag = True
|
||||
|
||||
if getattr(task, 'end_flag', False):
|
||||
return TaskStatus.SUCCESS
|
||||
if task.request_handle is None:
|
||||
sampling_params = self.convert_task_params(task)
|
||||
task.request_handle = self.llm.generate_async(
|
||||
task.input_str, sampling_params=sampling_params, streaming=True)
|
||||
task._result = task.request_handle
|
||||
await get_step_or_more_tokens(task)
|
||||
|
||||
def shutdown(self):
|
||||
if self.own_llm:
|
||||
self.llm.shutdown()
|
||||
|
||||
task_handlers = {GenerationTask: generation_handler}
|
||||
task_handlers = {
|
||||
GenerationTask: generation_handler,
|
||||
StreamGenerationTask: stream_generation_handler
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user