mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-13 16:47:20 +08:00
Push drift_k_followups through to prompt
This commit is contained in:
parent
a380a58f4b
commit
780a46d1b1
@ -65,7 +65,7 @@ Pay close attention specifically to the Sources tables as it contains the most r
|
||||
|
||||
Add sections and commentary to the response as appropriate for the length and format.
|
||||
|
||||
Additionally provide a score between 0 and 100 representing how well the response addresses the overall research question: {global_query}. Based on your response, suggest up to five follow-up questions that could be asked to further explore the topic as it relates to the overall research question. Do not include scores or follow up questions in the 'response' field of the JSON, add them to the respective 'score' and 'follow_up_queries' keys of the JSON output. Format your response in JSON with the following keys and values:
|
||||
Additionally provide a score between 0 and 100 representing how well the response addresses the overall research question: {global_query}. Based on your response, suggest up to {followups} follow-up questions that could be asked to further explore the topic as it relates to the overall research question. Do not include scores or follow up questions in the 'response' field of the JSON, add them to the respective 'score' and 'follow_up_queries' keys of the JSON output. Format your response in JSON with the following keys and values:
|
||||
|
||||
{{'response': str, Put your answer, formatted in markdown, here. Do not answer the global query in this section.
|
||||
'score': int,
|
||||
|
||||
@ -50,7 +50,7 @@ class DriftAction:
|
||||
"""Check if the action is complete (i.e., an answer is available)."""
|
||||
return self.answer is not None
|
||||
|
||||
async def search(self, search_engine: Any, global_query: str, scorer: Any = None):
|
||||
async def search(self, search_engine: Any, global_query: str, k_followups: int, scorer: Any = None):
|
||||
"""
|
||||
Execute an asynchronous search using the search engine, and update the action with the results.
|
||||
|
||||
@ -71,7 +71,9 @@ class DriftAction:
|
||||
return self
|
||||
|
||||
search_result = await search_engine.search(
|
||||
drift_query=global_query, query=self.query
|
||||
query=self.query,
|
||||
drift_query=global_query,
|
||||
k_followups=k_followups,
|
||||
)
|
||||
|
||||
# Do not launch exception as it will roll up with other steps
|
||||
|
||||
@ -40,6 +40,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
model: ChatModel,
|
||||
config: DRIFTSearchConfig,
|
||||
text_embedder: EmbeddingModel,
|
||||
entities: list[Entity],
|
||||
entity_text_embeddings: BaseVectorStore,
|
||||
@ -49,14 +50,13 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
covariates: dict[str, list[Covariate]] | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
embedding_vectorstore_key: str = EntityVectorStoreKey.ID,
|
||||
config: DRIFTSearchConfig | None = None,
|
||||
local_system_prompt: str | None = None,
|
||||
local_mixed_context: LocalSearchMixedContext | None = None,
|
||||
reduce_system_prompt: str | None = None,
|
||||
response_type: str | None = None,
|
||||
):
|
||||
"""Initialize the DRIFT search context builder with necessary components."""
|
||||
self.config = config or DRIFTSearchConfig()
|
||||
self.config = config
|
||||
self.model = model
|
||||
self.text_embedder = text_embedder
|
||||
self.tokenizer = tokenizer or get_tokenizer()
|
||||
|
||||
@ -156,7 +156,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
raise ValueError(error_msg)
|
||||
|
||||
async def _search_step(
|
||||
self, global_query: str, search_engine: LocalSearch, actions: list[DriftAction]
|
||||
self, global_query: str, k_followups: int, search_engine: LocalSearch, actions: list[DriftAction]
|
||||
) -> list[DriftAction]:
|
||||
"""
|
||||
Perform an asynchronous search step by executing each DriftAction asynchronously.
|
||||
@ -171,7 +171,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
list[DriftAction]: The results from executing the search actions asynchronously.
|
||||
"""
|
||||
tasks = [
|
||||
action.search(search_engine=search_engine, global_query=global_query)
|
||||
action.search(search_engine=search_engine, global_query=global_query, k_followups=k_followups)
|
||||
for action in actions
|
||||
]
|
||||
return await tqdm_asyncio.gather(*tasks, leave=False)
|
||||
@ -241,7 +241,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
)
|
||||
# Process actions
|
||||
results = await self._search_step(
|
||||
global_query=query, search_engine=self.local_search, actions=actions
|
||||
global_query=query, k_followups=self.context_builder.config.drift_k_followups, search_engine=self.local_search, actions=actions
|
||||
)
|
||||
|
||||
# Update query state
|
||||
|
||||
@ -76,6 +76,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
context_data=context_result.context_chunks,
|
||||
response_type=self.response_type,
|
||||
global_query=drift_query,
|
||||
followups=kwargs.get("k_followups", 0),
|
||||
)
|
||||
else:
|
||||
search_prompt = self.system_prompt.format(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user