feat(part 2): Enhance the integrated robustness of scaffolding with __init__.py #3305 (#3731)

Signed-off-by: fredw (generated by with_the_same_user script) <20514172+WeiHaocheng@users.noreply.github.com>
This commit is contained in:
WeiHaocheng 2025-04-24 18:47:03 +08:00 committed by GitHub
parent ae34d60108
commit 3fc2a16920
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 144 additions and 105 deletions

View File

@ -9,12 +9,12 @@ from tensorrt_llm.scaffolding import (MajorityVoteController,
def parse_arguments():
parser = argparse.ArgumentParser()
# .e.g. /home/scratch.trt_llm_data/llm-models/DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B
parser.add_argument(
'--generation_dir',
'--model_dir',
type=str,
default=
"/home/scratch.trt_llm_data/llm-models/DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B"
)
required=True,
help="Path to the directory containing the generation model")
# https://github.com/QwenLM/Qwen2.5-Math/blob/main/evaluation/data/aime24/test.jsonl
parser.add_argument('--jsonl_file', type=str, default='./test.jsonl')
parser.add_argument('--threshold', type=float, default=None)
@ -36,7 +36,7 @@ def main():
args = parse_arguments()
workers = {}
llm_worker = TRTLLMWorker.init_with_new_llm(args.generation_dir,
llm_worker = TRTLLMWorker.init_with_new_llm(args.model_dir,
backend="pytorch",
max_batch_size=32,
max_num_tokens=4096,

View File

@ -1,46 +1,10 @@
## Overview
`StreamGenerationTask` is an extension of `GenerationTask` designed for token-level streaming generation in asynchronous LLM workflows using TensorRT-LLM. It enables the controller to receive partial results during generation, which is critical for real-time or latency-sensitive applications such as chatbots, speech generation, or UI-interactive systems.
This example shows how to use the `StreamGenerationTask` and `stream_generation_handler` to enable efficient streaming-based generation workflows.
---
How to run the example?
## Features
```bash
python stream_generation_run.py
```
- ✅ Supports **streamed token delivery** by step (e.g., `streaming_step=1`).
- ✅ Supports **cancellation** of generation using a flag (`cancel_flag=True`).
- ✅ Tracks **stream completion status** (`end_flag=True` when done).
- ✅ Integrated with a streaming-capable LLM interface (`generate_async`).
---
## Fields in `StreamGenerationTask`
| Field | Description |
|-------|-------------|
| `cancel_flag` | If `True`, the generation will be cancelled on the worker side. |
| `streaming_step` | Number of new tokens required before returning control to the controller. If set to `0`, the task is returned immediately if any new tokens are available. |
| `request_handle` | Internal handle for the streaming generation (used by the worker). |
| `end_flag` | Indicates whether generation is finished. |
| `output_str` / `output_tokens` / `logprobs` | Outputs after each generation step. |
---
## Usage in Controller/Worker
The Controller can utilize `StreamGenerationTask` to enable efficient streaming-based generation workflows:
- It sends tasks to the worker, which returns them when the number of newly generated tokens reaches the specified `streaming_step`.
- It can cancel long-running tasks by setting `task.cancel_flag = True` when the number of generated tokens exceeds a predefined threshold.
To support this behavior on the worker side, you need to implement a `stream_generation_handler` and register it with the worker. This handler should process `StreamGenerationTask` instances step-by-step and update relevant fields such as `output_tokens`, `output_str`.
This design allows the controller and worker to coordinate generation in a token-efficient and responsive manner, ideal for real-time applications.
You can see more details in stream_generation_controller.py and stream_generation_task.py
## Notes
Ensure the `worker.llm.generate_async(...)` method supports streaming=True.
## TODO
- Add error handling for failed `request_handle`
- Support retry or backoff mechanism if generation stalls
See more detail on [tensorrt_llm/scaffolding/contrib/AsyncGeneration/README.md](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/scaffolding/contrib/AsyncGeneration/README.md).

View File

@ -2,10 +2,8 @@ import copy
from enum import Enum
from typing import List
from stream_generation_task import StreamGenerationTask
from tensorrt_llm.scaffolding.controller import Controller
from tensorrt_llm.scaffolding.task import GenerationTask, Task
from tensorrt_llm.scaffolding import Controller, GenerationTask, Task
from tensorrt_llm.scaffolding.contrib import StreamGenerationTask
class NativeStreamGenerationController(Controller):

View File

@ -1,20 +1,20 @@
import argparse
from stream_generation_controller import NativeStreamGenerationController
from stream_generation_task import (StreamGenerationTask,
stream_generation_handler)
from tensorrt_llm.scaffolding import ScaffoldingLlm, TRTLLMWorker
from tensorrt_llm.scaffolding.contrib import (StreamGenerationTask,
stream_generation_handler)
def parse_arguments():
parser = argparse.ArgumentParser()
# .e.g. /home/scratch.trt_llm_data/llm-models/DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B
parser.add_argument(
'--generation_dir',
'--model_dir',
type=str,
default=
"/home/scratch.trt_llm_data/llm-models/DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B"
)
required=True,
help="Path to the directory containing the generation model")
parser.add_argument('--run_type', type=str, default='original')
args = parser.parse_args()
return args
@ -81,7 +81,7 @@ def main():
"There exist real numbers $x$ and $y$, both greater than 1, such that $\\log_x\\left(y^x\\right)=\\log_y\\left(x^{4y}\\right)=10$. Find $xy$.",
"Find the largest possible real part of \\[(75+117i)z+\\frac{96+144i}{z}\\]where $z$ is a complex number with $|z|=4$.",
]
llm_worker = TRTLLMWorker.init_with_new_llm(args.generation_dir,
llm_worker = TRTLLMWorker.init_with_new_llm(args.model_dir,
backend="pytorch",
max_batch_size=32,
max_num_tokens=4096,

View File

@ -0,0 +1,2 @@
latex2sympy2
word2number

View File

@ -1,16 +1,18 @@
import argparse
from dynasor_controller import DynasorGenerationController
from tensorrt_llm.scaffolding import (MajorityVoteController, ScaffoldingLlm,
TRTLLMWorker)
from tensorrt_llm.scaffolding.contrib import DynasorGenerationController
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--generation_dir",
type=str,
default="./models/DeepSeek-R1-Distill-Qwen-7B")
# .e.g. /home/scratch.trt_llm_data/llm-models/DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B
parser.add_argument(
'--model_dir',
type=str,
required=True,
help="Path to the directory containing the generation model")
parser.add_argument("--max_num_tokens", type=int, default=7000)
parser.add_argument("--majority_vote", action='store_true')
parser.add_argument('--sample_num', type=int, default=3)
@ -18,9 +20,9 @@ def parse_arguments():
return args
def test_sync(prompts, proposer_worker, args):
def test(prompts, proposer_worker, args):
dynasor_generation_controller = DynasorGenerationController(
generation_dir=args.generation_dir, max_tokens=args.max_num_tokens)
generation_dir=args.model_dir, max_tokens=args.max_num_tokens)
# If majority voting is requested, wrap the controller in MajorityVoteController
if args.majority_vote:
@ -65,11 +67,9 @@ def main():
]
llm_worker = TRTLLMWorker.init_with_new_llm(
args.generation_dir,
backend="pytorch",
max_num_tokens=args.max_num_tokens)
args.model_dir, backend="pytorch", max_num_tokens=args.max_num_tokens)
test_sync(prompts, llm_worker, args)
test(prompts, llm_worker, args)
if __name__ == "__main__":

View File

@ -1,15 +0,0 @@
# Contrib Examples
We create this directory to store the community contributed examples.
Contributors can add examples of customize inference time compute methods with customize Controller/Task/Worker.
We will continue to move some generic works on this directory back to the main code.
### How to create a new project
Just create a new directory and add your code there.
### How to make your code include Controller/Task/Worker can be reused by other projects
Just add your Controller/Task/Worker to the `__init__.py` file of contrib directory.

View File

@ -1,8 +0,0 @@
from tensorrt_llm.scaffolding import * # noqa
from .Dynasor.dynasor_controller import DynasorGenerationController
__all__ = [
'NativeStreamGenerationController', 'StreamGenerationTask',
'stream_generation_handler', 'DynasorGenerationController'
]

View File

@ -7,13 +7,12 @@ from tensorrt_llm.scaffolding import (NativeGenerationController,
def parse_arguments():
parser = argparse.ArgumentParser()
# .e.g. /home/scratch.trt_llm_data/llm-models/DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B
parser.add_argument(
'--generation_dir',
'--model_dir',
type=str,
default=
"/home/scratch.trt_llm_data/llm-models/DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B"
)
parser.add_argument('--verifier_dir', type=str, default=None)
required=True,
help="Path to the directory containing the generation model")
parser.add_argument('--run_async', action='store_true')
args = parser.parse_args()
return args
@ -68,7 +67,7 @@ def main():
"Find the largest possible real part of \\[(75+117i)z+\\frac{96+144i}{z}\\]where $z$ is a complex number with $|z|=4$.",
]
llm_worker = TRTLLMWorker.init_with_new_llm(args.generation_dir,
llm_worker = TRTLLMWorker.init_with_new_llm(args.model_dir,
backend="pytorch",
max_batch_size=32,
max_num_tokens=4096,

View File

@ -8,10 +8,23 @@ from .task import GenerationTask, RewardTask, Task, TaskStatus
from .worker import OpenaiWorker, TRTLLMWorker, TRTOpenaiWorker, Worker
__all__ = [
"ScaffoldingLlm", "ScaffoldingOutput", "ParallelProcess", "Controller",
"NativeGenerationController", "NativeRewardController",
"MajorityVoteController", "BestOfNController", "Task", "GenerationTask",
"RewardTask", "Worker", "OpenaiWorker", "TRTOpenaiWorker", "TRTLLMWorker",
"TaskStatus", "extract_answer_from_boxed", "extract_answer_with_regex",
"get_digit_majority_vote_result"
"ScaffoldingLlm",
"ScaffoldingOutput",
"ParallelProcess",
"Controller",
"NativeGenerationController",
"NativeRewardController",
"MajorityVoteController",
"BestOfNController",
"Task",
"GenerationTask",
"RewardTask",
"Worker",
"OpenaiWorker",
"TRTOpenaiWorker",
"TRTLLMWorker",
"TaskStatus",
"extract_answer_from_boxed",
"extract_answer_with_regex",
"get_digit_majority_vote_result",
]

View File

@ -0,0 +1,46 @@
## Overview
`StreamGenerationTask` is an extension of `GenerationTask` designed for token-level streaming generation in asynchronous LLM workflows using TensorRT-LLM. It enables the controller to receive partial results during generation, which is critical for real-time or latency-sensitive applications such as chatbots, speech generation, or UI-interactive systems.
---
## Features
- ✅ Supports **streamed token delivery** by step (e.g., `streaming_step=1`).
- ✅ Supports **cancellation** of generation using a flag (`cancel_flag=True`).
- ✅ Tracks **stream completion status** (`end_flag=True` when done).
- ✅ Integrated with a streaming-capable LLM interface (`generate_async`).
---
## Fields in `StreamGenerationTask`
| Field | Description |
|-------|-------------|
| `cancel_flag` | If `True`, the generation will be cancelled on the worker side. |
| `streaming_step` | Number of new tokens required before returning control to the controller. If set to `0`, the task is returned immediately if any new tokens are available. |
| `request_handle` | Internal handle for the streaming generation (used by the worker). |
| `end_flag` | Indicates whether generation is finished. |
| `output_str` / `output_tokens` / `logprobs` | Outputs after each generation step. |
---
## Usage in Controller/Worker
The Controller can utilize `StreamGenerationTask` to enable efficient streaming-based generation workflows:
- It sends tasks to the worker, which returns them when the number of newly generated tokens reaches the specified `streaming_step`.
- It can cancel long-running tasks by setting `task.cancel_flag = True` when the number of generated tokens exceeds a predefined threshold.
To support this behavior on the worker side, we have implemented `stream_generation_handler` and you need to register it with the worker in your project. This handler should process `StreamGenerationTask` instances step-by-step and update relevant fields such as `output_tokens`, `output_str`.
This design allows the controller and worker to coordinate generation in a token-efficient and responsive manner, ideal for real-time applications.
You can see more details in `stream_generation_controller.py` and `stream_generation_task.py` from `examples/scaffolding/contrib/AsyncGeneration`.
## Notes
Remember to register the `stream_generation_handler` with the `TRTLLMWorker`.
## TODO
- Add error handling for failed `request_handle`.
- Support retry or backoff mechanism if generation stalls.

View File

@ -0,0 +1,3 @@
from .stream_generation import StreamGenerationTask, stream_generation_handler
__all__ = ["stream_generation_handler", "StreamGenerationTask"]

View File

@ -0,0 +1,3 @@
from .dynasor_controller import DynasorGenerationController
__all__ = ["DynasorGenerationController"]

View File

@ -1,11 +1,12 @@
from enum import Enum
from typing import List
from evaluator import equal_group
from transformers import AutoTokenizer
from tensorrt_llm.scaffolding import Controller, GenerationTask
from .evaluator import equal_group
class DynasorGenerationController(Controller):

View File

@ -0,0 +1,21 @@
# Contrib Examples
We create this directory to store the community contributed projects.
Contributors can develop inference time compute methods with various Controller/Task/Worker.
We will continue to move some generic works on this directory back to the main code.
### How to create a new project?
Just create a new directory and add your code there.
### How to make your code include Controller/Task/Worker can be reused by other projects?
Just add your Controller/Task/Worker to the `__init__.py` file of scaffolding.
### How to show examples of your project?
Just add your example to the `examples/scaffolding/contrib/` directory.
In summary, the part of the code you want to be imported by other users or projects should be put on `tensorrt_llm/scaffolding/contrib/` directory and added to the `__init__.py` file. The code to run the project and show the results should be put on `examples/scaffolding/contrib/` directory.

View File

@ -0,0 +1,12 @@
from tensorrt_llm.scaffolding import * # noqa
from .AsyncGeneration import StreamGenerationTask, stream_generation_handler
from .Dynasor import DynasorGenerationController
__all__ = [
# AsyncGeneration
"stream_generation_handler",
"StreamGenerationTask",
# Dynasor
"DynasorGenerationController",
]