# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. import math import random from dataclasses import dataclass, field from enum import Enum from typing import Any, Dict, List, Optional from tensorrt_llm.executor.result import CompletionOutput, GenerationResult from tensorrt_llm.scaffolding.controller import Controller, ParallelProcess from tensorrt_llm.scaffolding.task import GenerationTask, Task @dataclass class TreeNode: """Base class for tree nodes in tree-based inference methods.""" state: str = "" parent: Optional['TreeNode'] = None children: List['TreeNode'] = field(default_factory=list) visits: int = 0 value: float = 0.0 is_terminal: bool = False depth: int = 0 def is_leaf(self) -> bool: """node has no children""" return len(self.children) == 0 def is_root(self) -> bool: """node has no parent?""" return self.parent is None def add_child(self, child: 'TreeNode') -> 'TreeNode': child.parent = self child.depth = self.depth + 1 self.children.append(child) return child def get_path_to_root(self) -> List['TreeNode']: """Get the path from this node to the root.""" path = [] current = self while current is not None: path.append(current) current = current.parent return list(reversed(path)) @dataclass class MCTSNode(TreeNode): """Node for Monte Carlo Tree Search.""" reward: float = 0.0 untried_actions: List[str] = field(default_factory=list) def ucb1_score(self, exploration_constant: float = 1.414) -> float: """Calculate UCB1 score for this node.""" if self.visits == 0: return float('inf') if self.parent is None or self.parent.visits == 0: return float('inf') exploitation = self.value / self.visits exploration = exploration_constant * math.sqrt( math.log(self.parent.visits) / self.visits) return exploitation + exploration def select_best_child(self, exploration_constant: float = 1.414) -> 'MCTSNode': """Select the child with the highest UCB1 score.""" if not self.children: return self return max(self.children, key=lambda child: child.ucb1_score(exploration_constant)) @dataclass class TOTNode(TreeNode): """Node for Tree of Thoughts.""" thought: str = "" confidence: str = "Medium" reasoning: str = "" evaluation_score: float = 0.0 def __post_init__(self): if not self.thought and self.state: self.thought = self.state class MCTSController(Controller): """Monte Carlo Tree Search Controller for scaffolding framework.""" class WorkerTag(Enum): GENERATION = "generation" REWARD = "reward" def __init__(self, generation_controller: Controller, reward_controller: Optional[Controller] = None, max_depth: int = 5, max_iterations: int = 100, exploration_constant: float = 1.414, num_thoughts_per_step: int = 3, expansion_parallel_samples: int = 1): super().__init__() self.generation_controller = generation_controller self.reward_controller = reward_controller self.max_depth = max_depth self.max_iterations = max_iterations self.exploration_constant = exploration_constant self.num_thoughts_per_step = num_thoughts_per_step self.expansion_parallel_samples = max(1, expansion_parallel_samples) def process(self, tasks: List[Task], **kwargs) -> Any: """Process tasks using MCTS with yield-based orchestration.""" assert len( tasks) == 1, "MCTS Controller only supports single task processing" task = tasks[0] goal = kwargs.get('goal', 'Solve the problem step by step') initial_state = getattr(task, 'input_str', str(task)) or "" root = MCTSNode(state=initial_state) for iteration in range(self.max_iterations): # Selection node = root while not node.is_leaf() and not node.is_terminal: node = node.select_best_child(self.exploration_constant) # Expansion if not node.is_terminal and node.depth < self.max_depth: prompt = self._create_expansion_prompt(node, goal) # Parallelize expansion sampling if requested gen_controllers: List[Controller] = [] gen_tasks_wrapped: List[List[GenerationTask]] = [] gen_kwargs_list: List[Dict[str, Any]] = [] for _ in range(self.expansion_parallel_samples): gen_task = GenerationTask.create_from_prompt(prompt) gen_task.max_tokens = 200 gen_task.temperature = 0.7 gen_controllers.append(self.generation_controller.clone()) gen_tasks_wrapped.append([gen_task]) gen_kwargs_list.append({}) if gen_controllers: yield ParallelProcess(gen_controllers, gen_tasks_wrapped, gen_kwargs_list) # Collect and merge thoughts from all parallel samples merged_thoughts: List[str] = [] seen = set() for [gen_task] in gen_tasks_wrapped: for t in self._parse_thoughts(gen_task.output_str or ""): if t not in seen: merged_thoughts.append(t) seen.add(t) for thought in merged_thoughts[:self.num_thoughts_per_step]: child_state = f"{node.state}\n{thought}".strip() node.add_child(MCTSNode(state=child_state, reward=0.0)) if node.children: node = random.choice(node.children) # Evaluate node if self.reward_controller is not None: reward_task = GenerationTask() reward_task.input_str = initial_state completion_output = CompletionOutput(index=0, text=node.state) from tensorrt_llm.sampling_params import SamplingParams mock_sampling_params = SamplingParams() reward_result = GenerationResult.__new__(GenerationResult) reward_result._outputs = [completion_output] reward_result.sampling_params = mock_sampling_params reward_task.result = reward_result yield from self.reward_controller.process([reward_task]) # Get reward from the reward controller if hasattr(self.reward_controller, 'scores') and self.reward_controller.scores: reward = float(self.reward_controller.scores[0]) else: reward = 0.5 # Default reward else: reward = min(1.0, len(node.state.split()) / 100.0) # Backpropagation self._backpropagate(node, reward) # Pick best leaf best_leaf = self._select_best_leaf(root) path = best_leaf.get_path_to_root() # Final answer generation based on the best path steps_desc = [] if path: steps_desc.append(f"Problem: {path[0].state}") for i in range(1, len(path)): # Each child state's last line is the appended thought last_line = (path[i].state.split('\n')[-1]).strip() steps_desc.append(f"Step {i}: {last_line}") reasoning = "\n".join(steps_desc) final_prompt = ( f"{goal}\n\nHere is a coherent reasoning trajectory.\n" f"{reasoning}\n\nNow provide the final answer succinctly.") final_task = GenerationTask.create_from_prompt(final_prompt) final_task.max_tokens = 256 final_task.temperature = 0.2 yield from self.generation_controller.process([final_task]) # Assign the result to the original task tasks[0].result = final_task.result def _create_expansion_prompt(self, node: MCTSNode, goal: str) -> str: """Create a prompt for expanding a node.""" return f"""Goal: {goal} Current state: {node.state} Generate {self.num_thoughts_per_step} possible next steps or thoughts to progress toward the goal. Each thought should be a coherent reasoning step or action. Format your response as: 1. [First thought] 2. [Second thought] 3. [Third thought] ... """ def _parse_thoughts(self, text: str) -> List[str]: """Parse generated thoughts from text.""" thoughts = [] lines = (text or "").strip().split('\n') for line in lines: line = line.strip() if not line: continue if line[0].isdigit() and '.' in line: thought = line.split('.', 1)[-1].strip() if thought: thoughts.append(thought) elif line.startswith(('-', '*')): thoughts.append(line[1:].strip()) return thoughts def _backpropagate(self, node: MCTSNode, reward: float): """Backpropagate the reward up the tree.""" current = node while current is not None: current.visits += 1 current.value += reward current = current.parent def _select_best_leaf(self, root: MCTSNode) -> MCTSNode: """Select the best leaf node from the tree.""" best_leaf = root best_score = -float('inf') def traverse(node: MCTSNode): nonlocal best_leaf, best_score if node.is_leaf() and node.visits > 0: avg_value = node.value / node.visits if avg_value > best_score: best_score = avg_value best_leaf = node for child in node.children: traverse(child) traverse(root) return best_leaf class TOTController(Controller): """Tree of Thoughts Controller for scaffolding framework.""" class WorkerTag(Enum): GENERATION = "generation" REWARD = "reward" def __init__(self, generation_controller: Controller, reward_controller: Optional[Controller] = None, max_depth: int = 4, max_iterations: int = 50, num_thoughts_per_step: int = 3, selection_strategy: str = "best", branch_factor: int = 2): super().__init__() self.generation_controller = generation_controller self.reward_controller = reward_controller self.max_depth = max_depth self.max_iterations = max_iterations self.num_thoughts_per_step = num_thoughts_per_step self.selection_strategy = selection_strategy # "best", "vote", "random" self.branch_factor = max(1, branch_factor) def process(self, tasks: List[Task], **kwargs) -> Any: """Process tasks using Tree of Thoughts with yield-based orchestration.""" assert len( tasks) == 1, "TOT Controller only supports single task processing" task = tasks[0] goal = kwargs.get('goal', 'Solve the problem step by step') root_state = getattr(task, 'input_str', str(task)) or "" root = TOTNode(state=root_state, thought="Initial problem") current_level: List[TOTNode] = [root] iterations = 0 stop = False for depth in range(self.max_depth): if stop: break next_level: List[TOTNode] = [] # 1) Parallel generation for all nodes in the current level gen_controllers: List[Controller] = [] gen_tasks_wrapped: List[List[GenerationTask]] = [] gen_kwargs_list: List[Dict[str, Any]] = [] node_order: List[TOTNode] = [] for node in current_level: if stop: break if node.is_terminal: continue gen_prompt = self._generate_prompt(node, goal) gen_task = GenerationTask.create_from_prompt(gen_prompt) gen_task.max_tokens = 512 gen_task.temperature = 0.8 gen_controllers.append(self.generation_controller.clone()) gen_tasks_wrapped.append([gen_task]) gen_kwargs_list.append({}) node_order.append(node) iterations += 1 if (iterations >= self.max_iterations): stop = True break if gen_controllers: yield ParallelProcess(gen_controllers, gen_tasks_wrapped, gen_kwargs_list) # 2) Parse thoughts per node, then (optionally) reward scoring per node evaluated_by_node: Dict[int, List[Dict[str, Any]]] = {} # Prepare a single batched reward request across all nodes (leverages worker concurrency) all_reward_tasks: List[GenerationTask] = [] node_to_task_indices: Dict[int, List[int]] = {} for idx, (node, [gen_task ]) in enumerate(zip(node_order, gen_tasks_wrapped)): thoughts = self._parse_approaches(gen_task.output_str or "") evaluated_thoughts: List[Dict[str, Any]] = [] if not thoughts: evaluated_by_node[idx] = evaluated_thoughts continue if self.reward_controller is not None: # Build reward tasks for this node reward_indices_for_node: List[int] = [] from tensorrt_llm.sampling_params import SamplingParams for thought in thoughts[:self.num_thoughts_per_step]: reward_task = GenerationTask() reward_task.input_str = root_state candidate_content = self._combine_state_and_thought( node.state, thought) completion_output = CompletionOutput( index=0, text=candidate_content) mock_sampling_params = SamplingParams() reward_result = GenerationResult.__new__( GenerationResult) reward_result._outputs = [completion_output] reward_result.sampling_params = mock_sampling_params reward_task.result = reward_result reward_indices_for_node.append(len(all_reward_tasks)) all_reward_tasks.append(reward_task) node_to_task_indices[idx] = reward_indices_for_node evaluated_by_node[idx] = [{ 'thought': t, 'score': 0.0, 'confidence': 'Medium', 'reasoning': 'PRM score' } for t in thoughts[:self.num_thoughts_per_step]] else: # Fallback: sequential lightweight LLM self-eval for this node for thought in thoughts[:self.num_thoughts_per_step]: eval_prompt = self._evaluation_prompt( thought, goal, node.state) eval_task = GenerationTask.create_from_prompt( eval_prompt) eval_task.max_tokens = 256 eval_task.temperature = 0.3 yield from self.generation_controller.process( [eval_task]) evaluation = self._parse_evaluation(eval_task.output_str or "") evaluated_thoughts.append({ 'thought': thought, 'score': evaluation['score'], 'confidence': evaluation['confidence'], 'reasoning': evaluation['reasoning'] }) evaluated_by_node[idx] = evaluated_thoughts # Run all reward evaluations in a single batch if self.reward_controller is not None and all_reward_tasks: yield from self.reward_controller.process(all_reward_tasks) scores = getattr(self.reward_controller, 'scores', None) or [] for node_idx, indices in node_to_task_indices.items(): thoughts_for_node = evaluated_by_node[node_idx] for local_j, task_index in enumerate(indices): if task_index < len(scores): normalized_score = float(scores[task_index]) if 0.0 <= normalized_score <= 1.0: normalized_score *= 10.0 thoughts_for_node[local_j][ 'score'] = normalized_score thoughts_for_node[local_j]['confidence'] = ( 'High' if normalized_score >= 8.0 else 'Medium' if normalized_score >= 5.0 else 'Low') # 3) Selection and child creation for idx, node in enumerate(node_order): evaluated_thoughts = evaluated_by_node.get(idx, []) if not evaluated_thoughts: continue selected_thoughts = self._select_thoughts(evaluated_thoughts) if not selected_thoughts: continue for thought_data in selected_thoughts: child_state = self._combine_state_and_thought( node.state, thought_data['thought']) child = TOTNode(state=child_state, thought=thought_data['thought'], confidence=thought_data['confidence'], reasoning=thought_data['reasoning'], evaluation_score=thought_data['score']) node.add_child(child) next_level.append(child) if stop or not next_level: break current_level = next_level # Choose best leaf solution best_node = self._select_best_solution(root) path = best_node.get_path_to_root() steps_desc = [] for i, n in enumerate(path): if i == 0: steps_desc.append(f"Problem: {n.state}") else: steps_desc.append(f"Step {i}: {n.thought}") reasoning = "\n".join(steps_desc) # Generate final solution based on selected thoughts final_prompt = ( f"{goal}\n\nYou have the following proposed steps. Use them to produce the final answer.\n" f"{reasoning}\n\nProvide the final answer succinctly.") final_task = GenerationTask.create_from_prompt(final_prompt) final_task.max_tokens = 1024 final_task.temperature = 0.2 # If the model uses R1-style blocks, stop at the closing tag to avoid extra content try: if isinstance(final_task.stop, list): if '' not in final_task.stop: final_task.stop.append('') elif isinstance(final_task.stop, str) and final_task.stop: final_task.stop = [final_task.stop, ''] else: final_task.stop = [''] except Exception: final_task.stop = [''] yield from self.generation_controller.process([final_task]) tasks[0].result = final_task.result def _generate_prompt(self, node: TOTNode, goal: str) -> str: return f"""Goal: {goal} Current progress: {node.state} Generate {self.num_thoughts_per_step} different approaches or next steps to progress toward the goal. Each approach should be distinct and well-reasoned. Format your response as: Approach 1: [detailed approach] Approach 2: [detailed approach] Approach 3: [detailed approach]""" def _parse_approaches(self, text: str) -> List[str]: approaches: List[str] = [] lines = (text or "").strip().split('\n') current: List[str] = [] import re def is_new_item(line: str) -> Optional[str]: """Return the content of a new item header if the line starts a new approach/step item. Supports: - 'Approach N: ...' or 'Step N: ...' - 'N. ...' - '- ...' or '* ...' """ line_stripped = line.strip() if not line_stripped: return None # Approach/Step N: ... m = re.match(r'^(?:approach|step)\s*\d+\s*[:\-\.]\s*(.*)$', line_stripped, flags=re.IGNORECASE) if m: return m.group(1).strip() # Numbered list: '1. ...' m = re.match(r'^\d+\.\s*(.*)$', line_stripped) if m: return m.group(1).strip() # Bulleted list: '- ...' or '* ...' m = re.match(r'^[\-\*]\s*(.*)$', line_stripped) if m: return m.group(1).strip() return None def flush_current(): nonlocal current if current: content = ' '.join(s for s in current).strip() if content: approaches.append(content) current = [] for raw_line in lines: line = raw_line.strip() if not line: continue header_content = is_new_item(line) if header_content is not None: # Start of a new item flush_current() if header_content: current = [header_content] else: current = [] else: # Continuation of the current item current.append(line) flush_current() # Fallbacks if nothing was parsed if approaches: return approaches if not approaches: # Try splitting by blank lines paragraphs = [ p.strip() for p in re.split(r'\n\s*\n', text or '') if p.strip() ] if paragraphs: return paragraphs # Final fallback to the whole text if (text or '').strip(): return [(text or '').strip()] return [] def _evaluation_prompt(self, thought: str, goal: str, current_state: str) -> str: return f"""Goal: {goal} Current state: {current_state} Proposed next step: {thought} Evaluate this proposed step on a scale of 1-10 considering: 1. How well it progresses toward the goal 2. How feasible it is to execute 3. How likely it is to lead to a successful solution Provide your evaluation in this format: Score: [1-10] Confidence: [High/Medium/Low] Reasoning: [brief explanation]""" def _parse_evaluation(self, text: str) -> Dict[str, Any]: lines = (text or "").strip().split('\n') score = 5.0 confidence = 'Medium' reasoning_lines: List[str] = [] import re reasoning_started = False for raw_line in lines: line = raw_line.strip() if not line: continue lower = line.lower() if 'score' in lower: # Extract the first number (handles '8', '8.5', '8/10') nums = re.findall(r'(\d+(?:\.\d+)?)', line) if nums: try: parsed = float(nums[0]) # Normalize if it's like '0.8' (unlikely) or clamp if parsed <= 1.0 and '/10' in lower: parsed *= 10.0 score = max(0.0, min(10.0, parsed)) except Exception: pass elif 'confidence' in lower: # Normalize to High/Medium/Low if possible val = line.split(':', 1)[1].strip() if ':' in line else line v = val.lower() if 'high' in v or 'strong' in v: confidence = 'High' elif 'low' in v or 'weak' in v: confidence = 'Low' else: confidence = 'Medium' elif 'reason' in lower: reasoning_started = True val = line.split(':', 1)[1].strip() if ':' in line else line if val: reasoning_lines.append(val) else: if reasoning_started: reasoning_lines.append(line) reasoning = ' '.join(reasoning_lines).strip() or 'No reasoning provided' return { 'score': score, 'confidence': confidence, 'reasoning': reasoning } def _select_thoughts( self, evaluated_thoughts: List[Dict[str, Any]]) -> List[Dict[str, Any]]: if not evaluated_thoughts: return [] if self.selection_strategy == "best": return sorted( evaluated_thoughts, key=lambda x: x['score'], reverse=True)[:min(self.branch_factor, len(evaluated_thoughts))] if self.selection_strategy == "vote": # Confidence-aware selection: prioritize High > Medium > Low confidence, then score def confidence_weight(conf: str) -> int: c = (conf or '').lower() if 'high' in c: return 2 if 'low' in c: return 0 return 1 return sorted( evaluated_thoughts, key=lambda x: (confidence_weight(x.get('confidence', 'Medium')), x.get('score', 0.0)), reverse=True)[:min(self.branch_factor, len(evaluated_thoughts))] return random.sample(evaluated_thoughts, min(self.branch_factor, len(evaluated_thoughts))) def _combine_state_and_thought(self, current_state: str, thought: str) -> str: """Combine current state with new thought.""" if not current_state.strip(): return thought return f"{current_state}\n\nNext step: {thought}" def _select_best_solution(self, root: TOTNode) -> TOTNode: """Select the best solution from all leaf nodes.""" best_node = root best_score = -float('inf') def traverse(node: TOTNode): nonlocal best_node, best_score if node.is_leaf(): # For leaf nodes, use evaluation score if available, otherwise use depth as heuristic score = node.evaluation_score if hasattr( node, 'evaluation_score') else node.depth if score > best_score: best_score = score best_node = node for child in node.children: traverse(child) traverse(root) return best_node