feat: Enhance the integrated robustness of scaffolding with __init__.py #3305 (#3312)

Signed-off-by: fredw (generated by with_the_same_user script) <20514172+WeiHaocheng@users.noreply.github.com>
This commit is contained in:
WeiHaocheng 2025-04-09 21:13:47 +08:00 committed by GitHub
parent c069abc7d8
commit 6eee15900e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 59 additions and 61 deletions

View File

@ -1,11 +1,10 @@
import argparse
import json
from tensorrt_llm.scaffolding.controller import (MajorityVoteController,
NativeGenerationController)
from tensorrt_llm.scaffolding.math_utils import *
from tensorrt_llm.scaffolding.scaffolding_llm import ScaffoldingLlm
from tensorrt_llm.scaffolding.worker import TRTLLMWorker
from tensorrt_llm.scaffolding import (MajorityVoteController,
NativeGenerationController,
ScaffoldingLlm, TRTLLMWorker,
extract_answer_from_boxed)
def parse_arguments():
@ -77,7 +76,7 @@ def main():
ref_answer = int(test_case["answer"])
result.result()
output = result.output
extracted_answer = extract_answer(output.output_str)
extracted_answer = extract_answer_from_boxed(output.output_str)
try:
# print(f"[QUESTION]:\n{prompt}\n\n[OUTPUT]\n\n{output.output_str}\n\n")
answer = int(extracted_answer)

View File

@ -1,7 +1,15 @@
## Contrib Examples
# Contrib Examples
We create this directory to store the community contributed examples.
Contributors can add examples of customize inference time compute methods with customize Controller/Task/Worker.
We will continue to move some generic works on this directory back to the main code.
### How to create a new project
Just create a new directory and add your code there.
### How to make your code include Controller/Task/Worker can be reused by other projects
Just add your Controller/Task/Worker to the `__init__.py` file of contrib directory.

View File

@ -0,0 +1,3 @@
from tensorrt_llm.scaffolding import * # noqa
__all__ = []

View File

@ -1,9 +1,8 @@
import argparse
import asyncio
from tensorrt_llm.scaffolding.controller import NativeGenerationController
from tensorrt_llm.scaffolding.scaffolding_llm import ScaffoldingLlm
from tensorrt_llm.scaffolding.worker import TRTLLMWorker
from tensorrt_llm.scaffolding import (NativeGenerationController,
ScaffoldingLlm, TRTLLMWorker)
def parse_arguments():

View File

@ -0,0 +1,17 @@
from .controller import (BestOfNController, Controller, MajorityVoteController,
NativeGenerationController, NativeRewardController,
ScaffoldingOutput)
from .math_utils import (extract_answer_from_boxed, extract_answer_with_regex,
get_digit_majority_vote_result)
from .scaffolding_llm import ScaffoldingLlm
from .task import GenerationTask, RewardTask, Task, TaskStatus
from .worker import OpenaiWorker, TRTLLMWorker, TRTOpenaiWorker, Worker
__all__ = [
"ScaffoldingLlm", "ScaffoldingOutput", "Controller",
"NativeGenerationController", "NativeRewardController",
"MajorityVoteController", "BestOfNController", "Task", "GenerationTask",
"RewardTask", "Worker", "OpenaiWorker", "TRTOpenaiWorker", "TRTLLMWorker",
"TaskStatus", "extract_answer_from_boxed", "extract_answer_with_regex",
"get_digit_majority_vote_result"
]

View File

@ -3,15 +3,17 @@ from collections import Counter
from typing import List
def extract_answer(string: str,
extract_from_boxed: bool = True,
extract_regex: str = r"The final answer is (.+)$"):
def extract_answer_with_regex(string: str,
extract_regex: str = r"The final answer is (.+)$"
):
match = re.search(extract_regex, string)
if match:
return match.group(1)
return None
def extract_answer_from_boxed(string: str):
"""Extract Answer String from \\boxed expression or based on regex"""
if not extract_from_boxed:
match = re.search(extract_regex, string)
if match:
return match.group(1)
return None
if "\\boxed" not in string:
return None
@ -72,12 +74,13 @@ def get_majority_result(
def get_digit_majority_vote_result(results: List[str]) -> str:
def is_digit(result: str):
extracted_answer = extract_answer(result)
extracted_answer = extract_answer_from_boxed(result)
if extracted_answer is None:
return False
return extracted_answer.isdigit()
vote_result = get_majority_result(results,
result_extractor=extract_answer,
result_validator=is_digit)[0]
vote_result = get_majority_result(
results,
result_extractor=extract_answer_from_boxed,
result_validator=is_digit)[0]
return vote_result if vote_result else results[0]

View File

@ -3,9 +3,9 @@
from scaffolding.test_worker import create_trtllm_worker
from tensorrt_llm.scaffolding.controller import (MajorityVoteController,
NativeGenerationController)
from tensorrt_llm.scaffolding.scaffolding_llm import ScaffoldingLlm
from tensorrt_llm.scaffolding import (MajorityVoteController,
NativeGenerationController,
ScaffoldingLlm)
def create_scaffolding_llm_with_native_generation_controller(

View File

@ -8,12 +8,11 @@ import asyncio
import os
import sys
import openai
import pytest
from llmapi.apps.openai_server import RemoteOpenAIServer
from tensorrt_llm.scaffolding.task import GenerationTask, TaskStatus
from tensorrt_llm.scaffolding.worker import TRTLLMWorker, TRTOpenaiWorker
from tensorrt_llm.scaffolding import (GenerationTask, TaskStatus, TRTLLMWorker,
TRTOpenaiWorker)
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from llmapi.test_llm import get_model_path
@ -61,41 +60,11 @@ def async_client(server: RemoteOpenAIServer):
return server.get_async_client()
@pytest.mark.asyncio(loop_scope="module")
async def test_single_completion(async_client: openai.OpenAI, model_name):
completion = await async_client.completions.create(
model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0,
)
choice = completion.choices[0]
assert len(choice.text) >= 5
assert choice.finish_reason == "length"
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
completion_tokens = 5
prompt_tokens = 6
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=completion_tokens,
prompt_tokens=prompt_tokens,
total_tokens=prompt_tokens + completion_tokens)
# test using token IDs
completion = await async_client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert len(completion.choices[0].text) >= 1
def create_trtoai_worker(model_name, async_client):
return TRTOpenaiWorker(
async_client=async_client,
model=model_name,
max_tokens=5,
)
@ -111,8 +80,8 @@ async def test_trtoai_worker_generation(default_prompt, model_name,
def create_trtllm_worker(model_path):
return TRTLLMWorker.init_with_new_llm(str(model_path),
backend="pytorch",
max_batch_size=32,
max_num_tokens=4096,
max_batch_size=1,
max_num_tokens=5,
temperature=0.9)