Push drift_k_followups through to prompt

This commit is contained in:
Nathan Evans 2026-01-12 18:00:21 -08:00
parent a380a58f4b
commit 780a46d1b1
5 changed files with 11 additions and 8 deletions

View File

@ -65,7 +65,7 @@ Pay close attention specifically to the Sources tables as it contains the most r
Add sections and commentary to the response as appropriate for the length and format. Add sections and commentary to the response as appropriate for the length and format.
Additionally provide a score between 0 and 100 representing how well the response addresses the overall research question: {global_query}. Based on your response, suggest up to five follow-up questions that could be asked to further explore the topic as it relates to the overall research question. Do not include scores or follow up questions in the 'response' field of the JSON, add them to the respective 'score' and 'follow_up_queries' keys of the JSON output. Format your response in JSON with the following keys and values: Additionally provide a score between 0 and 100 representing how well the response addresses the overall research question: {global_query}. Based on your response, suggest up to {followups} follow-up questions that could be asked to further explore the topic as it relates to the overall research question. Do not include scores or follow up questions in the 'response' field of the JSON, add them to the respective 'score' and 'follow_up_queries' keys of the JSON output. Format your response in JSON with the following keys and values:
{{'response': str, Put your answer, formatted in markdown, here. Do not answer the global query in this section. {{'response': str, Put your answer, formatted in markdown, here. Do not answer the global query in this section.
'score': int, 'score': int,

View File

@ -50,7 +50,7 @@ class DriftAction:
"""Check if the action is complete (i.e., an answer is available).""" """Check if the action is complete (i.e., an answer is available)."""
return self.answer is not None return self.answer is not None
async def search(self, search_engine: Any, global_query: str, scorer: Any = None): async def search(self, search_engine: Any, global_query: str, k_followups: int, scorer: Any = None):
""" """
Execute an asynchronous search using the search engine, and update the action with the results. Execute an asynchronous search using the search engine, and update the action with the results.
@ -71,7 +71,9 @@ class DriftAction:
return self return self
search_result = await search_engine.search( search_result = await search_engine.search(
drift_query=global_query, query=self.query query=self.query,
drift_query=global_query,
k_followups=k_followups,
) )
# Do not launch exception as it will roll up with other steps # Do not launch exception as it will roll up with other steps

View File

@ -40,6 +40,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
def __init__( def __init__(
self, self,
model: ChatModel, model: ChatModel,
config: DRIFTSearchConfig,
text_embedder: EmbeddingModel, text_embedder: EmbeddingModel,
entities: list[Entity], entities: list[Entity],
entity_text_embeddings: BaseVectorStore, entity_text_embeddings: BaseVectorStore,
@ -49,14 +50,13 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
covariates: dict[str, list[Covariate]] | None = None, covariates: dict[str, list[Covariate]] | None = None,
tokenizer: Tokenizer | None = None, tokenizer: Tokenizer | None = None,
embedding_vectorstore_key: str = EntityVectorStoreKey.ID, embedding_vectorstore_key: str = EntityVectorStoreKey.ID,
config: DRIFTSearchConfig | None = None,
local_system_prompt: str | None = None, local_system_prompt: str | None = None,
local_mixed_context: LocalSearchMixedContext | None = None, local_mixed_context: LocalSearchMixedContext | None = None,
reduce_system_prompt: str | None = None, reduce_system_prompt: str | None = None,
response_type: str | None = None, response_type: str | None = None,
): ):
"""Initialize the DRIFT search context builder with necessary components.""" """Initialize the DRIFT search context builder with necessary components."""
self.config = config or DRIFTSearchConfig() self.config = config
self.model = model self.model = model
self.text_embedder = text_embedder self.text_embedder = text_embedder
self.tokenizer = tokenizer or get_tokenizer() self.tokenizer = tokenizer or get_tokenizer()

View File

@ -156,7 +156,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
raise ValueError(error_msg) raise ValueError(error_msg)
async def _search_step( async def _search_step(
self, global_query: str, search_engine: LocalSearch, actions: list[DriftAction] self, global_query: str, k_followups: int, search_engine: LocalSearch, actions: list[DriftAction]
) -> list[DriftAction]: ) -> list[DriftAction]:
""" """
Perform an asynchronous search step by executing each DriftAction asynchronously. Perform an asynchronous search step by executing each DriftAction asynchronously.
@ -171,7 +171,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
list[DriftAction]: The results from executing the search actions asynchronously. list[DriftAction]: The results from executing the search actions asynchronously.
""" """
tasks = [ tasks = [
action.search(search_engine=search_engine, global_query=global_query) action.search(search_engine=search_engine, global_query=global_query, k_followups=k_followups)
for action in actions for action in actions
] ]
return await tqdm_asyncio.gather(*tasks, leave=False) return await tqdm_asyncio.gather(*tasks, leave=False)
@ -241,7 +241,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
) )
# Process actions # Process actions
results = await self._search_step( results = await self._search_step(
global_query=query, search_engine=self.local_search, actions=actions global_query=query, k_followups=self.context_builder.config.drift_k_followups, search_engine=self.local_search, actions=actions
) )
# Update query state # Update query state

View File

@ -76,6 +76,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
context_data=context_result.context_chunks, context_data=context_result.context_chunks,
response_type=self.response_type, response_type=self.response_type,
global_query=drift_query, global_query=drift_query,
followups=kwargs.get("k_followups", 0),
) )
else: else:
search_prompt = self.system_prompt.format( search_prompt = self.system_prompt.format(