mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: fredw (generated by with_the_same_user script) <20514172+WeiHaocheng@users.noreply.github.com>
This commit is contained in:
parent
c069abc7d8
commit
6eee15900e
@ -1,11 +1,10 @@
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from tensorrt_llm.scaffolding.controller import (MajorityVoteController,
|
||||
NativeGenerationController)
|
||||
from tensorrt_llm.scaffolding.math_utils import *
|
||||
from tensorrt_llm.scaffolding.scaffolding_llm import ScaffoldingLlm
|
||||
from tensorrt_llm.scaffolding.worker import TRTLLMWorker
|
||||
from tensorrt_llm.scaffolding import (MajorityVoteController,
|
||||
NativeGenerationController,
|
||||
ScaffoldingLlm, TRTLLMWorker,
|
||||
extract_answer_from_boxed)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
@ -77,7 +76,7 @@ def main():
|
||||
ref_answer = int(test_case["answer"])
|
||||
result.result()
|
||||
output = result.output
|
||||
extracted_answer = extract_answer(output.output_str)
|
||||
extracted_answer = extract_answer_from_boxed(output.output_str)
|
||||
try:
|
||||
# print(f"[QUESTION]:\n{prompt}\n\n[OUTPUT]\n\n{output.output_str}\n\n")
|
||||
answer = int(extracted_answer)
|
||||
|
||||
@ -1,7 +1,15 @@
|
||||
## Contrib Examples
|
||||
# Contrib Examples
|
||||
|
||||
We create this directory to store the community contributed examples.
|
||||
|
||||
Contributors can add examples of customize inference time compute methods with customize Controller/Task/Worker.
|
||||
|
||||
We will continue to move some generic works on this directory back to the main code.
|
||||
|
||||
### How to create a new project
|
||||
|
||||
Just create a new directory and add your code there.
|
||||
|
||||
### How to make your code include Controller/Task/Worker can be reused by other projects
|
||||
|
||||
Just add your Controller/Task/Worker to the `__init__.py` file of contrib directory.
|
||||
|
||||
3
examples/scaffolding/contrib/__init__.py
Normal file
3
examples/scaffolding/contrib/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from tensorrt_llm.scaffolding import * # noqa
|
||||
|
||||
__all__ = []
|
||||
@ -1,9 +1,8 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from tensorrt_llm.scaffolding.controller import NativeGenerationController
|
||||
from tensorrt_llm.scaffolding.scaffolding_llm import ScaffoldingLlm
|
||||
from tensorrt_llm.scaffolding.worker import TRTLLMWorker
|
||||
from tensorrt_llm.scaffolding import (NativeGenerationController,
|
||||
ScaffoldingLlm, TRTLLMWorker)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
|
||||
@ -0,0 +1,17 @@
|
||||
from .controller import (BestOfNController, Controller, MajorityVoteController,
|
||||
NativeGenerationController, NativeRewardController,
|
||||
ScaffoldingOutput)
|
||||
from .math_utils import (extract_answer_from_boxed, extract_answer_with_regex,
|
||||
get_digit_majority_vote_result)
|
||||
from .scaffolding_llm import ScaffoldingLlm
|
||||
from .task import GenerationTask, RewardTask, Task, TaskStatus
|
||||
from .worker import OpenaiWorker, TRTLLMWorker, TRTOpenaiWorker, Worker
|
||||
|
||||
__all__ = [
|
||||
"ScaffoldingLlm", "ScaffoldingOutput", "Controller",
|
||||
"NativeGenerationController", "NativeRewardController",
|
||||
"MajorityVoteController", "BestOfNController", "Task", "GenerationTask",
|
||||
"RewardTask", "Worker", "OpenaiWorker", "TRTOpenaiWorker", "TRTLLMWorker",
|
||||
"TaskStatus", "extract_answer_from_boxed", "extract_answer_with_regex",
|
||||
"get_digit_majority_vote_result"
|
||||
]
|
||||
@ -3,15 +3,17 @@ from collections import Counter
|
||||
from typing import List
|
||||
|
||||
|
||||
def extract_answer(string: str,
|
||||
extract_from_boxed: bool = True,
|
||||
extract_regex: str = r"The final answer is (.+)$"):
|
||||
def extract_answer_with_regex(string: str,
|
||||
extract_regex: str = r"The final answer is (.+)$"
|
||||
):
|
||||
match = re.search(extract_regex, string)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
|
||||
def extract_answer_from_boxed(string: str):
|
||||
"""Extract Answer String from \\boxed expression or based on regex"""
|
||||
if not extract_from_boxed:
|
||||
match = re.search(extract_regex, string)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
if "\\boxed" not in string:
|
||||
return None
|
||||
@ -72,12 +74,13 @@ def get_majority_result(
|
||||
def get_digit_majority_vote_result(results: List[str]) -> str:
|
||||
|
||||
def is_digit(result: str):
|
||||
extracted_answer = extract_answer(result)
|
||||
extracted_answer = extract_answer_from_boxed(result)
|
||||
if extracted_answer is None:
|
||||
return False
|
||||
return extracted_answer.isdigit()
|
||||
|
||||
vote_result = get_majority_result(results,
|
||||
result_extractor=extract_answer,
|
||||
result_validator=is_digit)[0]
|
||||
vote_result = get_majority_result(
|
||||
results,
|
||||
result_extractor=extract_answer_from_boxed,
|
||||
result_validator=is_digit)[0]
|
||||
return vote_result if vote_result else results[0]
|
||||
|
||||
@ -3,9 +3,9 @@
|
||||
|
||||
from scaffolding.test_worker import create_trtllm_worker
|
||||
|
||||
from tensorrt_llm.scaffolding.controller import (MajorityVoteController,
|
||||
NativeGenerationController)
|
||||
from tensorrt_llm.scaffolding.scaffolding_llm import ScaffoldingLlm
|
||||
from tensorrt_llm.scaffolding import (MajorityVoteController,
|
||||
NativeGenerationController,
|
||||
ScaffoldingLlm)
|
||||
|
||||
|
||||
def create_scaffolding_llm_with_native_generation_controller(
|
||||
|
||||
@ -8,12 +8,11 @@ import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
from llmapi.apps.openai_server import RemoteOpenAIServer
|
||||
|
||||
from tensorrt_llm.scaffolding.task import GenerationTask, TaskStatus
|
||||
from tensorrt_llm.scaffolding.worker import TRTLLMWorker, TRTOpenaiWorker
|
||||
from tensorrt_llm.scaffolding import (GenerationTask, TaskStatus, TRTLLMWorker,
|
||||
TRTOpenaiWorker)
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
from llmapi.test_llm import get_model_path
|
||||
@ -61,41 +60,11 @@ def async_client(server: RemoteOpenAIServer):
|
||||
return server.get_async_client()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_single_completion(async_client: openai.OpenAI, model_name):
|
||||
completion = await async_client.completions.create(
|
||||
model=model_name,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
choice = completion.choices[0]
|
||||
assert len(choice.text) >= 5
|
||||
assert choice.finish_reason == "length"
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
completion_tokens = 5
|
||||
prompt_tokens = 6
|
||||
assert completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=completion_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens)
|
||||
|
||||
# test using token IDs
|
||||
completion = await async_client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert len(completion.choices[0].text) >= 1
|
||||
|
||||
|
||||
def create_trtoai_worker(model_name, async_client):
|
||||
return TRTOpenaiWorker(
|
||||
async_client=async_client,
|
||||
model=model_name,
|
||||
max_tokens=5,
|
||||
)
|
||||
|
||||
|
||||
@ -111,8 +80,8 @@ async def test_trtoai_worker_generation(default_prompt, model_name,
|
||||
def create_trtllm_worker(model_path):
|
||||
return TRTLLMWorker.init_with_new_llm(str(model_path),
|
||||
backend="pytorch",
|
||||
max_batch_size=32,
|
||||
max_num_tokens=4096,
|
||||
max_batch_size=1,
|
||||
max_num_tokens=5,
|
||||
temperature=0.9)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user