TensorRT-LLMs/examples/scaffolding/contrib/DeepConf/run_generation.py
Cao Dong 2ff772ef71
[None][feat] Add benchmark to DeepConf (#8776)
Signed-off-by: Dong Cao <docao@nvidia.com>
2025-11-03 16:05:50 +08:00

259 lines
9.8 KiB
Python

import argparse
import json
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict
import numpy as np
from utils import equal_func, prepare_prompt
from tensorrt_llm.scaffolding import (NativeGenerationController,
ScaffoldingLlm, TRTLLMWorker,
extract_answer_from_boxed)
from tensorrt_llm.scaffolding.contrib.DeepConf import (
DeepConfOfflineController, DeepConfOfflineMajorityVoteController,
DeepConfOnlineController, DeepConfOnlineMajorityVoteController)
_RUN_TYPE_TO_IMPL = {
"offline": DeepConfOfflineController,
"online": DeepConfOnlineController,
"offline_majority_vote": DeepConfOfflineMajorityVoteController,
"online_majority_vote": DeepConfOnlineMajorityVoteController,
}
def parse_arguments():
parser = argparse.ArgumentParser()
# .e.g. DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B
parser.add_argument(
'--model_dir',
type=str,
required=True,
help="Path to the directory containing the generation model")
parser.add_argument('--run_type',
type=str,
required=True,
choices=list(_RUN_TYPE_TO_IMPL.keys()),
help="Type of the run. Available choices: %(choices)s")
parser.add_argument('--warmup_sample_num', type=int, default=16)
parser.add_argument('--sample_num', type=int, default=256)
parser.add_argument('--conf_group_size', type=int, default=2048)
parser.add_argument('--conf_threshold', type=float, default=0.5)
parser.add_argument('--vote_policy',
type=str,
default="top10_bottom_window_filtered")
parser.add_argument('--confidence_percentile', type=int, default=10)
parser.add_argument('--logprobs_topk', type=int, default=20)
parser.add_argument('--max_tokens', type=int, default=64000)
parser.add_argument('--temperature', type=float, default=0.6)
parser.add_argument('--top_p', type=float, default=0.95)
parser.add_argument('--top_k', type=int, default=0)
parser.add_argument('--qid', type=int, default=-1)
parser.add_argument('--dataset', type=str, default="brumo_2025.jsonl")
parser.add_argument('--repeat_times', type=int, default=1)
parser.add_argument('--tensor_parallel_size', type=int, default=1)
args = parser.parse_args()
return args
@dataclass
class BenchResult:
right_answer_count: int = 0
total_answer_count: int = 0
accuracy: float = 0.0
generated_tokens: int = 0
def run_scaffolding_llm(prompts,
proposer_worker,
controller,
repeat_times=1,
ground_truth=None,
**kwargs):
llm = ScaffoldingLlm(
controller,
{
NativeGenerationController.WorkerTag.GENERATION: proposer_worker,
},
)
is_majority_vote = isinstance(
controller, DeepConfOnlineMajorityVoteController) or isinstance(
controller, DeepConfOfflineMajorityVoteController)
vote_policy_to_bench_result: Dict[str, BenchResult] = {}
times = []
for i in range(repeat_times):
print(f"=========== round {i} ===========")
start_time = time.time()
results = llm.generate(prompts)
times.append(time.time() - start_time)
for j, result in enumerate(results):
print(
f"result {j}: {extract_answer_from_boxed(result.outputs[0].text)}"
)
if is_majority_vote and ground_truth is not None:
vote_policy_to_voted_task = result.cur_output.vote_policy_to_voted_task
for vote_policy, voted_task in vote_policy_to_voted_task.items(
):
bench_result = vote_policy_to_bench_result.get(
vote_policy, BenchResult())
voted_answer = voted_task.customized_result_fields[
'extracted_answer']
if equal_func(voted_answer, ground_truth[j]):
bench_result.right_answer_count += 1
bench_result.total_answer_count += 1
bench_result.generated_tokens += result.cur_output.output_token_num
vote_policy_to_bench_result[vote_policy] = bench_result
print(f"e2e inference median time cost: {np.median(times):.2f} seconds")
if is_majority_vote:
for vote_policy, bench_result in vote_policy_to_bench_result.items():
bench_result.accuracy = bench_result.right_answer_count / bench_result.total_answer_count
print(
f"vote_policy: {vote_policy}, accuracy: {bench_result.accuracy}"
)
print(f"generated tokens: {bench_result.generated_tokens}")
llm.shutdown(shutdown_workers=True)
def test_single_vote_controller(prompts,
proposer_worker,
conf_group_size,
conf_threshold,
temperature,
max_tokens,
logprobs_topk,
top_p,
run_type="offline",
**kwargs):
generation_controller = NativeGenerationController(
sampling_params={
"temperature": temperature,
"max_tokens": max_tokens,
"num_logprobs": logprobs_topk,
"top_p": top_p,
})
DeepConfControllerImpl = _RUN_TYPE_TO_IMPL[run_type]
prototype_controller = DeepConfControllerImpl(
generation_controller=generation_controller,
conf_group_size=conf_group_size,
conf_threshold=conf_threshold,
)
run_scaffolding_llm(prompts, proposer_worker, prototype_controller,
**kwargs)
def test_majority_vote_controller(prompts,
proposer_worker,
conf_group_size,
conf_threshold,
logprobs_topk,
temperature,
max_tokens,
top_p,
top_k,
sample_num,
warmup_sample_num,
vote_policy,
confidence_percentile,
run_type="offline_majority_vote",
**kwargs):
generation_controller = NativeGenerationController(
sampling_params={
"temperature": temperature,
"max_tokens": max_tokens,
"num_logprobs": logprobs_topk,
"top_p": top_p,
"top_k": top_k,
})
DeepConfControllerKwargs = {
"generation_controller": generation_controller,
"conf_group_size": conf_group_size,
"conf_threshold": conf_threshold,
}
warmup_generation_controller = DeepConfOfflineController(
**DeepConfControllerKwargs)
final_generation_controller = DeepConfOnlineController(
**DeepConfControllerKwargs)
DeepConfMajorityVoteControllerImpl = _RUN_TYPE_TO_IMPL[run_type]
majority_vote_controller = DeepConfMajorityVoteControllerImpl(
generation_controller=warmup_generation_controller,
warmup_generation_controller=warmup_generation_controller,
final_generation_controller=final_generation_controller,
sample_num=sample_num,
vote_policy=vote_policy,
warmup_sample_num=warmup_sample_num,
confidence_percentile=confidence_percentile)
run_scaffolding_llm(prompts, proposer_worker, majority_vote_controller,
**kwargs)
def main():
args = parse_arguments()
kwargs = {
"sample_num": args.sample_num,
"conf_group_size": args.conf_group_size,
"conf_threshold": args.conf_threshold,
"vote_policy": args.vote_policy,
"warmup_sample_num": args.warmup_sample_num,
"confidence_percentile": args.confidence_percentile,
"logprobs_topk": args.logprobs_topk,
"temperature": args.temperature,
"top_p": args.top_p,
"top_k": args.top_k,
"repeat_times": args.repeat_times,
"max_tokens": args.max_tokens,
}
llm_worker = TRTLLMWorker.init_with_new_llm(
args.model_dir,
backend="pytorch",
max_batch_size=2048,
max_num_tokens=args.max_tokens,
)
print(f"init llm worker done")
dataset_path = Path(__file__).parent / args.dataset
with open(dataset_path, 'r', encoding='utf-8') as file:
question_data = [json.loads(line.strip()) for line in file]
if args.qid != -1:
question_data = [question_data[args.qid]]
prompts = [
prepare_prompt(question_data['question'], llm_worker.tokenizer)
for question_data in question_data
]
ground_truth = [
str(question_data.get('answer', '')).strip()
for question_data in question_data
]
kwargs["ground_truth"] = ground_truth
print(f"has {len(prompts)} prompts")
if args.run_type == "offline" or args.run_type == "online":
test_single_vote_controller(prompts,
llm_worker,
run_type=args.run_type,
**kwargs)
elif args.run_type == "offline_majority_vote" or args.run_type == "online_majority_vote":
test_majority_vote_controller(prompts,
llm_worker,
run_type=args.run_type,
**kwargs)
llm_worker.shutdown()
print('llm worker shutdown done')
if __name__ == "__main__":
main()