mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 10:11:47 +08:00
259 lines
9.8 KiB
Python
259 lines
9.8 KiB
Python
import argparse
|
|
import json
|
|
import time
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Dict
|
|
|
|
import numpy as np
|
|
from utils import equal_func, prepare_prompt
|
|
|
|
from tensorrt_llm.scaffolding import (NativeGenerationController,
|
|
ScaffoldingLlm, TRTLLMWorker,
|
|
extract_answer_from_boxed)
|
|
from tensorrt_llm.scaffolding.contrib.DeepConf import (
|
|
DeepConfOfflineController, DeepConfOfflineMajorityVoteController,
|
|
DeepConfOnlineController, DeepConfOnlineMajorityVoteController)
|
|
|
|
_RUN_TYPE_TO_IMPL = {
|
|
"offline": DeepConfOfflineController,
|
|
"online": DeepConfOnlineController,
|
|
"offline_majority_vote": DeepConfOfflineMajorityVoteController,
|
|
"online_majority_vote": DeepConfOnlineMajorityVoteController,
|
|
}
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser()
|
|
# .e.g. DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B
|
|
parser.add_argument(
|
|
'--model_dir',
|
|
type=str,
|
|
required=True,
|
|
help="Path to the directory containing the generation model")
|
|
parser.add_argument('--run_type',
|
|
type=str,
|
|
required=True,
|
|
choices=list(_RUN_TYPE_TO_IMPL.keys()),
|
|
help="Type of the run. Available choices: %(choices)s")
|
|
parser.add_argument('--warmup_sample_num', type=int, default=16)
|
|
parser.add_argument('--sample_num', type=int, default=256)
|
|
parser.add_argument('--conf_group_size', type=int, default=2048)
|
|
parser.add_argument('--conf_threshold', type=float, default=0.5)
|
|
parser.add_argument('--vote_policy',
|
|
type=str,
|
|
default="top10_bottom_window_filtered")
|
|
parser.add_argument('--confidence_percentile', type=int, default=10)
|
|
parser.add_argument('--logprobs_topk', type=int, default=20)
|
|
parser.add_argument('--max_tokens', type=int, default=64000)
|
|
parser.add_argument('--temperature', type=float, default=0.6)
|
|
parser.add_argument('--top_p', type=float, default=0.95)
|
|
parser.add_argument('--top_k', type=int, default=0)
|
|
parser.add_argument('--qid', type=int, default=-1)
|
|
parser.add_argument('--dataset', type=str, default="brumo_2025.jsonl")
|
|
parser.add_argument('--repeat_times', type=int, default=1)
|
|
parser.add_argument('--tensor_parallel_size', type=int, default=1)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
@dataclass
|
|
class BenchResult:
|
|
right_answer_count: int = 0
|
|
total_answer_count: int = 0
|
|
accuracy: float = 0.0
|
|
generated_tokens: int = 0
|
|
|
|
|
|
def run_scaffolding_llm(prompts,
|
|
proposer_worker,
|
|
controller,
|
|
repeat_times=1,
|
|
ground_truth=None,
|
|
**kwargs):
|
|
llm = ScaffoldingLlm(
|
|
controller,
|
|
{
|
|
NativeGenerationController.WorkerTag.GENERATION: proposer_worker,
|
|
},
|
|
)
|
|
|
|
is_majority_vote = isinstance(
|
|
controller, DeepConfOnlineMajorityVoteController) or isinstance(
|
|
controller, DeepConfOfflineMajorityVoteController)
|
|
vote_policy_to_bench_result: Dict[str, BenchResult] = {}
|
|
times = []
|
|
for i in range(repeat_times):
|
|
print(f"=========== round {i} ===========")
|
|
start_time = time.time()
|
|
results = llm.generate(prompts)
|
|
times.append(time.time() - start_time)
|
|
|
|
for j, result in enumerate(results):
|
|
print(
|
|
f"result {j}: {extract_answer_from_boxed(result.outputs[0].text)}"
|
|
)
|
|
|
|
if is_majority_vote and ground_truth is not None:
|
|
vote_policy_to_voted_task = result.cur_output.vote_policy_to_voted_task
|
|
for vote_policy, voted_task in vote_policy_to_voted_task.items(
|
|
):
|
|
bench_result = vote_policy_to_bench_result.get(
|
|
vote_policy, BenchResult())
|
|
|
|
voted_answer = voted_task.customized_result_fields[
|
|
'extracted_answer']
|
|
if equal_func(voted_answer, ground_truth[j]):
|
|
bench_result.right_answer_count += 1
|
|
bench_result.total_answer_count += 1
|
|
bench_result.generated_tokens += result.cur_output.output_token_num
|
|
|
|
vote_policy_to_bench_result[vote_policy] = bench_result
|
|
|
|
print(f"e2e inference median time cost: {np.median(times):.2f} seconds")
|
|
|
|
if is_majority_vote:
|
|
for vote_policy, bench_result in vote_policy_to_bench_result.items():
|
|
bench_result.accuracy = bench_result.right_answer_count / bench_result.total_answer_count
|
|
print(
|
|
f"vote_policy: {vote_policy}, accuracy: {bench_result.accuracy}"
|
|
)
|
|
|
|
print(f"generated tokens: {bench_result.generated_tokens}")
|
|
|
|
llm.shutdown(shutdown_workers=True)
|
|
|
|
|
|
def test_single_vote_controller(prompts,
|
|
proposer_worker,
|
|
conf_group_size,
|
|
conf_threshold,
|
|
temperature,
|
|
max_tokens,
|
|
logprobs_topk,
|
|
top_p,
|
|
run_type="offline",
|
|
**kwargs):
|
|
generation_controller = NativeGenerationController(
|
|
sampling_params={
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
"num_logprobs": logprobs_topk,
|
|
"top_p": top_p,
|
|
})
|
|
DeepConfControllerImpl = _RUN_TYPE_TO_IMPL[run_type]
|
|
prototype_controller = DeepConfControllerImpl(
|
|
generation_controller=generation_controller,
|
|
conf_group_size=conf_group_size,
|
|
conf_threshold=conf_threshold,
|
|
)
|
|
run_scaffolding_llm(prompts, proposer_worker, prototype_controller,
|
|
**kwargs)
|
|
|
|
|
|
def test_majority_vote_controller(prompts,
|
|
proposer_worker,
|
|
conf_group_size,
|
|
conf_threshold,
|
|
logprobs_topk,
|
|
temperature,
|
|
max_tokens,
|
|
top_p,
|
|
top_k,
|
|
sample_num,
|
|
warmup_sample_num,
|
|
vote_policy,
|
|
confidence_percentile,
|
|
run_type="offline_majority_vote",
|
|
**kwargs):
|
|
generation_controller = NativeGenerationController(
|
|
sampling_params={
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
"num_logprobs": logprobs_topk,
|
|
"top_p": top_p,
|
|
"top_k": top_k,
|
|
})
|
|
DeepConfControllerKwargs = {
|
|
"generation_controller": generation_controller,
|
|
"conf_group_size": conf_group_size,
|
|
"conf_threshold": conf_threshold,
|
|
}
|
|
warmup_generation_controller = DeepConfOfflineController(
|
|
**DeepConfControllerKwargs)
|
|
final_generation_controller = DeepConfOnlineController(
|
|
**DeepConfControllerKwargs)
|
|
DeepConfMajorityVoteControllerImpl = _RUN_TYPE_TO_IMPL[run_type]
|
|
majority_vote_controller = DeepConfMajorityVoteControllerImpl(
|
|
generation_controller=warmup_generation_controller,
|
|
warmup_generation_controller=warmup_generation_controller,
|
|
final_generation_controller=final_generation_controller,
|
|
sample_num=sample_num,
|
|
vote_policy=vote_policy,
|
|
warmup_sample_num=warmup_sample_num,
|
|
confidence_percentile=confidence_percentile)
|
|
run_scaffolding_llm(prompts, proposer_worker, majority_vote_controller,
|
|
**kwargs)
|
|
|
|
|
|
def main():
|
|
args = parse_arguments()
|
|
kwargs = {
|
|
"sample_num": args.sample_num,
|
|
"conf_group_size": args.conf_group_size,
|
|
"conf_threshold": args.conf_threshold,
|
|
"vote_policy": args.vote_policy,
|
|
"warmup_sample_num": args.warmup_sample_num,
|
|
"confidence_percentile": args.confidence_percentile,
|
|
"logprobs_topk": args.logprobs_topk,
|
|
"temperature": args.temperature,
|
|
"top_p": args.top_p,
|
|
"top_k": args.top_k,
|
|
"repeat_times": args.repeat_times,
|
|
"max_tokens": args.max_tokens,
|
|
}
|
|
|
|
llm_worker = TRTLLMWorker.init_with_new_llm(
|
|
args.model_dir,
|
|
backend="pytorch",
|
|
max_batch_size=2048,
|
|
max_num_tokens=args.max_tokens,
|
|
)
|
|
print(f"init llm worker done")
|
|
|
|
dataset_path = Path(__file__).parent / args.dataset
|
|
with open(dataset_path, 'r', encoding='utf-8') as file:
|
|
question_data = [json.loads(line.strip()) for line in file]
|
|
|
|
if args.qid != -1:
|
|
question_data = [question_data[args.qid]]
|
|
prompts = [
|
|
prepare_prompt(question_data['question'], llm_worker.tokenizer)
|
|
for question_data in question_data
|
|
]
|
|
ground_truth = [
|
|
str(question_data.get('answer', '')).strip()
|
|
for question_data in question_data
|
|
]
|
|
kwargs["ground_truth"] = ground_truth
|
|
|
|
print(f"has {len(prompts)} prompts")
|
|
|
|
if args.run_type == "offline" or args.run_type == "online":
|
|
test_single_vote_controller(prompts,
|
|
llm_worker,
|
|
run_type=args.run_type,
|
|
**kwargs)
|
|
elif args.run_type == "offline_majority_vote" or args.run_type == "online_majority_vote":
|
|
test_majority_vote_controller(prompts,
|
|
llm_worker,
|
|
run_type=args.run_type,
|
|
**kwargs)
|
|
|
|
llm_worker.shutdown()
|
|
print('llm worker shutdown done')
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|