mirror of
https://github.com/microsoft/graphrag.git
synced 2026-02-08 04:02:29 +08:00
drift search
This commit is contained in:
parent
fc9895f793
commit
05473aca45
101
graphrag/query/structured_search/drift_search/action.py
Normal file
101
graphrag/query/structured_search/drift_search/action.py
Normal file
@ -0,0 +1,101 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DriftAction:
|
||||
"""
|
||||
Represents an action containing query, answer, score, and follow-up actions.
|
||||
We want to be able to encapsulate the action strings being produced by the LLM in a clean way.
|
||||
"""
|
||||
|
||||
def __init__(self, query: str, answer: str | None = None, follow_ups: List['DriftAction'] | List[str] = []):
|
||||
self.query = query
|
||||
self.answer: str | None = answer # corresponds to an 'intermediate_answer'
|
||||
self.score: Optional[float] = None
|
||||
self.follow_ups: List['DriftAction'] | List[str] = follow_ups
|
||||
self.metadata: Dict[str, Any] = {}
|
||||
# Should contain metadata explaining how to execute the action. Will not always be local search in the future.
|
||||
|
||||
@property
|
||||
def is_complete(self) -> bool:
|
||||
return self.answer is not None
|
||||
|
||||
def search(self, search_engine: Any, scorer: Any = None):
|
||||
raise NotImplementedError("Search method not implemented for DriftAction. Use asearch instead.")
|
||||
|
||||
async def asearch(self, search_engine: Any, global_query:str, scorer: Any = None):
|
||||
# TODO: test that graph stores actions as reference... This SHOULD update the graph object.
|
||||
search_result = await search_engine.asearch(drift_query=global_query, query=self.query)
|
||||
try:
|
||||
response = json.loads(search_result.response)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to parse response: {search_result.response}. Ensure it is JSON serializable.")
|
||||
self.answer = response.pop('response', None)
|
||||
self.score = response.pop('score', float('-inf'))
|
||||
self.metadata.update({'context_data':search_result.context_data})
|
||||
|
||||
self.follow_ups = response.pop('follow_up_queries', [])
|
||||
if self.follow_ups == []:
|
||||
log.warning("No follow-up actions found for response: %s", response)
|
||||
|
||||
if scorer:
|
||||
self.compute_score(scorer)
|
||||
return self
|
||||
|
||||
def compute_score(self, scorer: Any):
|
||||
score = scorer.compute_score(self.query, self.answer)
|
||||
self.score = score if score is not None else float('-inf') # use -inf to help with sorting later
|
||||
|
||||
def serialize(self, include_follow_ups: bool = True) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes the action to a dictionary.
|
||||
"""
|
||||
data = {
|
||||
'query': self.query,
|
||||
'answer': self.answer,
|
||||
'score': self.score,
|
||||
'metadata': self.metadata,
|
||||
}
|
||||
if include_follow_ups:
|
||||
data['follow_ups'] = [action.serialize() for action in self.follow_ups] # TODO: handle leftover string followups
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, data: Dict[str, Any]) -> 'DriftAction':
|
||||
"""
|
||||
Deserializes the action from a dictionary.
|
||||
"""
|
||||
action = cls(data['query'])
|
||||
action.answer = data.get('answer')
|
||||
action.score = data.get('score')
|
||||
action.metadata = data.get('metadata', {})
|
||||
if 'follow_ups' in data:
|
||||
action.follow_ups = [cls.deserialize(fu_data) for fu_data in data.get('follow_up_queries', [])]
|
||||
else:
|
||||
action.follow_ups = []
|
||||
return action
|
||||
|
||||
@classmethod
|
||||
def from_primer_response(cls, query: str, response: str | Dict[str, Any] | List[Dict[str, Any]]) -> 'DriftAction':
|
||||
"""
|
||||
Creates a DriftAction from the DRIFTPrimer response.
|
||||
"""
|
||||
if isinstance(response, dict):
|
||||
action = cls(query, follow_ups=response.get('follow_up_queries', []), answer=response.get('intermediate_answer'))
|
||||
action.answer = response.get('intermediate_answer')
|
||||
action.score = response.get('score')
|
||||
return action
|
||||
else:
|
||||
raise ValueError(f'Response must be a dictionary. Found: {type(response)}'
|
||||
f' with content: {response}')
|
||||
|
||||
def __hash__(self):
|
||||
# Necessary for storing in networkx.MultiDiGraph. Assumes unique queries.
|
||||
return hash(self.query)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, DriftAction) and self.query == other.query
|
||||
157
graphrag/query/structured_search/drift_search/drift_context.py
Normal file
157
graphrag/query/structured_search/drift_search/drift_context.py
Normal file
@ -0,0 +1,157 @@
|
||||
import logging
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
|
||||
from graphrag.config.models.drift_config import DRIFTSearchConfig
|
||||
from graphrag.model import (
|
||||
CommunityReport,
|
||||
Covariate,
|
||||
Entity,
|
||||
Relationship,
|
||||
TextUnit,
|
||||
)
|
||||
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
|
||||
from graphrag.query.llm.base import BaseTextEmbedding
|
||||
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
|
||||
from graphrag.query.structured_search.base import DRIFTContextBuilder
|
||||
from graphrag.query.structured_search.drift_search.primer import (
|
||||
PrimerQueryProcessor,
|
||||
)
|
||||
from graphrag.query.structured_search.drift_search.system_prompt import (
|
||||
DRIFT_LOCAL_SYSTEM_PROMPT,
|
||||
)
|
||||
from graphrag.query.structured_search.local_search.mixed_context import (
|
||||
LocalSearchMixedContext,
|
||||
)
|
||||
from graphrag.vector_stores import BaseVectorStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
chat_llm: ChatOpenAI,
|
||||
text_embedder: BaseTextEmbedding,
|
||||
entities: List[Entity],
|
||||
entity_text_embeddings: BaseVectorStore,
|
||||
text_units: List[TextUnit] | None = None,
|
||||
reports: List[CommunityReport] | None = None,
|
||||
relationships: List[Relationship] | None = None,
|
||||
covariates: dict[str, List[Covariate]] | None = None,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
embedding_vectorstore_key: str = EntityVectorStoreKey.ID,
|
||||
config: DRIFTSearchConfig | None = None,
|
||||
local_system_prompt: str = DRIFT_LOCAL_SYSTEM_PROMPT,
|
||||
local_mixed_context: LocalSearchMixedContext | None = None,
|
||||
):
|
||||
self.config = config or DRIFTSearchConfig()
|
||||
self.chat_llm = chat_llm
|
||||
self.text_embedder = text_embedder
|
||||
self.token_encoder = token_encoder
|
||||
self.local_system_prompt = local_system_prompt
|
||||
|
||||
self.drift_primer_context = None
|
||||
|
||||
self.entities = entities
|
||||
self.entity_text_embeddings = entity_text_embeddings
|
||||
self.reports = reports
|
||||
self.text_units = text_units
|
||||
self.relationships = relationships
|
||||
self.covariates = covariates
|
||||
self.embedding_vectorstore_key = embedding_vectorstore_key
|
||||
|
||||
self.local_mixed_context = (
|
||||
local_mixed_context or self.init_local_context_builder()
|
||||
)
|
||||
|
||||
def init_local_context_builder(self) -> LocalSearchMixedContext:
|
||||
return LocalSearchMixedContext(
|
||||
community_reports=self.reports,
|
||||
text_units=self.text_units,
|
||||
entities=self.entities,
|
||||
relationships=self.relationships,
|
||||
covariates=self.covariates,
|
||||
entity_text_embeddings=self.entity_text_embeddings,
|
||||
embedding_vectorstore_key=self.embedding_vectorstore_key,
|
||||
text_embedder=self.text_embedder,
|
||||
token_encoder=self.token_encoder,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def convert_reports_to_df(
|
||||
reports: List[CommunityReport],
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Converts a list of CommunityReport objects to a DataFrame.
|
||||
"""
|
||||
df = pd.DataFrame([report.__dict__ for report in reports])
|
||||
if df["full_content"].isnull().any():
|
||||
raise ValueError("Some reports are missing full content.")
|
||||
|
||||
if (
|
||||
"full_content_embedding" not in df.columns
|
||||
or df["full_content_embedding"].isnull().any()
|
||||
):
|
||||
raise ValueError("Some reports are missing full content embeddings.")
|
||||
|
||||
return df
|
||||
|
||||
@staticmethod
|
||||
def check_query_doc_encodings(
|
||||
query_embedding: Any, embedding: Any
|
||||
) -> bool:
|
||||
"""
|
||||
Checks if the embeddings have the same type, length, are not empty,
|
||||
and have matching element types.
|
||||
"""
|
||||
if not (
|
||||
isinstance(query_embedding, type(embedding))
|
||||
and len(query_embedding) == len(embedding)
|
||||
and isinstance(query_embedding[0], type(embedding[0]))
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
def build_primer_context(self, query: str, **kwargs) -> pd.DataFrame:
|
||||
if not self.reports:
|
||||
raise ValueError(
|
||||
"No community reports available... Please provide a list of reports."
|
||||
)
|
||||
|
||||
query_processor = PrimerQueryProcessor(
|
||||
chat_llm=self.chat_llm,
|
||||
text_embedder=self.text_embedder,
|
||||
token_encoder=self.token_encoder,
|
||||
reports=self.reports,
|
||||
)
|
||||
|
||||
query_embedding = query_processor(query)
|
||||
|
||||
report_df = self.convert_reports_to_df(self.reports)
|
||||
|
||||
if self.check_query_doc_encodings(
|
||||
query_embedding, report_df["full_content_embedding"].iloc[0]
|
||||
):
|
||||
report_df["similarity"] = report_df[
|
||||
"full_content_embedding"
|
||||
].apply(
|
||||
lambda x: np.dot(x, query_embedding)
|
||||
/ (np.linalg.norm(x) * np.linalg.norm(query_embedding))
|
||||
)
|
||||
top_k = (
|
||||
report_df.sort_values("similarity", ascending=False)
|
||||
.head(self.config.search_primer_k)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Query and document embeddings are not compatible. "
|
||||
"Please ensure that the embeddings are of the same type and length."
|
||||
f" Query: {query_embedding}, Document: {report_df['full_content_embedding'].iloc[0]}"
|
||||
)
|
||||
|
||||
return top_k[['short_id', 'community_id', 'full_content']]
|
||||
|
||||
165
graphrag/query/structured_search/drift_search/primer.py
Normal file
165
graphrag/query/structured_search/drift_search/primer.py
Normal file
@ -0,0 +1,165 @@
|
||||
import logging
|
||||
import json
|
||||
import pandas as pd
|
||||
import random
|
||||
import time
|
||||
from typing import Any, List, Dict, Tuple
|
||||
|
||||
import tiktoken
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from graphrag.config.models.drift_config import DRIFTSearchConfig
|
||||
from graphrag.model import CommunityReport
|
||||
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
|
||||
from graphrag.query.llm.base import BaseTextEmbedding
|
||||
from graphrag.query.structured_search.base import SearchResult
|
||||
from graphrag.query.structured_search.drift_search.system_prompt import DRIFT_PRIMER_PROMPT
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PrimerQueryProcessor:
|
||||
|
||||
def __init__(self, chat_llm: ChatOpenAI, text_embedder: BaseTextEmbedding, reports: List[CommunityReport], token_encoder: tiktoken.Encoding | None = None,):
|
||||
self.chat_llm = chat_llm
|
||||
self.text_embedder = text_embedder
|
||||
self.token_encoder = token_encoder
|
||||
self.reports = reports
|
||||
|
||||
def expand_query(self, query: str) -> Tuple[str, int]:
|
||||
token_ct = 0
|
||||
|
||||
template = random.choice(self.reports).full_content
|
||||
|
||||
prompt = (
|
||||
f"Create a hypothetical answer to the following query: {query}\n\n"
|
||||
"Format it to follow the structure of the template below:\n\n"
|
||||
f"{template}\n"
|
||||
"Ensure that the hypothetical answer does not reference new named entities that are not present in the original query."
|
||||
)
|
||||
messages = [
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
text = self.chat_llm.generate(messages)
|
||||
if self.token_encoder:
|
||||
token_ct = len(self.token_encoder.encode(text)) + len(self.token_encoder.encode(query))
|
||||
if text == "":
|
||||
log.warning("Failed to generate expansion for query: %s", query)
|
||||
return query, token_ct
|
||||
return text, token_ct
|
||||
|
||||
def __call__(self, query: str) -> List[float]:
|
||||
hyde_query, token_ct = self.expand_query(query) # TODO: implement token counting
|
||||
log.info("Expanded query: %s", hyde_query)
|
||||
return self.text_embedder.embed(hyde_query)
|
||||
|
||||
|
||||
class DRIFTPrimer:
|
||||
"""
|
||||
Performs initial query decomposition using global guidance from information in community reports.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DRIFTSearchConfig,
|
||||
chat_llm: ChatOpenAI,
|
||||
token_encoder: tiktoken.Encoding | None = None, # TODO: implement token counting
|
||||
):
|
||||
self.llm = chat_llm
|
||||
self.config = config
|
||||
self.token_encoder = token_encoder
|
||||
|
||||
|
||||
async def decompose_query(self, query: str, reports: pd.DataFrame) -> Tuple[Dict, int]:
|
||||
"""
|
||||
Decomposes the query into subqueries based on the fetched global structures.
|
||||
Returns a tuple of the parsed response and the number of tokens used.
|
||||
"""
|
||||
community_reports = "\n\n".join(reports['full_content'].tolist())
|
||||
prompt = DRIFT_PRIMER_PROMPT.format(query=query, community_reports=community_reports)
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
prompt_tokens = self._count_message_tokens(messages)
|
||||
|
||||
response = await self.llm.agenerate(messages, response_format={"type": "json_object"})
|
||||
|
||||
parsed_response = json.loads(response)
|
||||
|
||||
return parsed_response, prompt_tokens
|
||||
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
top_k_reports: pd.DataFrame,
|
||||
) -> SearchResult:
|
||||
"""
|
||||
Asynchronous search method that processes the query and returns a SearchResult.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
report_folds = self.split_reports(top_k_reports)
|
||||
|
||||
tasks = []
|
||||
prompt_tokens = 0
|
||||
llm_calls = len(report_folds)
|
||||
|
||||
for fold in report_folds:
|
||||
task = self.decompose_query(query, fold)
|
||||
tasks.append(task)
|
||||
|
||||
results_with_tokens = await tqdm_asyncio.gather(*tasks)
|
||||
|
||||
for result in results_with_tokens:
|
||||
# result is a tuple: (parsed_response, tokens_used)
|
||||
_, tokens_used = result
|
||||
prompt_tokens += tokens_used
|
||||
completion_time = time.time() - start_time
|
||||
|
||||
search_result = SearchResult(
|
||||
response=[response for response, _ in results_with_tokens],
|
||||
context_data={'top_k_reports': top_k_reports},
|
||||
context_text=str(top_k_reports),
|
||||
completion_time=completion_time,
|
||||
llm_calls=llm_calls,
|
||||
prompt_tokens=prompt_tokens,
|
||||
)
|
||||
return search_result
|
||||
|
||||
|
||||
def split_reports(self, reports: pd.DataFrame) -> List[pd.DataFrame]:
|
||||
"""
|
||||
Splits the reports into folds, allows for parallel processing.
|
||||
"""
|
||||
folds = []
|
||||
num_reports = len(reports)
|
||||
primer_folds = self.config.primer_folds or 1 # Ensure at least one fold
|
||||
for i in range(primer_folds):
|
||||
start_idx = i * num_reports // primer_folds
|
||||
if i == primer_folds - 1:
|
||||
end_idx = num_reports
|
||||
else:
|
||||
end_idx = (i + 1) * num_reports // primer_folds
|
||||
fold = reports.iloc[start_idx:end_idx]
|
||||
folds.append(fold)
|
||||
return folds
|
||||
|
||||
|
||||
def _count_text_tokens(self, text: str) -> int:
|
||||
"""
|
||||
Counts the number of tokens in a given text using the token encoder.
|
||||
"""
|
||||
if self.token_encoder is None:
|
||||
raise ValueError("Token encoder is not initialized.")
|
||||
return len(self.token_encoder.encode(text))
|
||||
|
||||
def _count_message_tokens(self, messages: List[Dict[str, str]]) -> int:
|
||||
"""
|
||||
Counts the number of tokens in a list of messages.
|
||||
"""
|
||||
total_tokens = 0
|
||||
for message in messages:
|
||||
for _, value in message.items():
|
||||
total_tokens += self._count_text_tokens(value)
|
||||
return total_tokens
|
||||
149
graphrag/query/structured_search/drift_search/search.py
Normal file
149
graphrag/query/structured_search/drift_search/search.py
Normal file
@ -0,0 +1,149 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
import logging
|
||||
from typing import Any, List
|
||||
import tiktoken
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from graphrag.query.context_builder.conversation_history import ConversationHistory
|
||||
from graphrag.query.structured_search.base import BaseSearch, SearchResult
|
||||
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
|
||||
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
|
||||
from graphrag.query.structured_search.local_search.search import LocalSearch
|
||||
from graphrag.query.structured_search.drift_search.action import DriftAction
|
||||
from graphrag.query.structured_search.drift_search.state import QueryState
|
||||
from graphrag.query.structured_search.drift_search.drift_context import DRIFTSearchContextBuilder
|
||||
from graphrag.query.structured_search.drift_search.primer import DRIFTPrimer
|
||||
from graphrag.config.models.drift_config import DRIFTSearchConfig
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DRIFTSearch(BaseSearch):
|
||||
def __init__(
|
||||
self,
|
||||
llm: ChatOpenAI,
|
||||
context_builder: DRIFTSearchContextBuilder,
|
||||
config: DRIFTSearchConfig | None = None,
|
||||
token_encoder: tiktoken.Encoding | None = None, # TODO: implement token counting
|
||||
query_state: QueryState | None = None,
|
||||
):
|
||||
super().__init__(llm, context_builder, token_encoder)
|
||||
|
||||
self.config = config or DRIFTSearchConfig()
|
||||
self.context_builder = context_builder
|
||||
self.token_encoder = token_encoder
|
||||
self.query_state = query_state or QueryState()
|
||||
self.primer = DRIFTPrimer(config=self.config, chat_llm=llm, token_encoder=token_encoder)
|
||||
self.local_search = self.init_local_search()
|
||||
|
||||
|
||||
def init_local_search(self):
|
||||
|
||||
local_context_params = {
|
||||
"text_unit_prop": self.config.local_search_text_unit_prop,
|
||||
"community_prop": self.config.local_search_community_prop,
|
||||
"top_k_mapped_entities": self.config.local_search_top_k_mapped_entities,
|
||||
"top_k_relationships": self.config.local_search_top_k_relationships,
|
||||
"include_entity_rank": True,
|
||||
"include_relationship_weight": True,
|
||||
"include_community_rank": False,
|
||||
"return_candidate_context": False,
|
||||
"embedding_vectorstore_key": EntityVectorStoreKey.ID, # set this to EntityVectorStoreKey.TITLE if the vectorstore uses entity title as ids
|
||||
"max_tokens": self.config.local_search_max_data_tokens,
|
||||
}
|
||||
|
||||
llm_params = {
|
||||
"max_tokens": self.config.local_search_llm_max_gen_tokens,
|
||||
"temperature": self.config.local_search_temperature,
|
||||
"response_format": {"type": "json_object"}
|
||||
}
|
||||
|
||||
|
||||
return LocalSearch(
|
||||
llm=self.llm,
|
||||
system_prompt=self.context_builder.local_system_prompt,
|
||||
context_builder=self.context_builder.local_mixed_context,
|
||||
token_encoder=self.token_encoder,
|
||||
llm_params=llm_params,
|
||||
context_builder_params=local_context_params,
|
||||
response_type="multiple paragraphs" # this has no bearing on the obj returned by OAI, only the format of the response within the obj returned by OAI.
|
||||
)
|
||||
|
||||
|
||||
def _process_primer_results(self, query:str, search_results: SearchResult) -> DriftAction:
|
||||
response = search_results.response
|
||||
if isinstance(response, list) and isinstance(response[0], dict):
|
||||
intermediate_answer = "\n\n".join([i['intermediate_answer'] for i in response])
|
||||
follow_ups = [fu for i in response for fu in i['follow_up_queries']]
|
||||
score = sum([i['score'] for i in response]) / len(response)
|
||||
response = {'intermediate_answer': intermediate_answer, 'follow_up_queries': follow_ups, 'score': score}
|
||||
return DriftAction.from_primer_response(query, response)
|
||||
else:
|
||||
raise ValueError(f'Response must be a list of dictionaries. Found: {type(response)}')
|
||||
|
||||
|
||||
# hard coded to use local search, but will be updated to use a general meta search engine
|
||||
async def asearch_step(self, global_query:str, search_engine: LocalSearch, actions: List[DriftAction]) -> List[DriftAction]:
|
||||
tasks = [action.asearch(search_engine=search_engine, global_query=global_query) for action in actions]
|
||||
results = await tqdm_asyncio.gather(*tasks)
|
||||
return results
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: Any = None,
|
||||
**kwargs,
|
||||
) -> SearchResult:
|
||||
|
||||
|
||||
## Check if query state is empty. Are there follow-up actions?
|
||||
if not self.query_state.graph:
|
||||
# Prime the search with the primer
|
||||
primer_context = self.context_builder.build_primer_context(query) # pd.DataFrame
|
||||
primer_response = await self.primer.asearch(query=query, top_k_reports=primer_context)
|
||||
## Package response into DriftAction
|
||||
init_action = self._process_primer_results(query, primer_response)
|
||||
self.query_state.add_action(init_action)
|
||||
self.query_state.add_all_follow_ups(init_action, init_action.follow_ups)
|
||||
|
||||
## Main loop ##
|
||||
|
||||
steps = 0
|
||||
|
||||
while steps < self.config.n:
|
||||
## Rank actions
|
||||
actions = self.query_state.rank_incomplete_actions()
|
||||
if not actions:
|
||||
log.info("No more actions to take. Exiting DRIFT loop..")
|
||||
break
|
||||
## Take the top k actions
|
||||
actions = actions[:self.config.search_primer_k]
|
||||
## Process actions
|
||||
results = await self.asearch_step(global_query=query, search_engine=self.local_search, actions=actions) # currently harcoded to local search, but will be updated to use a more general search engine
|
||||
|
||||
## Update query state
|
||||
for action in results:
|
||||
self.query_state.add_action(action)
|
||||
self.query_state.add_all_follow_ups(action, action.follow_ups)
|
||||
steps += 1
|
||||
|
||||
return SearchResult(
|
||||
response=self.query_state.serialize(),
|
||||
context_data='test',
|
||||
context_text='testing_only',
|
||||
completion_time=0,
|
||||
llm_calls=0,
|
||||
prompt_tokens=0
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: Any = None,
|
||||
**kwargs,) -> SearchResult:
|
||||
|
||||
raise NotImplementedError("Synchronous DRIFT is not implemented.")
|
||||
|
||||
def astream_search(self, query: str, conversation_history: ConversationHistory | None = None) -> AsyncGenerator[str, None]:
|
||||
raise NotImplementedError("Streaming DRIFT search is not implemented.")
|
||||
113
graphrag/query/structured_search/drift_search/state.py
Normal file
113
graphrag/query/structured_search/drift_search/state.py
Normal file
@ -0,0 +1,113 @@
|
||||
import networkx as nx
|
||||
import random
|
||||
from typing import List, Dict, Optional, Any
|
||||
from graphrag.query.structured_search.drift_search.action import DriftAction
|
||||
|
||||
|
||||
class QueryState:
|
||||
"""
|
||||
Manages the state of the query, including a graph of actions.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.graph = nx.MultiDiGraph()
|
||||
|
||||
def add_action(self, action: DriftAction, metadata: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Adds an action to the graph with optional metadata.
|
||||
"""
|
||||
self.graph.add_node(action, **(metadata or {}))
|
||||
|
||||
def relate_actions(self, parent: DriftAction, child: DriftAction, weight: float = 1.0):
|
||||
"""
|
||||
Relates two actions in the graph.
|
||||
"""
|
||||
self.graph.add_edge(parent, child, weight=weight)
|
||||
|
||||
def add_all_follow_ups(self, action: DriftAction, follow_ups: List[DriftAction] | List[str], weight: float = 1.0):
|
||||
"""
|
||||
Adds all follow-up actions and links them to the given action.
|
||||
"""
|
||||
if len(follow_ups) == 0:
|
||||
raise ValueError("No follow-up actions to add. Please provide a list of follow-up actions.")
|
||||
|
||||
for follow_up in follow_ups:
|
||||
if isinstance(follow_up, str):
|
||||
follow_up = DriftAction(query=follow_up)
|
||||
|
||||
self.add_action(follow_up)
|
||||
self.relate_actions(action, follow_up, weight)
|
||||
|
||||
def find_incomplete_actions(self) -> List[DriftAction]:
|
||||
"""
|
||||
Finds all unanswered actions in the graph.
|
||||
"""
|
||||
return [node for node in self.graph.nodes if not node.is_complete]
|
||||
|
||||
def rank_incomplete_actions(self, scorer: Any | None = None) -> List[DriftAction]:
|
||||
"""
|
||||
Ranks all unanswered actions in the graph if scorer available.
|
||||
"""
|
||||
unanswered = self.find_incomplete_actions()
|
||||
if scorer:
|
||||
for node in unanswered:
|
||||
node.compute_score(scorer)
|
||||
return list(sorted(
|
||||
unanswered,
|
||||
key=lambda node: node.score if node.score is not None else float('-inf'),
|
||||
reverse=True
|
||||
))
|
||||
else: # shuffle the list if no scorer
|
||||
random.shuffle(unanswered)
|
||||
return list(unanswered)
|
||||
|
||||
def serialize(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes the graph to a dictionary, including nodes and edges.
|
||||
"""
|
||||
# Create a mapping from nodes to unique IDs
|
||||
node_to_id = {node: idx for idx, node in enumerate(self.graph.nodes())}
|
||||
|
||||
# Serialize nodes
|
||||
nodes = []
|
||||
for node in self.graph.nodes():
|
||||
node_data = node.serialize(include_follow_ups=False)
|
||||
node_data['id'] = node_to_id[node]
|
||||
node_attributes = self.graph.nodes[node]
|
||||
if node_attributes:
|
||||
node_data.update(node_attributes)
|
||||
nodes.append(node_data)
|
||||
|
||||
# Serialize edges
|
||||
edges = []
|
||||
for u, v, edge_data in self.graph.edges(data=True):
|
||||
edge_info = {
|
||||
"source": node_to_id[u],
|
||||
"target": node_to_id[v],
|
||||
"weight": edge_data.get("weight", 1.0),
|
||||
}
|
||||
edges.append(edge_info)
|
||||
|
||||
return {"nodes": nodes, "edges": edges}
|
||||
|
||||
def deserialize(self, data: Dict[str, Any]):
|
||||
"""
|
||||
Deserializes the dictionary back to a graph.
|
||||
"""
|
||||
self.graph.clear()
|
||||
id_to_action = {}
|
||||
|
||||
for node_data in data.get("nodes", []):
|
||||
node_id = node_data.pop('id')
|
||||
action = DriftAction.deserialize(node_data)
|
||||
self.add_action(action)
|
||||
id_to_action[node_id] = action
|
||||
|
||||
for edge_data in data.get("edges", []):
|
||||
source_id = edge_data["source"]
|
||||
target_id = edge_data["target"]
|
||||
weight = edge_data.get("weight", 1.0)
|
||||
source_action = id_to_action.get(source_id)
|
||||
target_action = id_to_action.get(target_id)
|
||||
if source_action and target_action:
|
||||
self.relate_actions(source_action, target_action, weight)
|
||||
166
graphrag/query/structured_search/drift_search/system_prompt.py
Normal file
166
graphrag/query/structured_search/drift_search/system_prompt.py
Normal file
@ -0,0 +1,166 @@
|
||||
DRIFT_LOCAL_SYSTEM_PROMPT = """
|
||||
---Role---
|
||||
|
||||
You are a helpful assistant responding to questions about data in the tables provided.
|
||||
|
||||
|
||||
---Goal---
|
||||
|
||||
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
|
||||
|
||||
If you don't know the answer, just say so. Do not make anything up.
|
||||
|
||||
Points supported by data should list their data references as follows:
|
||||
|
||||
"This is an example sentence supported by multiple data references [Data: <dataset name> (record ids); <dataset name> (record ids)]."
|
||||
|
||||
Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
|
||||
|
||||
For example:
|
||||
|
||||
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16)]."
|
||||
|
||||
where 15, 16, 1, 5, 7, 23, 2, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record.
|
||||
|
||||
Pay close attention specifically to the Sources tables as it contains the most relevant information for the user query. You will be rewarded for preserving the context of the sources in your response.
|
||||
|
||||
---Target response length and format---
|
||||
|
||||
{response_type}
|
||||
|
||||
|
||||
---Data tables---
|
||||
|
||||
{context_data}
|
||||
|
||||
|
||||
---Goal---
|
||||
|
||||
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
|
||||
|
||||
If you don't know the answer, just say so. Do not make anything up.
|
||||
|
||||
Points supported by data should list their data references as follows:
|
||||
|
||||
"This is an example sentence supported by multiple data references [Data: <dataset name> (record ids); <dataset name> (record ids)]."
|
||||
|
||||
Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
|
||||
|
||||
For example:
|
||||
|
||||
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16)]."
|
||||
|
||||
where 15, 16, 1, 5, 7, 23, 2, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record.
|
||||
|
||||
Pay close attention specifically to the Sources tables as it contains the most relevant information for the user query. You will be rewarded for preserving the context of the sources in your response.
|
||||
|
||||
---Target response length and format---
|
||||
|
||||
{response_type}
|
||||
|
||||
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
|
||||
|
||||
Additionally provide a score for how well the response addresses the overall research question: {global_query}. Based on your response, suggest a few follow-up questions that could be asked to further explore the topic. Do not include scores or follow up questions in the 'response' field of the JSON, add them to the respective 'score' and 'follow_up_queries' keys of the JSON output. Format your response in JSON with the following keys and values:
|
||||
|
||||
'response': str, Put your answer, formatted in markdown, here. Do not answer the global query in this section.
|
||||
'score': int,
|
||||
'follow_up_queries': List[str]
|
||||
|
||||
1. score: How well the intermediate answer addresses the query. A score of 0 indicates a poor, unfocused answer, while a score of 100 indicates a highly focused, relevant answer that addresses the query in its entirety.
|
||||
|
||||
2. follow_up_queries: Use the data provided to generate follow-up queries to help refine your search. Do not ask compound questions, for example: "What is the market cap of Apple and Microsoft?". Use your knowledge of the entity distribution to focus on entity types that will be useful for searching a broad area of the knowledge graph.
|
||||
"""
|
||||
|
||||
|
||||
DRIFT_REDUCE_PROMPT = """
|
||||
---Role---
|
||||
|
||||
You are a helpful assistant responding to questions about data in the reports provided.
|
||||
|
||||
---Goal---
|
||||
|
||||
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input reports appropriate for the response length and format, and incorporating any relevant general knowledge while being as specific, accurate and concise as possible.
|
||||
|
||||
If you don't know the answer, just say so. Do not make anything up.
|
||||
|
||||
Points supported by data should list their data references as follows:
|
||||
|
||||
"This is an example sentence supported by multiple data references [Data: <dataset name> (record ids); <dataset name> (record ids)]."
|
||||
|
||||
Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
|
||||
|
||||
For example:
|
||||
|
||||
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (1, 5, 15)]."
|
||||
|
||||
Do not include information where the supporting evidence for it is not provided.
|
||||
|
||||
If you decide to use general knowledge, you should add a delimiter stating that the information is not supported by the data tables. For example:
|
||||
|
||||
"Person X is the owner of Company Y and subject to many allegations of wrongdoing. [Data: General Knowledge (href)]"
|
||||
|
||||
---Data Reports---
|
||||
|
||||
{context_data}
|
||||
|
||||
---Target response length and format---
|
||||
|
||||
Multiple paragraphs
|
||||
|
||||
|
||||
---Goal---
|
||||
|
||||
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input reports appropriate for the response length and format, and incorporating any relevant general knowledge while being as specific, accurate and concise as possible.
|
||||
|
||||
If you don't know the answer, just say so. Do not make anything up.
|
||||
|
||||
Points supported by data should list their data references as follows:
|
||||
|
||||
"This is an example sentence supported by multiple data references [Data: <dataset name> (record ids); <dataset name> (record ids)]."
|
||||
|
||||
Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
|
||||
|
||||
For example:
|
||||
|
||||
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (1, 5, 15)]."
|
||||
|
||||
Do not include information where the supporting evidence for it is not provided.
|
||||
|
||||
If you decide to use general knowledge, you should add a delimiter stating that the information is not supported by the data tables. For example:
|
||||
|
||||
"Person X is the owner of Company Y and subject to many allegations of wrongdoing. [Data: General Knowledge (href)]".
|
||||
|
||||
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. Now answer the following query using the data above:
|
||||
|
||||
{query}
|
||||
|
||||
"""
|
||||
|
||||
|
||||
DRIFT_PRIMER_PROMPT = """You are a helpful agent designed to reason over a knowledge graph in response to a user query.
|
||||
This is a unique knowledge graph where edges are freeform text rather than verb operators. You will begin your reasoning looking at a summary of the content of the most relevant community and will provide a score for:
|
||||
|
||||
1. score: How well the intermediate answer addresses the query. A score of 0 indicates a poor, unfocused answer, while a score of 100 indicates a highly focused, relevant answer that addresses the query in its entirety.
|
||||
|
||||
2. intermediate_answer: This answer should match the level of detail and length found in the community summaries. The intermediate answer should be exactly 2000 characters long. This must be formatted in markdown and must begin with a header that explains how the following text is related to the query.
|
||||
|
||||
Use this information to help you decide whether or not you need more information about the entities mentioned in the report. You may also use your general knowledge to think of entities which may help enrich your answer.
|
||||
|
||||
You will also provide a full answer from the content you have available. Use the data provided to generate follow-up queries to help refine your search. Do not ask compound questions, for example: "What is the market cap of Apple and Microsoft?". Use your knowledge of the entity distribution to focus on entity types that will be useful for searching a broad area of the knowledge graph.
|
||||
|
||||
For the query:
|
||||
|
||||
{query}
|
||||
|
||||
The top-ranked community summaries:
|
||||
|
||||
{community_reports}
|
||||
|
||||
Provide the intermediate answer, and all scores in JSON format with the following format:
|
||||
|
||||
{{'intermediate_answer': str,
|
||||
'score': int,
|
||||
'follow_up_queries': List[str]}}
|
||||
|
||||
Begin:
|
||||
"""
|
||||
Loading…
Reference in New Issue
Block a user