mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
116 lines
4.2 KiB
Python
116 lines
4.2 KiB
Python
import argparse
|
|
import copy
|
|
from typing import List
|
|
|
|
from tensorrt_llm.scaffolding import (Controller, GenerationTokenCounter,
|
|
NativeGenerationController,
|
|
ParallelProcess, ScaffoldingLlm, Task,
|
|
TRTLLMWorker, extract_answer_from_boxed,
|
|
get_digit_majority_vote_result,
|
|
with_task_collection)
|
|
|
|
|
|
@with_task_collection("token_counter", GenerationTokenCounter)
|
|
class TokenBudgetMajorityVoteController(Controller):
|
|
|
|
def __init__(self, generation_controller: Controller, token_budget: int,
|
|
sumple_num_per_turn: int):
|
|
super().__init__()
|
|
self.generation_controller = generation_controller
|
|
self.token_budget = token_budget
|
|
self.sumple_num_per_turn = sumple_num_per_turn
|
|
|
|
def clone(self):
|
|
generation_controller = self.generation_controller.clone()
|
|
return TokenBudgetMajorityVoteController(generation_controller,
|
|
self.token_budget,
|
|
self.sumple_num_per_turn)
|
|
|
|
def process(self, tasks: List[Task], **kwargs):
|
|
candidates = []
|
|
# use GenerationTokenCounter to get the total token count from this controller
|
|
while self.task_collections[
|
|
"token_counter"].generation_token_count < self.token_budget:
|
|
sample_num = self.sumple_num_per_turn
|
|
generation_controllers = [
|
|
self.generation_controller.clone() for _ in range(sample_num)
|
|
]
|
|
tasks_list = [copy.deepcopy(tasks) for _ in range(sample_num)]
|
|
generation_kwargs_list = [
|
|
copy.deepcopy(kwargs) for _ in range(sample_num)
|
|
]
|
|
|
|
yield ParallelProcess(generation_controllers, tasks_list,
|
|
generation_kwargs_list)
|
|
|
|
for task_list in tasks_list:
|
|
candidates.extend([task.output_str for task in task_list])
|
|
|
|
result = self.majority_vote(candidates, **kwargs)
|
|
print(
|
|
'final token count: ',
|
|
str(self.task_collections["token_counter"].generation_token_count))
|
|
|
|
assert isinstance(result, str), "majority_vote failed"
|
|
# The task returned by majority vote does not have output_tokens and logits.
|
|
tasks[0].output_str = result
|
|
|
|
def majority_vote(self, candidates: List[str], **kwargs) -> str:
|
|
return get_digit_majority_vote_result(candidates)
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser()
|
|
# .e.g. DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B
|
|
parser.add_argument(
|
|
'--model_dir',
|
|
type=str,
|
|
required=True,
|
|
help="Path to the directory containing the generation model")
|
|
parser.add_argument('--token_budget', type=int, default=30384)
|
|
parser.add_argument('--sumple_num_per_turn', type=int, default=3)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_arguments()
|
|
workers = {}
|
|
|
|
llm_worker = TRTLLMWorker.init_with_new_llm(
|
|
args.model_dir,
|
|
max_batch_size=32,
|
|
max_num_tokens=4096,
|
|
)
|
|
|
|
prototype_generation_controller = NativeGenerationController(
|
|
sampling_params={
|
|
"max_tokens": 4096,
|
|
"top_p": 0.9,
|
|
"temperature": 0.9,
|
|
})
|
|
workers[NativeGenerationController.WorkerTag.GENERATION] = llm_worker
|
|
|
|
prototype_majority_vote_controller = TokenBudgetMajorityVoteController(
|
|
generation_controller=prototype_generation_controller,
|
|
token_budget=args.token_budget,
|
|
sumple_num_per_turn=args.sumple_num_per_turn,
|
|
)
|
|
|
|
llm = ScaffoldingLlm(
|
|
prototype_majority_vote_controller,
|
|
workers=workers,
|
|
)
|
|
prompt = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\r\n\r\n"
|
|
|
|
result = llm.generate(prompt)
|
|
extracted_answer = extract_answer_from_boxed(result.outputs[0].text)
|
|
print(f'extracted_answer={extracted_answer}')
|
|
|
|
llm.shutdown(shutdown_workers=True)
|
|
print(f'main shut down done')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|