mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Refactor scaffolding streaming feature and fix openai wo… (#8622)
Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
This commit is contained in:
parent
a4f75399b9
commit
cc286687c4
@ -1,10 +0,0 @@
|
||||
|
||||
This example shows how to use the `StreamGenerationTask` and `stream_generation_handler` to enable efficient streaming-based generation workflows.
|
||||
|
||||
How to run the example?
|
||||
|
||||
```bash
|
||||
python stream_generation_run.py
|
||||
```
|
||||
|
||||
See more detail on [tensorrt_llm/scaffolding/contrib/AsyncGeneration/README.md](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/scaffolding/contrib/AsyncGeneration/README.md).
|
||||
@ -1,104 +0,0 @@
|
||||
import argparse
|
||||
|
||||
from stream_generation_controller import NativeStreamGenerationController
|
||||
|
||||
from tensorrt_llm.scaffolding import ScaffoldingLlm, TRTLLMWorker
|
||||
from tensorrt_llm.scaffolding.contrib import (StreamGenerationTask,
|
||||
stream_generation_handler)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
# .e.g. DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B
|
||||
parser.add_argument(
|
||||
'--model_dir',
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the directory containing the generation model")
|
||||
parser.add_argument('--run_type', type=str, default='original')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def test(prompts, proposer_worker):
|
||||
prototype_controller = NativeStreamGenerationController(
|
||||
sampling_params={"temperature": 0.9})
|
||||
|
||||
llm = ScaffoldingLlm(
|
||||
prototype_controller,
|
||||
{NativeStreamGenerationController.WorkerTag.STREAM: proposer_worker},
|
||||
)
|
||||
results = llm.generate(prompts)
|
||||
for result in results:
|
||||
print(result.output.output_str)
|
||||
print(f'test main shutting down...')
|
||||
llm.shutdown()
|
||||
print(f'test worker shutting down...')
|
||||
proposer_worker.shutdown()
|
||||
print(f'test main shut down done')
|
||||
|
||||
|
||||
def test_step(prompts, proposer_worker):
|
||||
prototype_controller = NativeStreamGenerationController()
|
||||
prototype_controller.set_stream_step(20)
|
||||
|
||||
llm = ScaffoldingLlm(
|
||||
prototype_controller,
|
||||
{NativeStreamGenerationController.WorkerTag.STREAM: proposer_worker},
|
||||
)
|
||||
results = llm.generate(prompts)
|
||||
for result in results:
|
||||
print(result.output.output_str)
|
||||
print(f'test step main shutting down...')
|
||||
llm.shutdown()
|
||||
print(f'test step worker shutting down...')
|
||||
proposer_worker.shutdown()
|
||||
print(f'test step main shut down done')
|
||||
|
||||
|
||||
def test_cancel(prompts, proposer_worker):
|
||||
prototype_controller = NativeStreamGenerationController()
|
||||
prototype_controller.set_output_threshold(200)
|
||||
|
||||
llm = ScaffoldingLlm(
|
||||
prototype_controller,
|
||||
{NativeStreamGenerationController.WorkerTag.STREAM: proposer_worker},
|
||||
)
|
||||
results = llm.generate(prompts)
|
||||
for result in results:
|
||||
print(result.output.output_str)
|
||||
print(f'test cancel main shutting down...')
|
||||
llm.shutdown()
|
||||
print(f'test cancel worker shutting down...')
|
||||
proposer_worker.shutdown()
|
||||
print(f'test cancel main shut down done')
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
|
||||
prompts = [
|
||||
"Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\r\n\r\n",
|
||||
"There exist real numbers $x$ and $y$, both greater than 1, such that $\\log_x\\left(y^x\\right)=\\log_y\\left(x^{4y}\\right)=10$. Find $xy$.",
|
||||
"Find the largest possible real part of \\[(75+117i)z+\\frac{96+144i}{z}\\]where $z$ is a complex number with $|z|=4$.",
|
||||
]
|
||||
llm_worker = TRTLLMWorker.init_with_new_llm(
|
||||
args.model_dir,
|
||||
backend="pytorch",
|
||||
max_batch_size=32,
|
||||
max_num_tokens=4096,
|
||||
)
|
||||
|
||||
print(f'main llm worker init done')
|
||||
llm_worker.register_task_handler(StreamGenerationTask,
|
||||
stream_generation_handler)
|
||||
if args.run_type == 'original':
|
||||
test(prompts, llm_worker)
|
||||
elif args.run_type == 'step':
|
||||
test_step(prompts, llm_worker)
|
||||
elif args.run_type == 'cancel':
|
||||
test_cancel(prompts, llm_worker)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -65,15 +65,14 @@ def main():
|
||||
if args.streaming:
|
||||
|
||||
async def task(prompt: str):
|
||||
i = 0
|
||||
step = 0
|
||||
async for result in llm.generate_async(prompt):
|
||||
i += 1
|
||||
print(">>>", i, result)
|
||||
async for output in result.cur_output:
|
||||
print(">>>", i, len(output.outputs[0].token_ids), "\n",
|
||||
output.outputs[0].text)
|
||||
print(f">>> final output {len(result.outputs[0].token_ids)}\n",
|
||||
result.outputs[0].text)
|
||||
step += 1
|
||||
tokens_num = len(
|
||||
result.outputs[0].token_ids
|
||||
) if result.outputs[0].token_ids is not None else 0
|
||||
print(">>>", step, tokens_num, "\n", result.outputs[0].text)
|
||||
print(f">>> final output {tokens_num}\n", result.outputs[0].text)
|
||||
|
||||
# Need to provide LLM's event loop to get results in the middle of the whole process.
|
||||
asyncio.run_coroutine_threadsafe(task(prompts[0]), llm.loop).result()
|
||||
|
||||
@ -4,8 +4,8 @@ import asyncio
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from tensorrt_llm.scaffolding import OpenaiWorker, ScaffoldingLlm
|
||||
from tensorrt_llm.scaffolding.contrib import (ChatTask, MCPController,
|
||||
MCPWorker, chat_handler)
|
||||
from tensorrt_llm.scaffolding.contrib.mcp import (ChatTask, MCPController,
|
||||
MCPWorker, chat_handler)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
@ -28,7 +28,7 @@ def parse_arguments():
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from tensorrt_llm.scaffolding import OpenaiWorker, ScaffoldingLlm
|
||||
from tensorrt_llm.scaffolding.contrib import MCPController, MCPWorker
|
||||
from tensorrt_llm.scaffolding.contrib.mcp import MCPController, MCPWorker
|
||||
|
||||
|
||||
async def main():
|
||||
@ -41,7 +41,7 @@ async def main():
|
||||
]
|
||||
API_KEY = args.API_KEY
|
||||
urls = [
|
||||
"http://0.0.0.0:8080/sse", "http://0.0.0.0:8081/sse",
|
||||
#"http://0.0.0.0:8080/sse", "http://0.0.0.0:8081/sse",
|
||||
"http://0.0.0.0:8082/sse"
|
||||
]
|
||||
print(f"API_KEY {API_KEY}")
|
||||
@ -61,7 +61,7 @@ async def main():
|
||||
|
||||
future = llm.generate_async(prompts[0])
|
||||
result = await future.aresult()
|
||||
print(f"\nresult is {result.output.output_str}\n")
|
||||
print(f"\nresult is {result.outputs[0].text}\n")
|
||||
|
||||
print(f'main shutting down...')
|
||||
llm.shutdown()
|
||||
|
||||
@ -53,14 +53,12 @@ def test_async(prompt, proposer_worker):
|
||||
prototype_controller,
|
||||
{NativeGenerationController.WorkerTag.GENERATION: proposer_worker},
|
||||
)
|
||||
i = 0
|
||||
|
||||
step = 0
|
||||
async for result in llm.generate_async(prompt):
|
||||
i += 1
|
||||
print(">>>", i, result)
|
||||
async for output in result.cur_output:
|
||||
print(">>>", i, len(output.outputs[0].token_ids), "\n",
|
||||
output.outputs[0].text)
|
||||
step += 1
|
||||
print(">>>", step, len(result.outputs[0].token_ids), "\n",
|
||||
result.outputs[0].text)
|
||||
print(f">>> final output {len(result.outputs[0].token_ids)}\n",
|
||||
result.outputs[0].text)
|
||||
|
||||
|
||||
@ -104,7 +104,7 @@ def main():
|
||||
prompt = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\r\n\r\n"
|
||||
|
||||
result = llm.generate(prompt)
|
||||
extracted_answer = extract_answer_from_boxed(result.output.output_str)
|
||||
extracted_answer = extract_answer_from_boxed(result.outputs[0].text)
|
||||
print(f'extracted_answer={extracted_answer}')
|
||||
|
||||
llm.shutdown(shutdown_workers=True)
|
||||
|
||||
@ -25,6 +25,7 @@ __all__ = [
|
||||
"GenerationTask",
|
||||
"StreamGenerationTask",
|
||||
"RewardTask",
|
||||
"StreamGenerationTask",
|
||||
"Worker",
|
||||
"OpenaiWorker",
|
||||
"TRTOpenaiWorker",
|
||||
|
||||
@ -1,46 +0,0 @@
|
||||
## Overview
|
||||
|
||||
`StreamGenerationTask` is an extension of `GenerationTask` designed for token-level streaming generation in asynchronous LLM workflows using TensorRT-LLM. It enables the controller to receive partial results during generation, which is critical for real-time or latency-sensitive applications such as chatbots, speech generation, or UI-interactive systems.
|
||||
|
||||
---
|
||||
|
||||
## Features
|
||||
|
||||
- ✅ Supports **streamed token delivery** by step (e.g., `streaming_step=1`).
|
||||
- ✅ Supports **cancellation** of generation using a flag (`cancel_flag=True`).
|
||||
- ✅ Tracks **stream completion status** (`end_flag=True` when done).
|
||||
- ✅ Integrated with a streaming-capable LLM interface (`generate_async`).
|
||||
|
||||
---
|
||||
|
||||
## Fields in `StreamGenerationTask`
|
||||
|
||||
| Field | Description |
|
||||
|-------|-------------|
|
||||
| `cancel_flag` | If `True`, the generation will be cancelled on the worker side. |
|
||||
| `streaming_step` | Number of new tokens required before returning control to the controller. If set to `0`, the task is returned immediately if any new tokens are available. |
|
||||
| `request_handle` | Internal handle for the streaming generation (used by the worker). |
|
||||
| `end_flag` | Indicates whether generation is finished. |
|
||||
| `output_str` / `output_tokens` / `logprobs` | Outputs after each generation step. |
|
||||
|
||||
---
|
||||
|
||||
## Usage in Controller/Worker
|
||||
|
||||
The Controller can utilize `StreamGenerationTask` to enable efficient streaming-based generation workflows:
|
||||
- It sends tasks to the worker, which returns them when the number of newly generated tokens reaches the specified `streaming_step`.
|
||||
- It can cancel long-running tasks by setting `task.cancel_flag = True` when the number of generated tokens exceeds a predefined threshold.
|
||||
|
||||
To support this behavior on the worker side, we have implemented `stream_generation_handler` and you need to register it with the worker in your project. This handler should process `StreamGenerationTask` instances step-by-step and update relevant fields such as `output_tokens`, `output_str`.
|
||||
|
||||
This design allows the controller and worker to coordinate generation in a token-efficient and responsive manner, ideal for real-time applications.
|
||||
|
||||
You can see more details in `stream_generation_controller.py` and `stream_generation_task.py` from `examples/scaffolding/contrib/AsyncGeneration`.
|
||||
|
||||
## Notes
|
||||
Remember to register the `stream_generation_handler` with the `TRTLLMWorker`.
|
||||
|
||||
## TODO
|
||||
|
||||
- Add error handling for failed `request_handle`.
|
||||
- Support retry or backoff mechanism if generation stalls.
|
||||
@ -1,3 +0,0 @@
|
||||
from .stream_generation import StreamGenerationTask, stream_generation_handler
|
||||
|
||||
__all__ = ["stream_generation_handler", "StreamGenerationTask"]
|
||||
@ -126,18 +126,10 @@ class DynasorGenerationController(Controller):
|
||||
probe_answers[-self.certainty_threshold:])
|
||||
== self.certainty_threshold
|
||||
and sum(probe_certain_count) == self.certainty_threshold):
|
||||
tasks[0].result = probe_task.result
|
||||
# If the current prompt indicates the chain-of-thought phase has ended, use one type of suffix.
|
||||
if "</think>" in current_prompt:
|
||||
tasks[0].output_str = (current_prompt + self.answer_suffix +
|
||||
probe_answers[-1] + "}\n\\]")
|
||||
return
|
||||
else:
|
||||
# Otherwise, use the suffix with marker to transition clearly.
|
||||
tasks[0].output_str = (current_prompt +
|
||||
self.answer_suffix_with_marker +
|
||||
probe_answers[-1] + "}\n\\]")
|
||||
return
|
||||
suffix = self.answer_suffix if "</think>" in current_prompt else self.answer_suffix_with_marker
|
||||
suffix += probe_answers[-1] + "}\n\\]"
|
||||
current_prompt += suffix
|
||||
break
|
||||
|
||||
# If not confident, do another round of generation
|
||||
# Append the newly generated text from the proposer to the current prompt for the next iteration.
|
||||
@ -145,7 +137,6 @@ class DynasorGenerationController(Controller):
|
||||
|
||||
# If the maximum token limit is reached without satisfying the certainty condition,
|
||||
# output the accumulated prompt as the final output.
|
||||
tasks[0].result = proposer_task.result
|
||||
tasks[0].output_str = current_prompt
|
||||
return
|
||||
|
||||
|
||||
@ -19,11 +19,9 @@ def add_param_if_not_none(params, key, candidate_values):
|
||||
def combine_params_with_chat_task(worker, params: dict, task: ChatTask):
|
||||
params["messages"] = task.messages
|
||||
|
||||
add_param_if_not_none(params, "max_tokens",
|
||||
[task.max_tokens, worker.max_tokens])
|
||||
add_param_if_not_none(params, "temperature",
|
||||
[task.temperature, worker.temperature])
|
||||
add_param_if_not_none(params, "top_p", [task.top_p, worker.top_p])
|
||||
add_param_if_not_none(params, "max_tokens", [task.max_tokens])
|
||||
add_param_if_not_none(params, "temperature", [task.temperature])
|
||||
add_param_if_not_none(params, "top_p", [task.top_p])
|
||||
|
||||
add_param_if_not_none(params, "tools", [task.tools])
|
||||
|
||||
|
||||
@ -67,7 +67,7 @@ class NativeGenerationController(Controller):
|
||||
for key, value in self.sampling_params.items():
|
||||
if getattr(task, key) is None:
|
||||
setattr(task, key, value)
|
||||
task.streaming = self.streaming
|
||||
task.streaming_output_flag = self.streaming
|
||||
|
||||
yield tasks
|
||||
|
||||
|
||||
@ -1,80 +1,67 @@
|
||||
import asyncio
|
||||
from typing import Mapping, Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Mapping, Optional, Union
|
||||
|
||||
from tensorrt_llm.executor.result import GenerationResult
|
||||
|
||||
@dataclass
|
||||
class ScaffoldingOutput:
|
||||
text: str
|
||||
token_ids: List[int]
|
||||
|
||||
|
||||
class ScaffoldingResult:
|
||||
|
||||
def __init__(self, streaming_event: Optional[asyncio.Event] = None):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.aqueue = asyncio.Queue()
|
||||
self.cur_output: GenerationResult = None
|
||||
#self.cur_output: GenerationResult = None
|
||||
self.outputs = []
|
||||
# only support one output for now, so use an empty obj to init
|
||||
self.outputs.append(ScaffoldingOutput("", []))
|
||||
self._done = False
|
||||
self.task_collections = None
|
||||
self.streaming_event = streaming_event
|
||||
|
||||
def set_output(self, output: GenerationResult):
|
||||
def set_output(self, output: Union[ScaffoldingOutput, Any]):
|
||||
if isinstance(output, ScaffoldingOutput):
|
||||
self.set_output_streaming(output)
|
||||
# terminate
|
||||
self.set_output_streaming(None)
|
||||
|
||||
def set_output_streaming(self, output: Union[ScaffoldingOutput, Any]):
|
||||
self.aqueue.put_nowait(output)
|
||||
self._done = True
|
||||
|
||||
async def set_output_async(self, output: GenerationResult):
|
||||
await self.aqueue.put(output)
|
||||
|
||||
def set_task_collections(self, task_collections: Mapping[str,
|
||||
"TaskCollection"]):
|
||||
self.task_collections = task_collections
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return self.cur_output.outputs if self.cur_output else None
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
return self.cur_output is not None and self.cur_output.finished
|
||||
|
||||
async def _aresult_step(self):
|
||||
# TODO: error handling or raise exception?
|
||||
response = await self.aqueue.get()
|
||||
if response is None:
|
||||
raise Exception("ScaffoldingLlm execution failed")
|
||||
self._handle_response(response)
|
||||
obj = await self.aqueue.get()
|
||||
if obj is None:
|
||||
self._done = True
|
||||
else: # obj is ScaffoldingOutput
|
||||
self.outputs[0] = obj
|
||||
|
||||
def result(self, timeout: Optional[float] = None) -> "ScaffoldingResult":
|
||||
if not self.finished:
|
||||
if not self._done:
|
||||
loop = asyncio.get_event_loop()
|
||||
asyncio.run_coroutine_threadsafe(self.aresult(), loop).result()
|
||||
return self
|
||||
|
||||
async def aresult(self) -> "ScaffoldingResult":
|
||||
while not self.finished:
|
||||
while not self._done:
|
||||
await self._aresult_step()
|
||||
return self
|
||||
|
||||
def __await__(self):
|
||||
return self.aresult().__await__()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._done and self.finished:
|
||||
raise StopIteration
|
||||
|
||||
self._result_step()
|
||||
return self
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.finished:
|
||||
self.streaming_event.set() if self.streaming_event else None
|
||||
if self._done and self.finished:
|
||||
if self._done:
|
||||
raise StopAsyncIteration
|
||||
|
||||
await self._aresult_step()
|
||||
return self
|
||||
|
||||
def _handle_response(self, response: GenerationResult):
|
||||
self.cur_output = response
|
||||
|
||||
@ -34,7 +34,6 @@ class ScaffoldingLlm:
|
||||
self.task_queue = asyncio.Queue()
|
||||
self.main_loop_stop_event = asyncio.Event()
|
||||
self.shutdown_event = asyncio.Event()
|
||||
self.streaming_event = asyncio.Event()
|
||||
if self.own_loop:
|
||||
self._run_main_loop_thread()
|
||||
else:
|
||||
@ -82,10 +81,10 @@ class ScaffoldingLlm:
|
||||
]
|
||||
await asyncio.gather(*async_tasks)
|
||||
for task in tasks:
|
||||
if getattr(task, 'streaming', False):
|
||||
await request.result.set_output_async(task.result)
|
||||
self.streaming_event.clear()
|
||||
await self.streaming_event.wait()
|
||||
if task.streaming_output_flag:
|
||||
for output in task.streaming_output_list:
|
||||
request.result.set_output_streaming(output)
|
||||
task.streaming_output_list = []
|
||||
|
||||
async def _handle_parallel_process(self,
|
||||
tasks: ParallelProcess,
|
||||
@ -172,7 +171,7 @@ class ScaffoldingLlm:
|
||||
self.main_loop_thread.start()
|
||||
|
||||
def generate_async(self, prompt: str) -> ScaffoldingResult:
|
||||
result = ScaffoldingResult(self.streaming_event)
|
||||
result = ScaffoldingResult()
|
||||
|
||||
async def put_request():
|
||||
try:
|
||||
|
||||
@ -5,20 +5,36 @@ from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm.executor.result import GenerationResult, TokenLogprobs
|
||||
from tensorrt_llm.executor.result import TokenLogprobs
|
||||
|
||||
from .result import ScaffoldingOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
# Reserve for custom input params.
|
||||
custom_input_params: Optional[dict] = None
|
||||
|
||||
# Scaffolding delivers the task to the Worker by worker_tag.
|
||||
worker_tag: str = field(default=None)
|
||||
|
||||
# For streaming output.
|
||||
streaming_output_flag: bool = field(default=False)
|
||||
streaming_output_list: list[Any] = field(default_factory=list)
|
||||
|
||||
# Reserve for custom input params.
|
||||
custom_input_params: Optional[dict] = None
|
||||
|
||||
# Reserve for custom output params.
|
||||
custom_output_params: Optional[dict] = None
|
||||
|
||||
@staticmethod
|
||||
def create_from_prompt(prompt: str) -> "Task":
|
||||
pass
|
||||
|
||||
def create_scaffolding_output(self) -> ScaffoldingOutput:
|
||||
pass
|
||||
|
||||
def create_scaffolding_output_stream(self) -> List[ScaffoldingOutput]:
|
||||
pass
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
SUCCESS = "success"
|
||||
@ -33,7 +49,7 @@ class GenerationTask(Task):
|
||||
input_str: Optional[str] = None
|
||||
skip_tokenizer: bool = False
|
||||
skip_detokenizer: bool = False
|
||||
streaming: bool = False
|
||||
#streaming: bool = False
|
||||
|
||||
# sampling params for openai
|
||||
# Ordered by official OpenAI API documentation
|
||||
@ -63,48 +79,14 @@ class GenerationTask(Task):
|
||||
worker_tag: Union[str, "Controller.WorkerTag"] = None
|
||||
|
||||
# result field
|
||||
# link to TRTLLM's GenerationResult, for async update in streaming mode
|
||||
_result: Optional[GenerationResult] = None
|
||||
output_str: Optional[str] = None
|
||||
output_tokens: Optional[List[int]] = None
|
||||
# TODO: support openai API format context logits
|
||||
context_logits: Optional[torch.Tensor] = None
|
||||
# TODO: don't not use TokenLogprobs for general support
|
||||
logprobs: Optional[TokenLogprobs] = None
|
||||
customized_result_fields: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def result(self) -> GenerationResult:
|
||||
return self._result
|
||||
|
||||
@result.setter
|
||||
def result(self, result: GenerationResult) -> None:
|
||||
self._result = result
|
||||
|
||||
@property
|
||||
def outputs(self) -> Optional[List[dict]]:
|
||||
return self._result.outputs if self._result else None
|
||||
|
||||
@property
|
||||
def output_tokens(self) -> List[int]:
|
||||
return self._result.outputs[0].token_ids if self._result else None
|
||||
|
||||
@property
|
||||
def output_str(self) -> Optional[str]:
|
||||
return self._result.outputs[0].text if self._result else None
|
||||
|
||||
@output_str.setter
|
||||
def output_str(self, output) -> Optional[str]:
|
||||
assert self.result
|
||||
self._result.outputs[0].text = output
|
||||
|
||||
@property
|
||||
def cumulative_logprob(self) -> Optional[float]:
|
||||
return self._result.outputs[
|
||||
0].cumulative_logprob if self._result else None
|
||||
|
||||
@property
|
||||
def logprobs(self) -> Optional[TokenLogprobs]:
|
||||
return self._result.outputs[0].logprobs if self._result else None
|
||||
|
||||
@property
|
||||
def context_logits(self) -> Optional[torch.Tensor]:
|
||||
return self._result.context_logits if self._result else None
|
||||
|
||||
@staticmethod
|
||||
def create_from_prompt(prompt: str) -> "GenerationTask":
|
||||
task = GenerationTask()
|
||||
@ -113,8 +95,8 @@ class GenerationTask(Task):
|
||||
task.skip_detokenizer = False
|
||||
return task
|
||||
|
||||
def create_scaffolding_output(self) -> GenerationResult:
|
||||
return self._result
|
||||
def create_scaffolding_output(self) -> ScaffoldingOutput:
|
||||
return ScaffoldingOutput(self.output_str, self.output_tokens)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -148,3 +130,21 @@ class RewardTask(Task):
|
||||
# input field
|
||||
input_tokens: Optional[List[int]] = field(default=None)
|
||||
input_str: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamGenerationTask(GenerationTask):
|
||||
# input field
|
||||
# if the flag is set to True, the worker will cancel the generation work
|
||||
cancel_flag: Optional[bool] = field(default=False)
|
||||
# the task will be returned to the controller with at least new streaming_step tokens
|
||||
# if the streaming_step is set to 0,
|
||||
# the task will be returned to the controller immediately with
|
||||
# new tokens that have already been generated.
|
||||
streaming_step: Optional[int] = field(default=1)
|
||||
|
||||
#result field
|
||||
# worker set this field and identify the same task by this field
|
||||
request_handle: Any = field(default=None)
|
||||
# worker set this field to True when the generation is finished
|
||||
end_flag: bool = field(default=False)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import copy
|
||||
from abc import ABC
|
||||
from typing import Callable, Optional
|
||||
|
||||
@ -6,10 +7,11 @@ import openai
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.executor import GenerationExecutor
|
||||
from tensorrt_llm.executor import GenerationExecutor, GenerationResult
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, SchedulerConfig
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
from .result import ScaffoldingOutput
|
||||
from .task import GenerationTask, StreamGenerationTask, Task, TaskStatus
|
||||
|
||||
ExecutorCls = GenerationExecutor
|
||||
@ -96,7 +98,7 @@ class OpenaiWorker(Worker):
|
||||
def fill_generation_task_with_response(self, task: GenerationTask,
|
||||
response: openai.Completion):
|
||||
task.output_str = response.choices[0].text
|
||||
task.logprobs = response.choices[0].logprobs
|
||||
task.output_tokens = response.choices[0].token_ids
|
||||
|
||||
async def generation_handler(self, task: GenerationTask) -> TaskStatus:
|
||||
params = self.convert_task_params(task)
|
||||
@ -190,55 +192,70 @@ class TRTLLMWorker(Worker):
|
||||
logprobs=task.num_logprobs)
|
||||
return sampling_params
|
||||
|
||||
async def streaming_generate_helper(self, generate_result, step_at_least,
|
||||
streaming_output_list):
|
||||
step = 0
|
||||
while not generate_result._done:
|
||||
async_task = asyncio.create_task(generate_result._aresult_step())
|
||||
if step_at_least and step >= step_at_least and not async_task.done(
|
||||
):
|
||||
async_task.cancel()
|
||||
break
|
||||
await async_task
|
||||
step += 1
|
||||
# do not put the last token to the streaming_output_list
|
||||
if streaming_output_list is not None and not generate_result._done:
|
||||
streaming_output_list.append(
|
||||
ScaffoldingOutput(
|
||||
generate_result.outputs[0].text,
|
||||
copy.deepcopy(generate_result.outputs[0].token_ids)))
|
||||
|
||||
def fill_task_with_result(self, task: GenerationTask,
|
||||
result: GenerationResult):
|
||||
task.output_str = result.outputs[0].text
|
||||
task.output_tokens = result.outputs[0].token_ids
|
||||
task.context_logits = result.context_logits
|
||||
task.logprobs = result.outputs[0].logprobs
|
||||
|
||||
async def generation_handler(self, task: GenerationTask) -> TaskStatus:
|
||||
sampling_params = self.convert_task_params(task)
|
||||
|
||||
# If the task is streaming, we will return result directly for
|
||||
# async iteration outside. Otherwise, we will wait.
|
||||
if task.streaming:
|
||||
if task.streaming_output_flag:
|
||||
result = self.llm.generate_async(task.input_str,
|
||||
sampling_params=sampling_params,
|
||||
streaming=True)
|
||||
await self.streaming_generate_helper(result, None,
|
||||
task.streaming_output_list)
|
||||
else:
|
||||
result = await self.llm.generate_async(
|
||||
task.input_str, sampling_params=sampling_params)
|
||||
task.result = result
|
||||
#task.result = result
|
||||
self.fill_task_with_result(task, result)
|
||||
|
||||
# TODO: error handle
|
||||
return TaskStatus.SUCCESS
|
||||
|
||||
async def stream_generation_handler(
|
||||
self, task: StreamGenerationTask) -> TaskStatus:
|
||||
|
||||
async def get_step_or_more_tokens(task: StreamGenerationTask):
|
||||
if task.cancel_flag:
|
||||
task.end_flag = True
|
||||
task.request_handle.abort()
|
||||
return TaskStatus.SUCCESS
|
||||
|
||||
for _ in range(task.streaming_step):
|
||||
await task.request_handle._aresult_step()
|
||||
if task.request_handle._done:
|
||||
break
|
||||
|
||||
while not task.request_handle._done:
|
||||
async_task = asyncio.create_task(
|
||||
task.request_handle._aresult_step())
|
||||
if not async_task.done():
|
||||
async_task.cancel()
|
||||
break
|
||||
|
||||
if task.request_handle._done:
|
||||
task.end_flag = True
|
||||
|
||||
if getattr(task, 'end_flag', False):
|
||||
return TaskStatus.SUCCESS
|
||||
sampling_params = self.convert_task_params(task)
|
||||
if task.request_handle is None:
|
||||
sampling_params = self.convert_task_params(task)
|
||||
task.request_handle = self.llm.generate_async(
|
||||
task.input_str, sampling_params=sampling_params, streaming=True)
|
||||
task._result = task.request_handle
|
||||
await get_step_or_more_tokens(task)
|
||||
|
||||
if task.cancel_flag:
|
||||
task.end_flag = True
|
||||
task.request_handle.abort()
|
||||
return TaskStatus.SUCCESS
|
||||
|
||||
await self.streaming_generate_helper(
|
||||
task.request_handle, task.streaming_step,
|
||||
task.streaming_output_queue if task.streaming_output_flag else None)
|
||||
|
||||
self.fill_task_with_result(task, task.request_handle)
|
||||
|
||||
if task.request_handle._done:
|
||||
task.end_flag = True
|
||||
return TaskStatus.SUCCESS
|
||||
|
||||
def shutdown(self):
|
||||
if self.own_llm:
|
||||
|
||||
@ -56,6 +56,6 @@ def test_scaffolding_benchmark():
|
||||
|
||||
assert len(results) == requests_num
|
||||
assert len(requests_execution_time) == requests_num
|
||||
assert results[0].cur_output == OUTPUT_STR
|
||||
assert results[0].outputs[0].text == OUTPUT_STR
|
||||
assert results[0].task_collections[
|
||||
"bench_dummy_collection"].output_len == len(OUTPUT_STR)
|
||||
|
||||
@ -49,8 +49,8 @@ def test_unbatched_scaffolding_sync(default_prompt, deepseek_distill_7b_path):
|
||||
scaffolding_llm = create_scaffolding_llm_with_native_generation_controller(
|
||||
deepseek_distill_7b_path)
|
||||
result = scaffolding_llm.generate(default_prompt)
|
||||
assert isinstance(result.output.output_str, str) and len(
|
||||
result.output.output_str) > 0, "Output should be a non-empty string"
|
||||
assert isinstance(result.outputs[0].text, str) and len(
|
||||
result.outputs[0].text) > 0, "Output should be a non-empty string"
|
||||
scaffolding_llm.shutdown(shutdown_workers=True)
|
||||
|
||||
|
||||
@ -62,8 +62,8 @@ def test_batched_scaffolding_sync(default_prompt, deepseek_distill_7b_path):
|
||||
results = scaffolding_llm.generate(prompts)
|
||||
assert len(results) == batch_size
|
||||
for result in results:
|
||||
assert isinstance(result.output.output_str, str) and len(
|
||||
result.output.output_str) > 0, "Output should be a non-empty string"
|
||||
assert isinstance(result.outputs[0].text, str) and len(
|
||||
result.outputs[0].text) > 0, "Output should be a non-empty string"
|
||||
scaffolding_llm.shutdown(shutdown_workers=True)
|
||||
|
||||
|
||||
@ -74,8 +74,8 @@ def test_async_scaffolding_generation(default_prompt, deepseek_distill_7b_path):
|
||||
deepseek_distill_7b_path)
|
||||
future = scaffolding_llm.generate_async(default_prompt)
|
||||
result = await future.aresult()
|
||||
assert isinstance(result.output.output_str, str) and len(
|
||||
result.output.output_str) > 0, "Output should be a non-empty string"
|
||||
assert isinstance(result.outputs[0].text, str) and len(
|
||||
result.outputs[0].text) > 0, "Output should be a non-empty string"
|
||||
scaffolding_llm.shutdown(shutdown_workers=True)
|
||||
|
||||
import asyncio
|
||||
@ -86,6 +86,6 @@ def test_majority_vote(default_prompt, deepseek_distill_7b_path):
|
||||
scaffolding_llm = create_scaffolding_llm_with_majority_vote_controller(
|
||||
deepseek_distill_7b_path, samples_num=3)
|
||||
result = scaffolding_llm.generate(default_prompt)
|
||||
assert isinstance(result.output.output_str, str) and len(
|
||||
result.output.output_str) > 0, "Output should be a non-empty string"
|
||||
assert isinstance(result.outputs[0].text, str) and len(
|
||||
result.outputs[0].text) > 0, "Output should be a non-empty string"
|
||||
scaffolding_llm.shutdown(shutdown_workers=True)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user