From 780a46d1b1b36c52033467cbbe2c0bfbb370cc23 Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Mon, 12 Jan 2026 18:00:21 -0800 Subject: [PATCH] Push drift_k_followups through to prompt --- .../graphrag/prompts/query/drift_search_system_prompt.py | 2 +- .../graphrag/query/structured_search/drift_search/action.py | 6 ++++-- .../query/structured_search/drift_search/drift_context.py | 4 ++-- .../graphrag/query/structured_search/drift_search/search.py | 6 +++--- .../graphrag/query/structured_search/local_search/search.py | 1 + 5 files changed, 11 insertions(+), 8 deletions(-) diff --git a/packages/graphrag/graphrag/prompts/query/drift_search_system_prompt.py b/packages/graphrag/graphrag/prompts/query/drift_search_system_prompt.py index 3faae89a..4cb220f6 100644 --- a/packages/graphrag/graphrag/prompts/query/drift_search_system_prompt.py +++ b/packages/graphrag/graphrag/prompts/query/drift_search_system_prompt.py @@ -65,7 +65,7 @@ Pay close attention specifically to the Sources tables as it contains the most r Add sections and commentary to the response as appropriate for the length and format. -Additionally provide a score between 0 and 100 representing how well the response addresses the overall research question: {global_query}. Based on your response, suggest up to five follow-up questions that could be asked to further explore the topic as it relates to the overall research question. Do not include scores or follow up questions in the 'response' field of the JSON, add them to the respective 'score' and 'follow_up_queries' keys of the JSON output. Format your response in JSON with the following keys and values: +Additionally provide a score between 0 and 100 representing how well the response addresses the overall research question: {global_query}. Based on your response, suggest up to {followups} follow-up questions that could be asked to further explore the topic as it relates to the overall research question. Do not include scores or follow up questions in the 'response' field of the JSON, add them to the respective 'score' and 'follow_up_queries' keys of the JSON output. Format your response in JSON with the following keys and values: {{'response': str, Put your answer, formatted in markdown, here. Do not answer the global query in this section. 'score': int, diff --git a/packages/graphrag/graphrag/query/structured_search/drift_search/action.py b/packages/graphrag/graphrag/query/structured_search/drift_search/action.py index 8f1da3d7..39ef11d1 100644 --- a/packages/graphrag/graphrag/query/structured_search/drift_search/action.py +++ b/packages/graphrag/graphrag/query/structured_search/drift_search/action.py @@ -50,7 +50,7 @@ class DriftAction: """Check if the action is complete (i.e., an answer is available).""" return self.answer is not None - async def search(self, search_engine: Any, global_query: str, scorer: Any = None): + async def search(self, search_engine: Any, global_query: str, k_followups: int, scorer: Any = None): """ Execute an asynchronous search using the search engine, and update the action with the results. @@ -71,7 +71,9 @@ class DriftAction: return self search_result = await search_engine.search( - drift_query=global_query, query=self.query + query=self.query, + drift_query=global_query, + k_followups=k_followups, ) # Do not launch exception as it will roll up with other steps diff --git a/packages/graphrag/graphrag/query/structured_search/drift_search/drift_context.py b/packages/graphrag/graphrag/query/structured_search/drift_search/drift_context.py index 4b4325ae..9e1e9c31 100644 --- a/packages/graphrag/graphrag/query/structured_search/drift_search/drift_context.py +++ b/packages/graphrag/graphrag/query/structured_search/drift_search/drift_context.py @@ -40,6 +40,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder): def __init__( self, model: ChatModel, + config: DRIFTSearchConfig, text_embedder: EmbeddingModel, entities: list[Entity], entity_text_embeddings: BaseVectorStore, @@ -49,14 +50,13 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder): covariates: dict[str, list[Covariate]] | None = None, tokenizer: Tokenizer | None = None, embedding_vectorstore_key: str = EntityVectorStoreKey.ID, - config: DRIFTSearchConfig | None = None, local_system_prompt: str | None = None, local_mixed_context: LocalSearchMixedContext | None = None, reduce_system_prompt: str | None = None, response_type: str | None = None, ): """Initialize the DRIFT search context builder with necessary components.""" - self.config = config or DRIFTSearchConfig() + self.config = config self.model = model self.text_embedder = text_embedder self.tokenizer = tokenizer or get_tokenizer() diff --git a/packages/graphrag/graphrag/query/structured_search/drift_search/search.py b/packages/graphrag/graphrag/query/structured_search/drift_search/search.py index 64a8e52b..150c3e39 100644 --- a/packages/graphrag/graphrag/query/structured_search/drift_search/search.py +++ b/packages/graphrag/graphrag/query/structured_search/drift_search/search.py @@ -156,7 +156,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]): raise ValueError(error_msg) async def _search_step( - self, global_query: str, search_engine: LocalSearch, actions: list[DriftAction] + self, global_query: str, k_followups: int, search_engine: LocalSearch, actions: list[DriftAction] ) -> list[DriftAction]: """ Perform an asynchronous search step by executing each DriftAction asynchronously. @@ -171,7 +171,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]): list[DriftAction]: The results from executing the search actions asynchronously. """ tasks = [ - action.search(search_engine=search_engine, global_query=global_query) + action.search(search_engine=search_engine, global_query=global_query, k_followups=k_followups) for action in actions ] return await tqdm_asyncio.gather(*tasks, leave=False) @@ -241,7 +241,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]): ) # Process actions results = await self._search_step( - global_query=query, search_engine=self.local_search, actions=actions + global_query=query, k_followups=self.context_builder.config.drift_k_followups, search_engine=self.local_search, actions=actions ) # Update query state diff --git a/packages/graphrag/graphrag/query/structured_search/local_search/search.py b/packages/graphrag/graphrag/query/structured_search/local_search/search.py index fdd72949..64fc8842 100644 --- a/packages/graphrag/graphrag/query/structured_search/local_search/search.py +++ b/packages/graphrag/graphrag/query/structured_search/local_search/search.py @@ -76,6 +76,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]): context_data=context_result.context_chunks, response_type=self.response_type, global_query=drift_query, + followups=kwargs.get("k_followups", 0), ) else: search_prompt = self.system_prompt.format(