mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
142 lines
4.7 KiB
Python
142 lines
4.7 KiB
Python
import argparse
|
|
import asyncio
|
|
import json
|
|
|
|
from tensorrt_llm.scaffolding import (GenerationTokenCounter,
|
|
MajorityVoteController,
|
|
NativeGenerationController,
|
|
ScaffoldingBenchRequest, ScaffoldingLlm,
|
|
TRTLLMWorker, async_scaffolding_benchmark,
|
|
extract_answer_from_boxed)
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser()
|
|
# .e.g. DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B
|
|
parser.add_argument(
|
|
'--model_dir',
|
|
type=str,
|
|
required=True,
|
|
help="Path to the directory containing the generation model")
|
|
# https://github.com/QwenLM/Qwen2.5-Math/blob/main/evaluation/data/aime24/test.jsonl
|
|
parser.add_argument('--jsonl_file', type=str, default='./test.jsonl')
|
|
parser.add_argument('--threshold', type=float, default=None)
|
|
parser.add_argument('--sample_num', type=int, default=10)
|
|
parser.add_argument('--concurrency', type=int, default=None)
|
|
parser.add_argument('--static_with_benchmark', action='store_true')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def load_test_file(jsonl_file: str):
|
|
data = []
|
|
with open(jsonl_file, "r", encoding="utf-8") as file:
|
|
for line in file:
|
|
if line.strip():
|
|
data.append(json.loads(line))
|
|
return data
|
|
|
|
|
|
def main():
|
|
args = parse_arguments()
|
|
workers = {}
|
|
|
|
llm_worker = TRTLLMWorker.init_with_new_llm(
|
|
args.model_dir,
|
|
backend="pytorch",
|
|
max_batch_size=32,
|
|
max_num_tokens=4096,
|
|
)
|
|
|
|
prototype_generation_controller = NativeGenerationController(
|
|
sampling_params={
|
|
"max_tokens": 4096,
|
|
"top_p": 0.9,
|
|
"temperature": 0.9,
|
|
})
|
|
workers[NativeGenerationController.WorkerTag.GENERATION] = llm_worker
|
|
|
|
prototype_majority_vote_controller = MajorityVoteController(
|
|
generation_controller=prototype_generation_controller,
|
|
default_sample_num=args.sample_num,
|
|
)
|
|
|
|
llm = ScaffoldingLlm(
|
|
prototype_majority_vote_controller,
|
|
workers=workers,
|
|
)
|
|
test_dataset = load_test_file(args.jsonl_file)
|
|
total_count = 0
|
|
correct_count = 0
|
|
controller_name = "MajorityVoteController"
|
|
|
|
prompts = []
|
|
for i in range(len(test_dataset)):
|
|
test_case = test_dataset[i]
|
|
prompts.append(test_case["problem"])
|
|
|
|
if args.static_with_benchmark or args.concurrency:
|
|
if args.concurrency == None:
|
|
args.concurrency = 1
|
|
|
|
if args.static_with_benchmark:
|
|
task_collection_types = {"token_counter": GenerationTokenCounter}
|
|
|
|
requests = [
|
|
ScaffoldingBenchRequest(prompt=prompt) for prompt in prompts
|
|
]
|
|
|
|
results, requests_execution_time, total_time = asyncio.run(
|
|
async_scaffolding_benchmark(llm, task_collection_types, requests,
|
|
args.concurrency))
|
|
else:
|
|
results = llm.generate(prompts)
|
|
|
|
print(f'main shutting down...')
|
|
llm.shutdown()
|
|
llm_worker.shutdown()
|
|
print(f'main shut down done')
|
|
|
|
for i in range(len(results)):
|
|
result = results[i]
|
|
test_case = test_dataset[i]
|
|
ref_answer = int(test_case["answer"])
|
|
output = result.outputs[0]
|
|
extracted_answer = extract_answer_from_boxed(output.text)
|
|
try:
|
|
# print(f"[QUESTION]:\n{prompt}\n\n[OUTPUT]\n\n{output.output_str}\n\n")
|
|
answer = int(extracted_answer)
|
|
print(f'Answer={answer}, reference={ref_answer}')
|
|
if answer == ref_answer:
|
|
correct_count += 1
|
|
except:
|
|
print(f'extracted_answer={extracted_answer}, not integer.')
|
|
total_count += 1
|
|
print(
|
|
f'Controller {controller_name} Accuracy: {correct_count} out of {total_count}'
|
|
)
|
|
|
|
if args.threshold is not None:
|
|
accuracy = correct_count / total_count
|
|
if accuracy < args.threshold:
|
|
print(
|
|
f'Accuracy check failed with {correct_count}/{total_count} < {args.threshold}'
|
|
)
|
|
else:
|
|
print(f'Accuracy check passed with threshold={args.threshold}')
|
|
|
|
if args.static_with_benchmark:
|
|
print(f'Total time: {total_time}')
|
|
print(
|
|
f'Average requests execution time: {sum(requests_execution_time) / len(requests_execution_time)}'
|
|
)
|
|
total_token_count = 0
|
|
for result in results:
|
|
total_token_count += result.task_collections[
|
|
"token_counter"].generation_token_count
|
|
print(f'Average output token count: {total_token_count / len(results)}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|