mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] Add benchmark to DeepConf (#8776)
Signed-off-by: Dong Cao <docao@nvidia.com>
This commit is contained in:
parent
497a07021d
commit
2ff772ef71
30
examples/scaffolding/contrib/DeepConf/brumo_2025.jsonl
Normal file
30
examples/scaffolding/contrib/DeepConf/brumo_2025.jsonl
Normal file
@ -0,0 +1,30 @@
|
||||
{"question": "One hundred concentric circles are labelled $C_{1}, C_{2}, C_{3}, \\ldots, C_{100}$. Each circle $C_{n}$ is inscribed within an equilateral triangle whose vertices are points on $C_{n+1}$. Given $C_{1}$ has a radius of $1$, what is the radius of $C_{100}$ ?\n", "answer": "2^{99}"}
|
||||
{"question": "An infinite geometric sequence with common ratio $r$ sums to $91$. A new sequence starting with the same term has common ratio $r^{3}$. The sum of the new sequence produced is $81$. What was the common ratio of the original sequence?\n", "answer": "\\frac{1}{9}"}
|
||||
{"question": "Let $A, B, C, D$, and $E$ be five equally spaced points on a line in that order. Let $F, G, H$, and $I$ all be on the same side of line $A E$ such that triangles $A F B, B G C, C H D$, and $D I E$ are equilateral with side length $1$. Let $S$ be the region consisting of the interiors of all four triangles. Compute the length of segment $A I$ that is contained in $S$.", "answer": "\\frac{\\sqrt{13}}{2}"}
|
||||
{"question": "If $5 f(x)-x f\\left(\\frac{1}{x}\\right)=\\frac{1}{17} x^{2}$, determine $f(3)$.", "answer": "\\frac{1}{9}"}
|
||||
{"question": "How many ways are there to arrange $1,2,3,4,5,6$ such that no two consecutive numbers have the same remainder when divided by $3$ ?", "answer": "240"}
|
||||
{"question": "Joshua is playing with his number cards. He has $9$ cards of $9$ lined up in a row. He puts a multiplication sign between two of the $9 \\mathrm{~s}$ and calculates the product of the two strings of $9 \\mathrm{~s}$. For example, one possible result is $999 \\times 999999=998999001$. Let $S$ be the sum of all possible distinct results (note that $999 \\times 999999$ yields the same result as $999999 \\times 999$ ). What is the sum of digits of $S$ ?", "answer": "72"}
|
||||
{"question": "Bruno the Bear is tasked to organize $16$ identical brown balls into $7$ bins labeled 1-7. He must distribute the balls among the bins so that each odd-labeled bin contains an odd number of balls, and each even-labeled bin contains an even number of balls (with $0$ considered even). In how many ways can Bruno do this?", "answer": "924"}
|
||||
{"question": "Let $f(n)$ be the number obtained by increasing every prime factor in $f$ by one. For instance, $f(12)=(2+1)^{2}(3+1)=36$. What is the lowest $n$ such that $6^{2025}$ divides $f^{(n)}(2025)$, where $f^{(n)}$ denotes the $n$th iteration of $f$ ?", "answer": "20"}
|
||||
{"question": "How many positive integer divisors of $63^{10}$ do not end in a $1$ ?", "answer": "173"}
|
||||
{"question": "Bruno is throwing a party and invites $n$ guests. Each pair of party guests are either friends or enemies. Each guest has exactly $12$ enemies. All guests believe the following: the friend of an enemy is an enemy. Calculate the sum of all possible values of $n$. (Please note: Bruno is not a guest at his own party)", "answer": "100"}
|
||||
{"question": "In acute $\\triangle A B C$, let $D$ be the foot of the altitude from $A$ to $B C$ and $O$ be the circumcenter. Suppose that the area of $\\triangle A B D$ is equal to the area of $\\triangle A O C$. Given that $O D=2$ and $B D=3$, compute $A D$.", "answer": "3+2\\sqrt{2}"}
|
||||
{"question": "Alice has $10$ gifts $g_{1}, g_{2}, \\ldots, g_{10}$ and $10$ friends $f_{1}, f_{2}, \\ldots, f_{10}$. Gift $g_{i}$ can be given to friend $f_{j}$ if\n\n$$\ni-j=-1,0, \\text { or } 1 \\quad(\\bmod 10)\n$$\n\nHow many ways are there for Alice to pair the $10$ gifts with the $10$ friends such that each friend receives one gift?", "answer": "125"}
|
||||
{"question": "Let $\\triangle A B C$ be an equilateral triangle with side length $1$. A real number $d$ is selected uniformly at random from the open interval $(0,0.5)$. Points $E$ and $F$ lie on sides $A C$ and $A B$, respectively, such that $A E=d$ and $A F=1-d$. Let $D$ be the intersection of lines $B E$ and $C F$.\n\nConsider line $\\ell$ passing through both points of intersection of the circumcircles of triangles $\\triangle D E F$ and $\\triangle D B C . O$ is the circumcenter of $\\triangle D E F$. Line $\\ell$ intersects line $\\overleftrightarrow{B C}$ at point $P$, and point $Q$ lies on $A P$ such that $\\angle A Q B=120^{\\circ}$. What is the probability that the line segment $\\overline{Q O}$ has length less than $\\frac{1}{3}$ ?", "answer": "\\frac{1}{3}"}
|
||||
{"question": "Define sequence $\\left\\{a_{n}\\right\\}_{n=1}^{\\infty}$ such that $a_{1}=\\frac{\\pi}{3}$ and $a_{n+1}=\\cot ^{-1}\\left(\\csc \\left(a_{n}\\right)\\right)$ for all positive integers $n$. Find the value of\n\n$$\n\\frac{1}{\\cos \\left(a_{1}\\right) \\cos \\left(a_{2}\\right) \\cos \\left(a_{3}\\right) \\cdots \\cos \\left(a_{16}\\right)}\n$$", "answer": "7"}
|
||||
{"question": "Define $\\{x\\}$ to be the fractional part of $x$. For example, $\\{20.25\\}=0.25$ and $\\{\\pi\\}=\\pi-3$. Let $A=\\sum_{a=1}^{96} \\sum_{n=1}^{96}\\left\\{\\frac{a^{n}}{97}\\right\\}$, where $\\{x\\}$ denotes the fractional part of $x$. Compute $A$ rounded to the nearest integer.", "answer": "4529"}
|
||||
{"question": "Find the smallest positive integer $n$ such that $n$ is divisible by exactly $25$ different positive integers.", "answer": "1296"}
|
||||
{"question": "Two squares, $A B C D$ and $A E F G$, have equal side length $x$. They intersect at $A$ and $O$. Given that $C O=2$ and $O A=2 \\sqrt{2}$, what is $x$ ?", "answer": "1+\\sqrt{3}"}
|
||||
{"question": "Bruno and Brutus are running on a circular track with a $20$ foot radius. Bruno completes $5$ laps every hour, while Brutus completes $7$ laps every hour. If they start at the same point but run in opposite directions, how far along the track's circumference (in feet) from the starting point are they when they meet for the sixth time? Note: Do not count the moment they start running as a meeting point.", "answer": "20\\pi"}
|
||||
{"question": "What is the smallest positive integer $n$ such that $z^{n}-1$ and $(z-\\sqrt{3})^{n}-1$ share a common complex root?", "answer": "12"}
|
||||
{"question": "Consider a pond with lily pads numbered from $1$ to $12$ arranged in a circle. Bruno the frog starts on lily pad 1. Each turn, Bruno has an equal probability of making one of three moves: jumping $4$ lily pads clockwise, jumping $2$ lily pads clockwise, or jumping $1$ lily pad counterclockwise. What is the expected number of turns for Bruno to return to lily pad $1$ for the first time?", "answer": "12"}
|
||||
{"question": "$4$ bears - Aruno, Bruno, Cruno and Druno - are each given a card with a positive integer and are told that the sum of their $4$ numbers is $17$. They cannot show each other their cards, but discuss a series of observations in the following order:\n\nAruno: \"I think it is possible that the other three bears all have the same card.\"\nBruno: \"At first, I thought it was possible for the other three bears to have the same card. Now I know it is impossible for them to have the same card.\"\nCruno: \"I think it is still possible that the other three bears have the same card.\"\nDruno: \"I now know what card everyone has.\"\nWhat is the product of their four card values?", "answer": "160"}
|
||||
{"question": "Digits $1$ through $9$ are placed on a $3 x 3$ square such that all rows and columns sum to the same value. Please note that diagonals do not need to sum to the same value. How many ways can this be done?", "answer": "72"}
|
||||
{"question": "Define the operation $\\oplus$ by\n\n$$\nx \\oplus y=x y-2 x-2 y+6 .\n$$\n\nCompute all complex numbers $a$ such that\n\n$$\na \\oplus(a \\oplus(a \\oplus a))=a .\n$$", "answer": "2,3,\\frac{3+i\\sqrt{3}}{2},\\frac{3-i\\sqrt{3}}{2}"}
|
||||
{"question": "Define the function $f$ on positive integers\n\n$$\nf(n)= \\begin{cases}\\frac{n}{2} & \\text { if } n \\text { is even } \\\\ n+1 & \\text { if } n \\text { is odd }\\end{cases}\n$$\n\nLet $S(n)$ equal the smallest positive integer $k$ such that $f^{k}(n)=1$. How many positive integers satisfy $S(n)=11$ ?", "answer": "89"}
|
||||
{"question": "Let $A B C D E F$ be a convex cyclic hexagon. Suppose that $A B=D E=\\sqrt{5}, B C=E F=3$, and $C D=F A=\\sqrt{20}$. Compute the circumradius of $A B C D E F$.", "answer": "\\frac{1+\\sqrt{31}}{2}"}
|
||||
{"question": "A repetend is the infinitely repeated digit sequence of a repeating decimal. What are the last three digits of the repetend of the decimal representation of $\\frac{1}{727}$, given that the repetend has a length of $726$ ? Express the answer as a three-digit number. Include preceding zeros if there are any.", "answer": "337"}
|
||||
{"question": "Consider a $54$-deck of cards, i.e. a standard $52$-card deck together with two jokers. Ada draws cards from the deck until Ada has drawn an ace, a king, and a queen. How many cards does Ada pick up on average?", "answer": "\\frac{737}{39}"}
|
||||
{"question": "Let $\\omega$ be a circle, and let a line $\\ell$ intersect $\\omega$ at two points, $P$ and $Q$. Circles $\\omega_{1}$ and $\\omega_{2}$ are internally tangent to $\\omega$ at points $X$ and $Y$, respectively, and both are tangent to $\\ell$ at a common point $D$. Similarly, circles $\\omega_{3}$ and $\\omega_{4}$ are externally tangent to $\\omega$ at $X$ and $Y$, respectively, and are tangent to $\\ell$ at points $E$ and $F$, respectively.\nGiven that the radius of $\\omega$ is $13$, the segment $\\overline{P Q}=24$, and $\\overline{Y D}=\\overline{Y E}$, find the length of segment $\\overline{Y F}$.", "answer": "5\\sqrt{2}"}
|
||||
{"question": "Let $f$ be a degree $7$ polynomial satisfying\n$$\nf(k)=\\frac{1}{k^{2}}\n$$\n\nfor $k \\in\\{1 \\cdot 2,2 \\cdot 3, \\ldots, 8 \\cdot 9\\}$. Find $f(90)-\\frac{1}{90^{2}}$.", "answer": "-\\frac{2431}{50}"}
|
||||
{"question": "Let $\\triangle A B C$ be an isosceles triangle with $A B=A C$. Let $D$ be a point on the circumcircle of $\\triangle A B C$ on minor arc $A B$. Let $\\overline{A D}$ intersect the extension of $\\overline{B C}$ at $E$. Let $F$ be the midpoint of segment $A C$, and let $G$ be the intersection of $\\overline{E F}$ and $\\overline{A B}$. Let the extension of $\\overline{D G}$ intersect $\\overline{A C}$ and the circumcircle of $\\triangle A B C$ at $H$ and $I$, respectively. Given that $D G=3, G H=5$, and $H I=1$, compute the length of $A E$.", "answer": "\\frac{9\\sqrt{30}}{4}"}
|
||||
@ -1,8 +1,16 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
from utils import equal_func, prepare_prompt
|
||||
|
||||
from tensorrt_llm.scaffolding import (NativeGenerationController,
|
||||
ScaffoldingLlm, TRTLLMWorker)
|
||||
ScaffoldingLlm, TRTLLMWorker,
|
||||
extract_answer_from_boxed)
|
||||
from tensorrt_llm.scaffolding.contrib.DeepConf import (
|
||||
DeepConfOfflineController, DeepConfOfflineMajorityVoteController,
|
||||
DeepConfOnlineController, DeepConfOnlineMajorityVoteController)
|
||||
@ -28,35 +36,91 @@ def parse_arguments():
|
||||
required=True,
|
||||
choices=list(_RUN_TYPE_TO_IMPL.keys()),
|
||||
help="Type of the run. Available choices: %(choices)s")
|
||||
parser.add_argument('--sample_num', type=int, default=20)
|
||||
parser.add_argument('--conf_group_size', type=int, default=128)
|
||||
parser.add_argument('--warmup_sample_num', type=int, default=16)
|
||||
parser.add_argument('--sample_num', type=int, default=256)
|
||||
parser.add_argument('--conf_group_size', type=int, default=2048)
|
||||
parser.add_argument('--conf_threshold', type=float, default=0.5)
|
||||
parser.add_argument('--vote_policy',
|
||||
type=str,
|
||||
default="top10_bottom_window_filtered")
|
||||
parser.add_argument('--warmup_sample_num', type=int, default=5)
|
||||
parser.add_argument('--confidence_percentile', type=int, default=90)
|
||||
parser.add_argument('--confidence_percentile', type=int, default=10)
|
||||
parser.add_argument('--logprobs_topk', type=int, default=20)
|
||||
parser.add_argument('--max_tokens', type=int, default=8192)
|
||||
parser.add_argument('--max_tokens', type=int, default=64000)
|
||||
parser.add_argument('--temperature', type=float, default=0.6)
|
||||
parser.add_argument('--top_p', type=float, default=0.95)
|
||||
parser.add_argument('--top_k', type=int, default=0)
|
||||
parser.add_argument('--qid', type=int, default=-1)
|
||||
parser.add_argument('--dataset', type=str, default="brumo_2025.jsonl")
|
||||
parser.add_argument('--repeat_times', type=int, default=1)
|
||||
parser.add_argument('--tensor_parallel_size', type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def run_scaffolding_llm(prompts, proposer_worker, controller):
|
||||
@dataclass
|
||||
class BenchResult:
|
||||
right_answer_count: int = 0
|
||||
total_answer_count: int = 0
|
||||
accuracy: float = 0.0
|
||||
generated_tokens: int = 0
|
||||
|
||||
|
||||
def run_scaffolding_llm(prompts,
|
||||
proposer_worker,
|
||||
controller,
|
||||
repeat_times=1,
|
||||
ground_truth=None,
|
||||
**kwargs):
|
||||
llm = ScaffoldingLlm(
|
||||
controller,
|
||||
{
|
||||
NativeGenerationController.WorkerTag.GENERATION: proposer_worker,
|
||||
},
|
||||
)
|
||||
time_start = time.time()
|
||||
results = llm.generate(prompts)
|
||||
time_end = time.time()
|
||||
print(f"time cost: {time_end - time_start} seconds")
|
||||
for i, result in enumerate(results):
|
||||
print(f"result {i}:\n{result.outputs[0].text}")
|
||||
|
||||
is_majority_vote = isinstance(
|
||||
controller, DeepConfOnlineMajorityVoteController) or isinstance(
|
||||
controller, DeepConfOfflineMajorityVoteController)
|
||||
vote_policy_to_bench_result: Dict[str, BenchResult] = {}
|
||||
times = []
|
||||
for i in range(repeat_times):
|
||||
print(f"=========== round {i} ===========")
|
||||
start_time = time.time()
|
||||
results = llm.generate(prompts)
|
||||
times.append(time.time() - start_time)
|
||||
|
||||
for j, result in enumerate(results):
|
||||
print(
|
||||
f"result {j}: {extract_answer_from_boxed(result.outputs[0].text)}"
|
||||
)
|
||||
|
||||
if is_majority_vote and ground_truth is not None:
|
||||
vote_policy_to_voted_task = result.cur_output.vote_policy_to_voted_task
|
||||
for vote_policy, voted_task in vote_policy_to_voted_task.items(
|
||||
):
|
||||
bench_result = vote_policy_to_bench_result.get(
|
||||
vote_policy, BenchResult())
|
||||
|
||||
voted_answer = voted_task.customized_result_fields[
|
||||
'extracted_answer']
|
||||
if equal_func(voted_answer, ground_truth[j]):
|
||||
bench_result.right_answer_count += 1
|
||||
bench_result.total_answer_count += 1
|
||||
bench_result.generated_tokens += result.cur_output.output_token_num
|
||||
|
||||
vote_policy_to_bench_result[vote_policy] = bench_result
|
||||
|
||||
print(f"e2e inference median time cost: {np.median(times):.2f} seconds")
|
||||
|
||||
if is_majority_vote:
|
||||
for vote_policy, bench_result in vote_policy_to_bench_result.items():
|
||||
bench_result.accuracy = bench_result.right_answer_count / bench_result.total_answer_count
|
||||
print(
|
||||
f"vote_policy: {vote_policy}, accuracy: {bench_result.accuracy}"
|
||||
)
|
||||
|
||||
print(f"generated tokens: {bench_result.generated_tokens}")
|
||||
|
||||
llm.shutdown(shutdown_workers=True)
|
||||
|
||||
|
||||
@ -83,7 +147,8 @@ def test_single_vote_controller(prompts,
|
||||
conf_group_size=conf_group_size,
|
||||
conf_threshold=conf_threshold,
|
||||
)
|
||||
run_scaffolding_llm(prompts, proposer_worker, prototype_controller)
|
||||
run_scaffolding_llm(prompts, proposer_worker, prototype_controller,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def test_majority_vote_controller(prompts,
|
||||
@ -94,6 +159,7 @@ def test_majority_vote_controller(prompts,
|
||||
temperature,
|
||||
max_tokens,
|
||||
top_p,
|
||||
top_k,
|
||||
sample_num,
|
||||
warmup_sample_num,
|
||||
vote_policy,
|
||||
@ -106,6 +172,7 @@ def test_majority_vote_controller(prompts,
|
||||
"max_tokens": max_tokens,
|
||||
"num_logprobs": logprobs_topk,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
})
|
||||
DeepConfControllerKwargs = {
|
||||
"generation_controller": generation_controller,
|
||||
@ -125,7 +192,8 @@ def test_majority_vote_controller(prompts,
|
||||
vote_policy=vote_policy,
|
||||
warmup_sample_num=warmup_sample_num,
|
||||
confidence_percentile=confidence_percentile)
|
||||
run_scaffolding_llm(prompts, proposer_worker, majority_vote_controller)
|
||||
run_scaffolding_llm(prompts, proposer_worker, majority_vote_controller,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def main():
|
||||
@ -138,25 +206,39 @@ def main():
|
||||
"warmup_sample_num": args.warmup_sample_num,
|
||||
"confidence_percentile": args.confidence_percentile,
|
||||
"logprobs_topk": args.logprobs_topk,
|
||||
"max_tokens": args.max_tokens,
|
||||
"temperature": args.temperature,
|
||||
"top_p": args.top_p,
|
||||
"top_k": args.top_k,
|
||||
"repeat_times": args.repeat_times,
|
||||
"max_tokens": args.max_tokens,
|
||||
}
|
||||
|
||||
prompts = [
|
||||
"Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\r\n\r\n",
|
||||
"There exist real numbers $x$ and $y$, both greater than 1, such that $\\log_x\\left(y^x\\right)=\\log_y\\left(x^{4y}\\right)=10$. Find $xy$.",
|
||||
"Find the largest possible real part of \\[(75+117i)z+\\frac{96+144i}{z}\\]where $z$ is a complex number with $|z|=4$.",
|
||||
]
|
||||
|
||||
llm_worker = TRTLLMWorker.init_with_new_llm(
|
||||
args.model_dir,
|
||||
backend="pytorch",
|
||||
max_batch_size=32,
|
||||
max_num_tokens=kwargs.get("max_tokens"),
|
||||
max_batch_size=2048,
|
||||
max_num_tokens=args.max_tokens,
|
||||
)
|
||||
print(f"init llm worker done")
|
||||
|
||||
dataset_path = Path(__file__).parent / args.dataset
|
||||
with open(dataset_path, 'r', encoding='utf-8') as file:
|
||||
question_data = [json.loads(line.strip()) for line in file]
|
||||
|
||||
if args.qid != -1:
|
||||
question_data = [question_data[args.qid]]
|
||||
prompts = [
|
||||
prepare_prompt(question_data['question'], llm_worker.tokenizer)
|
||||
for question_data in question_data
|
||||
]
|
||||
ground_truth = [
|
||||
str(question_data.get('answer', '')).strip()
|
||||
for question_data in question_data
|
||||
]
|
||||
kwargs["ground_truth"] = ground_truth
|
||||
|
||||
print(f"has {len(prompts)} prompts")
|
||||
|
||||
if args.run_type == "offline" or args.run_type == "online":
|
||||
test_single_vote_controller(prompts,
|
||||
llm_worker,
|
||||
|
||||
49
examples/scaffolding/contrib/DeepConf/utils.py
Normal file
49
examples/scaffolding/contrib/DeepConf/utils.py
Normal file
@ -0,0 +1,49 @@
|
||||
from dynasor.core.evaluator import math_equal
|
||||
|
||||
|
||||
def quick_parse(text: str) -> str:
|
||||
"""Parse LaTeX text content."""
|
||||
if "\\text{" in text and "}" in text:
|
||||
# Find all occurrences of \text{...} and remove them
|
||||
while "\\text{" in text:
|
||||
start = text.find("\\text{")
|
||||
if start == -1:
|
||||
break
|
||||
end = text.find("}", start)
|
||||
if end == -1:
|
||||
break
|
||||
# Replace \text{content} with just content
|
||||
content = text[start + 6 : end] # 6 is length of '\text{'
|
||||
text = text[:start] + content + text[end + 1 :]
|
||||
return text
|
||||
|
||||
|
||||
def equal_func(answer: str, ground_truth: str) -> bool:
|
||||
"""Check if answer equals ground truth."""
|
||||
answer = quick_parse(answer)
|
||||
if len(answer) == 1 and answer.isalpha() and len(ground_truth) == 1 and ground_truth.isalpha():
|
||||
return answer.lower() == ground_truth.lower()
|
||||
else:
|
||||
return math_equal(answer, ground_truth)
|
||||
|
||||
|
||||
def prepare_prompt(question: str, tokenizer, model_type: str = "deepseek") -> str:
|
||||
"""Prepare prompt for a single question."""
|
||||
if model_type == "deepseek":
|
||||
# Format prompt using chat template for DeepSeek
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "该助手为DeepSeek-R1,由深度求索公司创造。\n今天是2025年5月28日,星期一。\n",
|
||||
},
|
||||
{"role": "user", "content": question},
|
||||
]
|
||||
else:
|
||||
# Format for GPT-like models
|
||||
messages = [{"role": "user", "content": question}]
|
||||
|
||||
full_prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
return full_prompt
|
||||
89
tensorrt_llm/scaffolding/contrib/DeepConf/README.md
Normal file
89
tensorrt_llm/scaffolding/contrib/DeepConf/README.md
Normal file
@ -0,0 +1,89 @@
|
||||
# DeepConf
|
||||
|
||||
This document shows how to speed up reasoning models without training or fine-tuning by using **DeepConf** ([Deep Think with Confidence](https://arxiv.org/abs/2508.15260)) in TensorRT-LLM.
|
||||
|
||||
## Overview
|
||||
|
||||
Deep Think with Confidence (DeepConf) is a parallel thinking method that enhances both LLM reasoning performance and efficiency at test time. It leverages model-internal confidence signals to dynamically filter low-quality reasoning traces during or after generation. It requires no additional model training or hyperparameter tuning and can be seamlessly integrated into existing serving frameworks. It achieves up to 99.9% accuracy on AIME 2025 while reducing generated tokens by up to 84.7% compared to standard thinking approaches.
|
||||
|
||||
## Usage
|
||||
|
||||
The core logic for **DeepConf** lives in `deep_conf_controller.py`, which contains four core classes organized in two layers:
|
||||
|
||||
**Base Controllers** (building blocks):
|
||||
1. **DeepConfOfflineController**: Wraps generation with confidence tracking, collecting logprobs for all generated tokens to build a `ConfidenceInfo` object for post-generation analysis. Serves as the foundation for offline voting and warmup phases.
|
||||
2. **DeepConfOnlineController**: Implements streaming generation with real-time confidence monitoring and dynamic early stopping when confidence thresholds are met. Serves as the foundation for online voting with early termination.
|
||||
|
||||
**Voting Controllers** (composed from base controllers):
|
||||
|
||||
> Thanks to the high flexibility of the **scaffolding framework**, we can easily leverage base controllers to handle confidence-related operations, while voting controllers remain focused on algorithmic logic without worrying about confidence tracking or online early-exit implementation details.
|
||||
|
||||
|
||||
3. **DeepConfOfflineMajorityVoteController**: Orchestrates parallel generation using multiple `DeepConfOfflineController` instances, then aggregates results via configurable majority voting strategies.
|
||||
4. **DeepConfOnlineMajorityVoteController**: Two-phase orchestration combining both base controllers: uses `DeepConfOfflineController` for warmup samples to calibrate thresholds, then `DeepConfOnlineController` for final samples with early stopping, aggregating all results through majority voting.
|
||||
|
||||
|
||||
You can adjust the behavior of DeepConf by passing different parameter values:
|
||||
|
||||
| Parameter | Description |
|
||||
|-----------|-------------|
|
||||
| `warmup_sample_num` | Number of warmup samples for calibrating confidence threshold |
|
||||
| `sample_num` | Total samples for majority voting (warmup + final) |
|
||||
| `conf_group_size` | Token chunk size for confidence checking intervals |
|
||||
| `conf_threshold` | Base confidence threshold for early stopping |
|
||||
| `confidence_percentile` | Percentile for computing threshold from warmup (lower = earlier stopping) |
|
||||
| `logprobs_topk` | Number of top logprobs to track per token |
|
||||
|
||||
### Quick Start
|
||||
|
||||
#### Offline Mode
|
||||
|
||||
```bash
|
||||
python3 examples/scaffolding/contrib/DeepConf/run_generation.py --model_dir deepseek-ai/DeepSeek-R1-0528-Qwen3-8B --run_type offline_majority_vote
|
||||
```
|
||||
|
||||
#### Online Mode
|
||||
|
||||
```bash
|
||||
python3 examples/scaffolding/contrib/DeepConf/run_generation.py --model_dir deepseek-ai/DeepSeek-R1-0528-Qwen3-8B --run_type online_majority_vote
|
||||
```
|
||||
|
||||
> **Note**: `run_generation.py` supports various configurable parameters (e.g., `--sample_num`, `--conf_group_size`, `--confidence_percentile`). See the parameter table above or check the code for detailed options.
|
||||
|
||||
## Results
|
||||
|
||||
Evaluated on the **brumo_2025.jsonl** dataset with the configuration of `warmup_sample_num=16`, `sample_num=256`, `conf_group_size=2048`, `confidence_percentile=10`, and `logprobs_topk=20`, the online mode achieves a 54.5% reduction in output tokens and approximately 1.92x speedup.
|
||||
|
||||
| Mode | Mean Gen Time | Mean Tokens |
|
||||
|---------|-----------------|---------------|
|
||||
| Online | 1506.4s | ~2.0M |
|
||||
| Offline | 2891.4s | ~4.4M |
|
||||
|
||||
Under the same configuration, confidence-based voting methods significantly improve accuracy, with `top10_bottom_window_filtered` boosting the accuracy from 88.14% (`basic_majority_vote`) to 94.92%.
|
||||
|
||||
| Vote Policy | Accuracy |
|
||||
|------------------------------|------------|
|
||||
| top10_bottom_window_filtered | 0.9492 |
|
||||
| top10_tail_filtered | 0.9153 |
|
||||
| mean_confidence_weighted | 0.8983 |
|
||||
| tail_confidence_weighted | 0.8983 |
|
||||
| bottom_window_weighted | 0.8983 |
|
||||
| min_window_weighted | 0.8983 |
|
||||
| basic_majority_vote | 0.8814 |
|
||||
| single_vote | 0.7966 |
|
||||
|
||||
## References
|
||||
|
||||
- Blog post: [Deep Think with Confidence](https://jiaweizzhao.github.io/deepconf/)
|
||||
- Paper: [https://arxiv.org/abs/2508.15260](https://arxiv.org/abs/2508.15260)
|
||||
- Codebase: [https://github.com/facebookresearch/deepconf](https://github.com/facebookresearch/deepconf)
|
||||
|
||||
If you use DeepConf for your research, please cite the [paper](https://arxiv.org/abs/2508.15260):
|
||||
```
|
||||
@article{fu2025deep,
|
||||
title={Deep think with confidence},
|
||||
author={Fu, Yichao and Wang, Xuewei and Tian, Yuandong and Zhao, Jiawei},
|
||||
journal={arXiv preprint arXiv:2508.15260},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
@ -100,6 +100,10 @@ def weighted_majority_vote(tasks: List[Task],
|
||||
|
||||
def majority_vote(tasks_list: List[List[Task]], vote_policy: str = 'majority'):
|
||||
tasks = [tasks[0] for tasks in tasks_list]
|
||||
output_token_num = sum([
|
||||
len(task.customized_result_fields['confidence_info'].conf_list)
|
||||
for task in tasks
|
||||
])
|
||||
for task in tasks:
|
||||
task.customized_result_fields[
|
||||
'extracted_answer'] = extract_answer_from_boxed(task.output_str)
|
||||
@ -113,7 +117,7 @@ def majority_vote(tasks_list: List[List[Task]], vote_policy: str = 'majority'):
|
||||
print(
|
||||
"Warning: No valid tasks, maybe you should increase max_output_len, a random task will be returned"
|
||||
)
|
||||
return random.choice(tasks)
|
||||
return {}, random.choice(tasks)
|
||||
|
||||
answers = [
|
||||
task.customized_result_fields['extracted_answer']
|
||||
@ -123,53 +127,44 @@ def majority_vote(tasks_list: List[List[Task]], vote_policy: str = 'majority'):
|
||||
task.customized_result_fields['confidence_info'] for task in valid_tasks
|
||||
]
|
||||
|
||||
match vote_policy:
|
||||
case 'majority':
|
||||
return basic_majority_vote(valid_tasks, answers=answers)
|
||||
case 'mean_confidence_weighted':
|
||||
mean_confidences = [conf.mean_conf for conf in confidences]
|
||||
return weighted_majority_vote(valid_tasks,
|
||||
answers=answers,
|
||||
confidences=mean_confidences,
|
||||
type=vote_policy)
|
||||
case 'tail_confidence_weighted':
|
||||
tail_confidences = [conf.tail_mean_conf for conf in confidences]
|
||||
return weighted_majority_vote(valid_tasks,
|
||||
answers=answers,
|
||||
confidences=tail_confidences,
|
||||
type=vote_policy)
|
||||
case 'bottom_window_weighted':
|
||||
bottom_window_confidences = [
|
||||
conf.bottom_window_mean_conf for conf in confidences
|
||||
]
|
||||
return weighted_majority_vote(valid_tasks,
|
||||
answers=answers,
|
||||
confidences=bottom_window_confidences,
|
||||
type=vote_policy)
|
||||
case 'min_window_weighted':
|
||||
min_window_confidences = [
|
||||
conf.min_window_conf for conf in confidences
|
||||
]
|
||||
return weighted_majority_vote(valid_tasks,
|
||||
answers=answers,
|
||||
confidences=min_window_confidences,
|
||||
type=vote_policy)
|
||||
case 'top10_tail_filtered':
|
||||
tail_confidences = [conf.tail_mean_conf for conf in confidences]
|
||||
return weighted_majority_vote(valid_tasks,
|
||||
answers=answers,
|
||||
confidences=tail_confidences,
|
||||
filter_top_percent=0.1,
|
||||
type=vote_policy)
|
||||
case 'top10_bottom_window_filtered':
|
||||
bottom_window_confidences = [
|
||||
conf.bottom_window_mean_conf for conf in confidences
|
||||
]
|
||||
return weighted_majority_vote(valid_tasks,
|
||||
answers=answers,
|
||||
confidences=bottom_window_confidences,
|
||||
filter_top_percent=0.1,
|
||||
type=vote_policy)
|
||||
case _:
|
||||
raise NotImplementedError(
|
||||
f"Vote policy '{vote_policy}' is not implemented")
|
||||
mean_confs = [conf.mean_conf for conf in confidences]
|
||||
tail_confs = [conf.tail_mean_conf for conf in confidences]
|
||||
bottom_window_confs = [conf.bottom_window_mean_conf for conf in confidences]
|
||||
min_window_confs = [conf.min_window_conf for conf in confidences]
|
||||
vote_policy_to_voted_task = {
|
||||
'single_vote':
|
||||
random.choice(valid_tasks),
|
||||
'basic_majority_vote':
|
||||
basic_majority_vote(valid_tasks, answers=answers),
|
||||
'mean_confidence_weighted':
|
||||
weighted_majority_vote(valid_tasks,
|
||||
answers=answers,
|
||||
confidences=mean_confs),
|
||||
'tail_confidence_weighted':
|
||||
weighted_majority_vote(valid_tasks,
|
||||
answers=answers,
|
||||
confidences=tail_confs),
|
||||
'bottom_window_weighted':
|
||||
weighted_majority_vote(valid_tasks,
|
||||
answers=answers,
|
||||
confidences=bottom_window_confs),
|
||||
'min_window_weighted':
|
||||
weighted_majority_vote(valid_tasks,
|
||||
answers=answers,
|
||||
confidences=min_window_confs),
|
||||
'top10_tail_filtered':
|
||||
weighted_majority_vote(valid_tasks,
|
||||
answers=answers,
|
||||
confidences=tail_confs,
|
||||
filter_top_percent=0.1),
|
||||
'top10_bottom_window_filtered':
|
||||
weighted_majority_vote(valid_tasks,
|
||||
answers=answers,
|
||||
confidences=bottom_window_confs,
|
||||
filter_top_percent=0.1),
|
||||
}
|
||||
|
||||
voted_task = vote_policy_to_voted_task[vote_policy]
|
||||
voted_task.result.vote_policy_to_voted_task = vote_policy_to_voted_task
|
||||
voted_task.result.output_token_num = output_token_num
|
||||
return voted_task
|
||||
|
||||
Loading…
Reference in New Issue
Block a user