mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: fredw (generated by with_the_same_user script) <20514172+WeiHaocheng@users.noreply.github.com>
This commit is contained in:
parent
ae34d60108
commit
3fc2a16920
@ -9,12 +9,12 @@ from tensorrt_llm.scaffolding import (MajorityVoteController,
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
# .e.g. /home/scratch.trt_llm_data/llm-models/DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B
|
||||
parser.add_argument(
|
||||
'--generation_dir',
|
||||
'--model_dir',
|
||||
type=str,
|
||||
default=
|
||||
"/home/scratch.trt_llm_data/llm-models/DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B"
|
||||
)
|
||||
required=True,
|
||||
help="Path to the directory containing the generation model")
|
||||
# https://github.com/QwenLM/Qwen2.5-Math/blob/main/evaluation/data/aime24/test.jsonl
|
||||
parser.add_argument('--jsonl_file', type=str, default='./test.jsonl')
|
||||
parser.add_argument('--threshold', type=float, default=None)
|
||||
@ -36,7 +36,7 @@ def main():
|
||||
args = parse_arguments()
|
||||
workers = {}
|
||||
|
||||
llm_worker = TRTLLMWorker.init_with_new_llm(args.generation_dir,
|
||||
llm_worker = TRTLLMWorker.init_with_new_llm(args.model_dir,
|
||||
backend="pytorch",
|
||||
max_batch_size=32,
|
||||
max_num_tokens=4096,
|
||||
|
||||
@ -1,46 +1,10 @@
|
||||
## Overview
|
||||
|
||||
`StreamGenerationTask` is an extension of `GenerationTask` designed for token-level streaming generation in asynchronous LLM workflows using TensorRT-LLM. It enables the controller to receive partial results during generation, which is critical for real-time or latency-sensitive applications such as chatbots, speech generation, or UI-interactive systems.
|
||||
This example shows how to use the `StreamGenerationTask` and `stream_generation_handler` to enable efficient streaming-based generation workflows.
|
||||
|
||||
---
|
||||
How to run the example?
|
||||
|
||||
## Features
|
||||
```bash
|
||||
python stream_generation_run.py
|
||||
```
|
||||
|
||||
- ✅ Supports **streamed token delivery** by step (e.g., `streaming_step=1`).
|
||||
- ✅ Supports **cancellation** of generation using a flag (`cancel_flag=True`).
|
||||
- ✅ Tracks **stream completion status** (`end_flag=True` when done).
|
||||
- ✅ Integrated with a streaming-capable LLM interface (`generate_async`).
|
||||
|
||||
---
|
||||
|
||||
## Fields in `StreamGenerationTask`
|
||||
|
||||
| Field | Description |
|
||||
|-------|-------------|
|
||||
| `cancel_flag` | If `True`, the generation will be cancelled on the worker side. |
|
||||
| `streaming_step` | Number of new tokens required before returning control to the controller. If set to `0`, the task is returned immediately if any new tokens are available. |
|
||||
| `request_handle` | Internal handle for the streaming generation (used by the worker). |
|
||||
| `end_flag` | Indicates whether generation is finished. |
|
||||
| `output_str` / `output_tokens` / `logprobs` | Outputs after each generation step. |
|
||||
|
||||
---
|
||||
|
||||
## Usage in Controller/Worker
|
||||
|
||||
The Controller can utilize `StreamGenerationTask` to enable efficient streaming-based generation workflows:
|
||||
- It sends tasks to the worker, which returns them when the number of newly generated tokens reaches the specified `streaming_step`.
|
||||
- It can cancel long-running tasks by setting `task.cancel_flag = True` when the number of generated tokens exceeds a predefined threshold.
|
||||
|
||||
To support this behavior on the worker side, you need to implement a `stream_generation_handler` and register it with the worker. This handler should process `StreamGenerationTask` instances step-by-step and update relevant fields such as `output_tokens`, `output_str`.
|
||||
|
||||
This design allows the controller and worker to coordinate generation in a token-efficient and responsive manner, ideal for real-time applications.
|
||||
|
||||
You can see more details in stream_generation_controller.py and stream_generation_task.py
|
||||
|
||||
## Notes
|
||||
Ensure the `worker.llm.generate_async(...)` method supports streaming=True.
|
||||
|
||||
## TODO
|
||||
|
||||
- Add error handling for failed `request_handle`
|
||||
- Support retry or backoff mechanism if generation stalls
|
||||
See more detail on [tensorrt_llm/scaffolding/contrib/AsyncGeneration/README.md](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/scaffolding/contrib/AsyncGeneration/README.md).
|
||||
|
||||
@ -2,10 +2,8 @@ import copy
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from stream_generation_task import StreamGenerationTask
|
||||
|
||||
from tensorrt_llm.scaffolding.controller import Controller
|
||||
from tensorrt_llm.scaffolding.task import GenerationTask, Task
|
||||
from tensorrt_llm.scaffolding import Controller, GenerationTask, Task
|
||||
from tensorrt_llm.scaffolding.contrib import StreamGenerationTask
|
||||
|
||||
|
||||
class NativeStreamGenerationController(Controller):
|
||||
|
||||
@ -1,20 +1,20 @@
|
||||
import argparse
|
||||
|
||||
from stream_generation_controller import NativeStreamGenerationController
|
||||
from stream_generation_task import (StreamGenerationTask,
|
||||
stream_generation_handler)
|
||||
|
||||
from tensorrt_llm.scaffolding import ScaffoldingLlm, TRTLLMWorker
|
||||
from tensorrt_llm.scaffolding.contrib import (StreamGenerationTask,
|
||||
stream_generation_handler)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
# .e.g. /home/scratch.trt_llm_data/llm-models/DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B
|
||||
parser.add_argument(
|
||||
'--generation_dir',
|
||||
'--model_dir',
|
||||
type=str,
|
||||
default=
|
||||
"/home/scratch.trt_llm_data/llm-models/DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B"
|
||||
)
|
||||
required=True,
|
||||
help="Path to the directory containing the generation model")
|
||||
parser.add_argument('--run_type', type=str, default='original')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
@ -81,7 +81,7 @@ def main():
|
||||
"There exist real numbers $x$ and $y$, both greater than 1, such that $\\log_x\\left(y^x\\right)=\\log_y\\left(x^{4y}\\right)=10$. Find $xy$.",
|
||||
"Find the largest possible real part of \\[(75+117i)z+\\frac{96+144i}{z}\\]where $z$ is a complex number with $|z|=4$.",
|
||||
]
|
||||
llm_worker = TRTLLMWorker.init_with_new_llm(args.generation_dir,
|
||||
llm_worker = TRTLLMWorker.init_with_new_llm(args.model_dir,
|
||||
backend="pytorch",
|
||||
max_batch_size=32,
|
||||
max_num_tokens=4096,
|
||||
|
||||
2
examples/scaffolding/contrib/Dynasor/requirements.txt
Normal file
2
examples/scaffolding/contrib/Dynasor/requirements.txt
Normal file
@ -0,0 +1,2 @@
|
||||
latex2sympy2
|
||||
word2number
|
||||
@ -1,16 +1,18 @@
|
||||
import argparse
|
||||
|
||||
from dynasor_controller import DynasorGenerationController
|
||||
|
||||
from tensorrt_llm.scaffolding import (MajorityVoteController, ScaffoldingLlm,
|
||||
TRTLLMWorker)
|
||||
from tensorrt_llm.scaffolding.contrib import DynasorGenerationController
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--generation_dir",
|
||||
type=str,
|
||||
default="./models/DeepSeek-R1-Distill-Qwen-7B")
|
||||
# .e.g. /home/scratch.trt_llm_data/llm-models/DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B
|
||||
parser.add_argument(
|
||||
'--model_dir',
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the directory containing the generation model")
|
||||
parser.add_argument("--max_num_tokens", type=int, default=7000)
|
||||
parser.add_argument("--majority_vote", action='store_true')
|
||||
parser.add_argument('--sample_num', type=int, default=3)
|
||||
@ -18,9 +20,9 @@ def parse_arguments():
|
||||
return args
|
||||
|
||||
|
||||
def test_sync(prompts, proposer_worker, args):
|
||||
def test(prompts, proposer_worker, args):
|
||||
dynasor_generation_controller = DynasorGenerationController(
|
||||
generation_dir=args.generation_dir, max_tokens=args.max_num_tokens)
|
||||
generation_dir=args.model_dir, max_tokens=args.max_num_tokens)
|
||||
|
||||
# If majority voting is requested, wrap the controller in MajorityVoteController
|
||||
if args.majority_vote:
|
||||
@ -65,11 +67,9 @@ def main():
|
||||
]
|
||||
|
||||
llm_worker = TRTLLMWorker.init_with_new_llm(
|
||||
args.generation_dir,
|
||||
backend="pytorch",
|
||||
max_num_tokens=args.max_num_tokens)
|
||||
args.model_dir, backend="pytorch", max_num_tokens=args.max_num_tokens)
|
||||
|
||||
test_sync(prompts, llm_worker, args)
|
||||
test(prompts, llm_worker, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,15 +0,0 @@
|
||||
# Contrib Examples
|
||||
|
||||
We create this directory to store the community contributed examples.
|
||||
|
||||
Contributors can add examples of customize inference time compute methods with customize Controller/Task/Worker.
|
||||
|
||||
We will continue to move some generic works on this directory back to the main code.
|
||||
|
||||
### How to create a new project
|
||||
|
||||
Just create a new directory and add your code there.
|
||||
|
||||
### How to make your code include Controller/Task/Worker can be reused by other projects
|
||||
|
||||
Just add your Controller/Task/Worker to the `__init__.py` file of contrib directory.
|
||||
@ -1,8 +0,0 @@
|
||||
from tensorrt_llm.scaffolding import * # noqa
|
||||
|
||||
from .Dynasor.dynasor_controller import DynasorGenerationController
|
||||
|
||||
__all__ = [
|
||||
'NativeStreamGenerationController', 'StreamGenerationTask',
|
||||
'stream_generation_handler', 'DynasorGenerationController'
|
||||
]
|
||||
@ -7,13 +7,12 @@ from tensorrt_llm.scaffolding import (NativeGenerationController,
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
# .e.g. /home/scratch.trt_llm_data/llm-models/DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B
|
||||
parser.add_argument(
|
||||
'--generation_dir',
|
||||
'--model_dir',
|
||||
type=str,
|
||||
default=
|
||||
"/home/scratch.trt_llm_data/llm-models/DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B"
|
||||
)
|
||||
parser.add_argument('--verifier_dir', type=str, default=None)
|
||||
required=True,
|
||||
help="Path to the directory containing the generation model")
|
||||
parser.add_argument('--run_async', action='store_true')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
@ -68,7 +67,7 @@ def main():
|
||||
"Find the largest possible real part of \\[(75+117i)z+\\frac{96+144i}{z}\\]where $z$ is a complex number with $|z|=4$.",
|
||||
]
|
||||
|
||||
llm_worker = TRTLLMWorker.init_with_new_llm(args.generation_dir,
|
||||
llm_worker = TRTLLMWorker.init_with_new_llm(args.model_dir,
|
||||
backend="pytorch",
|
||||
max_batch_size=32,
|
||||
max_num_tokens=4096,
|
||||
|
||||
@ -8,10 +8,23 @@ from .task import GenerationTask, RewardTask, Task, TaskStatus
|
||||
from .worker import OpenaiWorker, TRTLLMWorker, TRTOpenaiWorker, Worker
|
||||
|
||||
__all__ = [
|
||||
"ScaffoldingLlm", "ScaffoldingOutput", "ParallelProcess", "Controller",
|
||||
"NativeGenerationController", "NativeRewardController",
|
||||
"MajorityVoteController", "BestOfNController", "Task", "GenerationTask",
|
||||
"RewardTask", "Worker", "OpenaiWorker", "TRTOpenaiWorker", "TRTLLMWorker",
|
||||
"TaskStatus", "extract_answer_from_boxed", "extract_answer_with_regex",
|
||||
"get_digit_majority_vote_result"
|
||||
"ScaffoldingLlm",
|
||||
"ScaffoldingOutput",
|
||||
"ParallelProcess",
|
||||
"Controller",
|
||||
"NativeGenerationController",
|
||||
"NativeRewardController",
|
||||
"MajorityVoteController",
|
||||
"BestOfNController",
|
||||
"Task",
|
||||
"GenerationTask",
|
||||
"RewardTask",
|
||||
"Worker",
|
||||
"OpenaiWorker",
|
||||
"TRTOpenaiWorker",
|
||||
"TRTLLMWorker",
|
||||
"TaskStatus",
|
||||
"extract_answer_from_boxed",
|
||||
"extract_answer_with_regex",
|
||||
"get_digit_majority_vote_result",
|
||||
]
|
||||
|
||||
46
tensorrt_llm/scaffolding/contrib/AsyncGeneration/README.md
Normal file
46
tensorrt_llm/scaffolding/contrib/AsyncGeneration/README.md
Normal file
@ -0,0 +1,46 @@
|
||||
## Overview
|
||||
|
||||
`StreamGenerationTask` is an extension of `GenerationTask` designed for token-level streaming generation in asynchronous LLM workflows using TensorRT-LLM. It enables the controller to receive partial results during generation, which is critical for real-time or latency-sensitive applications such as chatbots, speech generation, or UI-interactive systems.
|
||||
|
||||
---
|
||||
|
||||
## Features
|
||||
|
||||
- ✅ Supports **streamed token delivery** by step (e.g., `streaming_step=1`).
|
||||
- ✅ Supports **cancellation** of generation using a flag (`cancel_flag=True`).
|
||||
- ✅ Tracks **stream completion status** (`end_flag=True` when done).
|
||||
- ✅ Integrated with a streaming-capable LLM interface (`generate_async`).
|
||||
|
||||
---
|
||||
|
||||
## Fields in `StreamGenerationTask`
|
||||
|
||||
| Field | Description |
|
||||
|-------|-------------|
|
||||
| `cancel_flag` | If `True`, the generation will be cancelled on the worker side. |
|
||||
| `streaming_step` | Number of new tokens required before returning control to the controller. If set to `0`, the task is returned immediately if any new tokens are available. |
|
||||
| `request_handle` | Internal handle for the streaming generation (used by the worker). |
|
||||
| `end_flag` | Indicates whether generation is finished. |
|
||||
| `output_str` / `output_tokens` / `logprobs` | Outputs after each generation step. |
|
||||
|
||||
---
|
||||
|
||||
## Usage in Controller/Worker
|
||||
|
||||
The Controller can utilize `StreamGenerationTask` to enable efficient streaming-based generation workflows:
|
||||
- It sends tasks to the worker, which returns them when the number of newly generated tokens reaches the specified `streaming_step`.
|
||||
- It can cancel long-running tasks by setting `task.cancel_flag = True` when the number of generated tokens exceeds a predefined threshold.
|
||||
|
||||
To support this behavior on the worker side, we have implemented `stream_generation_handler` and you need to register it with the worker in your project. This handler should process `StreamGenerationTask` instances step-by-step and update relevant fields such as `output_tokens`, `output_str`.
|
||||
|
||||
This design allows the controller and worker to coordinate generation in a token-efficient and responsive manner, ideal for real-time applications.
|
||||
|
||||
You can see more details in `stream_generation_controller.py` and `stream_generation_task.py` from `examples/scaffolding/contrib/AsyncGeneration`.
|
||||
|
||||
## Notes
|
||||
Remember to register the `stream_generation_handler` with the `TRTLLMWorker`.
|
||||
|
||||
## TODO
|
||||
|
||||
- Add error handling for failed `request_handle`.
|
||||
- Support retry or backoff mechanism if generation stalls.
|
||||
@ -0,0 +1,3 @@
|
||||
from .stream_generation import StreamGenerationTask, stream_generation_handler
|
||||
|
||||
__all__ = ["stream_generation_handler", "StreamGenerationTask"]
|
||||
3
tensorrt_llm/scaffolding/contrib/Dynasor/__init__.py
Normal file
3
tensorrt_llm/scaffolding/contrib/Dynasor/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .dynasor_controller import DynasorGenerationController
|
||||
|
||||
__all__ = ["DynasorGenerationController"]
|
||||
@ -1,11 +1,12 @@
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from evaluator import equal_group
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tensorrt_llm.scaffolding import Controller, GenerationTask
|
||||
|
||||
from .evaluator import equal_group
|
||||
|
||||
|
||||
class DynasorGenerationController(Controller):
|
||||
|
||||
21
tensorrt_llm/scaffolding/contrib/README.md
Normal file
21
tensorrt_llm/scaffolding/contrib/README.md
Normal file
@ -0,0 +1,21 @@
|
||||
# Contrib Examples
|
||||
|
||||
We create this directory to store the community contributed projects.
|
||||
|
||||
Contributors can develop inference time compute methods with various Controller/Task/Worker.
|
||||
|
||||
We will continue to move some generic works on this directory back to the main code.
|
||||
|
||||
### How to create a new project?
|
||||
|
||||
Just create a new directory and add your code there.
|
||||
|
||||
### How to make your code include Controller/Task/Worker can be reused by other projects?
|
||||
|
||||
Just add your Controller/Task/Worker to the `__init__.py` file of scaffolding.
|
||||
|
||||
### How to show examples of your project?
|
||||
|
||||
Just add your example to the `examples/scaffolding/contrib/` directory.
|
||||
|
||||
In summary, the part of the code you want to be imported by other users or projects should be put on `tensorrt_llm/scaffolding/contrib/` directory and added to the `__init__.py` file. The code to run the project and show the results should be put on `examples/scaffolding/contrib/` directory.
|
||||
12
tensorrt_llm/scaffolding/contrib/__init__.py
Normal file
12
tensorrt_llm/scaffolding/contrib/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
from tensorrt_llm.scaffolding import * # noqa
|
||||
|
||||
from .AsyncGeneration import StreamGenerationTask, stream_generation_handler
|
||||
from .Dynasor import DynasorGenerationController
|
||||
|
||||
__all__ = [
|
||||
# AsyncGeneration
|
||||
"stream_generation_handler",
|
||||
"StreamGenerationTask",
|
||||
# Dynasor
|
||||
"DynasorGenerationController",
|
||||
]
|
||||
Loading…
Reference in New Issue
Block a user