Push drift_k_followups through to prompt

This commit is contained in:
Nathan Evans 2026-01-12 18:00:21 -08:00
parent a380a58f4b
commit 780a46d1b1
5 changed files with 11 additions and 8 deletions

View File

@ -65,7 +65,7 @@ Pay close attention specifically to the Sources tables as it contains the most r
Add sections and commentary to the response as appropriate for the length and format.
Additionally provide a score between 0 and 100 representing how well the response addresses the overall research question: {global_query}. Based on your response, suggest up to five follow-up questions that could be asked to further explore the topic as it relates to the overall research question. Do not include scores or follow up questions in the 'response' field of the JSON, add them to the respective 'score' and 'follow_up_queries' keys of the JSON output. Format your response in JSON with the following keys and values:
Additionally provide a score between 0 and 100 representing how well the response addresses the overall research question: {global_query}. Based on your response, suggest up to {followups} follow-up questions that could be asked to further explore the topic as it relates to the overall research question. Do not include scores or follow up questions in the 'response' field of the JSON, add them to the respective 'score' and 'follow_up_queries' keys of the JSON output. Format your response in JSON with the following keys and values:
{{'response': str, Put your answer, formatted in markdown, here. Do not answer the global query in this section.
'score': int,

View File

@ -50,7 +50,7 @@ class DriftAction:
"""Check if the action is complete (i.e., an answer is available)."""
return self.answer is not None
async def search(self, search_engine: Any, global_query: str, scorer: Any = None):
async def search(self, search_engine: Any, global_query: str, k_followups: int, scorer: Any = None):
"""
Execute an asynchronous search using the search engine, and update the action with the results.
@ -71,7 +71,9 @@ class DriftAction:
return self
search_result = await search_engine.search(
drift_query=global_query, query=self.query
query=self.query,
drift_query=global_query,
k_followups=k_followups,
)
# Do not launch exception as it will roll up with other steps

View File

@ -40,6 +40,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
def __init__(
self,
model: ChatModel,
config: DRIFTSearchConfig,
text_embedder: EmbeddingModel,
entities: list[Entity],
entity_text_embeddings: BaseVectorStore,
@ -49,14 +50,13 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
covariates: dict[str, list[Covariate]] | None = None,
tokenizer: Tokenizer | None = None,
embedding_vectorstore_key: str = EntityVectorStoreKey.ID,
config: DRIFTSearchConfig | None = None,
local_system_prompt: str | None = None,
local_mixed_context: LocalSearchMixedContext | None = None,
reduce_system_prompt: str | None = None,
response_type: str | None = None,
):
"""Initialize the DRIFT search context builder with necessary components."""
self.config = config or DRIFTSearchConfig()
self.config = config
self.model = model
self.text_embedder = text_embedder
self.tokenizer = tokenizer or get_tokenizer()

View File

@ -156,7 +156,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
raise ValueError(error_msg)
async def _search_step(
self, global_query: str, search_engine: LocalSearch, actions: list[DriftAction]
self, global_query: str, k_followups: int, search_engine: LocalSearch, actions: list[DriftAction]
) -> list[DriftAction]:
"""
Perform an asynchronous search step by executing each DriftAction asynchronously.
@ -171,7 +171,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
list[DriftAction]: The results from executing the search actions asynchronously.
"""
tasks = [
action.search(search_engine=search_engine, global_query=global_query)
action.search(search_engine=search_engine, global_query=global_query, k_followups=k_followups)
for action in actions
]
return await tqdm_asyncio.gather(*tasks, leave=False)
@ -241,7 +241,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
)
# Process actions
results = await self._search_step(
global_query=query, search_engine=self.local_search, actions=actions
global_query=query, k_followups=self.context_builder.config.drift_k_followups, search_engine=self.local_search, actions=actions
)
# Update query state

View File

@ -76,6 +76,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
context_data=context_result.context_chunks,
response_type=self.response_type,
global_query=drift_query,
followups=kwargs.get("k_followups", 0),
)
else:
search_prompt = self.system_prompt.format(