TensorRT-LLMs/examples/scaffolding/token_budget_majority_vote.py
WeiHaocheng 0f01826dde
feat: support task collection for to collect information (#3328) (#3824)
Signed-off-by: fredw (generated by with_the_same_user script) <20514172+WeiHaocheng@users.noreply.github.com>
2025-05-09 17:09:01 +08:00

115 lines
4.4 KiB
Python

import argparse
import copy
from typing import List
from tensorrt_llm.scaffolding import (Controller, GenerationTokenCounter,
NativeGenerationController,
ParallelProcess, ScaffoldingLlm, Task,
TRTLLMWorker, extract_answer_from_boxed,
get_digit_majority_vote_result,
with_task_collection)
@with_task_collection("token_counter", GenerationTokenCounter)
class TokenBudgetMajorityVoteController(Controller):
def __init__(self, generation_controller: Controller, token_budget: int,
sumple_num_per_turn: int):
super().__init__()
self.generation_controller = generation_controller
self.token_budget = token_budget
self.sumple_num_per_turn = sumple_num_per_turn
def clone(self):
generation_controller = self.generation_controller.clone()
return TokenBudgetMajorityVoteController(generation_controller,
self.token_budget,
self.sumple_num_per_turn)
def process(self, tasks: List[Task], **kwargs):
candidates = []
# use GenerationTokenCounter to get the total token count from this controller
while self.task_collections[
"token_counter"].generation_token_count < self.token_budget:
sample_num = self.sumple_num_per_turn
generation_controllers = [
self.generation_controller.clone() for _ in range(sample_num)
]
tasks_list = [copy.deepcopy(tasks) for _ in range(sample_num)]
generation_kwargs_list = [
copy.deepcopy(kwargs) for _ in range(sample_num)
]
yield ParallelProcess(generation_controllers, tasks_list,
generation_kwargs_list)
for task_list in tasks_list:
candidates.extend([task.output_str for task in task_list])
result = self.majority_vote(candidates, **kwargs)
print(
'final token count: ',
str(self.task_collections["token_counter"].generation_token_count))
assert isinstance(result, str), "majority_vote failed"
# The task returned by majority vote does not have output_tokens and logits.
tasks[0].output_str = result
def majority_vote(self, candidates: List[str], **kwargs) -> str:
return get_digit_majority_vote_result(candidates)
def parse_arguments():
parser = argparse.ArgumentParser()
# .e.g. DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B
parser.add_argument(
'--model_dir',
type=str,
required=True,
help="Path to the directory containing the generation model")
parser.add_argument('--token_budget', type=int, default=30384)
parser.add_argument('--sumple_num_per_turn', type=int, default=3)
args = parser.parse_args()
return args
def main():
args = parse_arguments()
workers = {}
llm_worker = TRTLLMWorker.init_with_new_llm(args.model_dir,
backend="pytorch",
max_batch_size=32,
max_num_tokens=4096,
temperature=0.9)
prototype_generation_controller = NativeGenerationController(
custom_sampling_params={
"max_tokens": 4096,
"top_p": 0.9,
})
workers[NativeGenerationController.WorkerTag.GENERATION] = llm_worker
prototype_majority_vote_controller = TokenBudgetMajorityVoteController(
generation_controller=prototype_generation_controller,
token_budget=args.token_budget,
sumple_num_per_turn=args.sumple_num_per_turn,
)
llm = ScaffoldingLlm(
prototype_majority_vote_controller,
workers=workers,
)
prompt = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\r\n\r\n"
result = llm.generate(prompt)
extracted_answer = extract_answer_from_boxed(result.output.output_str)
print(f'extracted_answer={extracted_answer}')
llm.shutdown(shutdown_workers=True)
print(f'main shut down done')
if __name__ == '__main__':
main()