diff --git a/graphrag/query/factories.py b/graphrag/query/factories.py index ff5b75a0..cae54eff 100644 --- a/graphrag/query/factories.py +++ b/graphrag/query/factories.py @@ -17,7 +17,6 @@ from graphrag.model import ( Relationship, TextUnit, ) - from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey from graphrag.query.llm.oai.chat_openai import ChatOpenAI from graphrag.query.llm.oai.embedding import OpenAIEmbedding diff --git a/graphrag/query/structured_search/drift_search/action.py b/graphrag/query/structured_search/drift_search/action.py index 7529d1a7..19c8c3c6 100644 --- a/graphrag/query/structured_search/drift_search/action.py +++ b/graphrag/query/structured_search/drift_search/action.py @@ -86,7 +86,9 @@ class DriftAction: generated_tokens = 0 else: generated_tokens = num_tokens(self.answer, search_engine.token_encoder) - self.metadata.update({"token_ct": search_result.prompt_tokens + generated_tokens}) + self.metadata.update({ + "token_ct": search_result.prompt_tokens + generated_tokens + }) self.follow_ups = response.pop("follow_up_queries", []) if not self.follow_ups: @@ -105,7 +107,9 @@ class DriftAction: scorer (Any): The scorer to compute the score. """ score = scorer.compute_score(self.query, self.answer) - self.score = score if score is not None else float("-inf") # Default to -inf for sorting + self.score = ( + score if score is not None else float("-inf") + ) # Default to -inf for sorting def serialize(self, include_follow_ups: bool = True) -> dict[str, Any]: """ diff --git a/graphrag/query/structured_search/drift_search/drift_context.py b/graphrag/query/structured_search/drift_search/drift_context.py index 434d8426..07f0f2b0 100644 --- a/graphrag/query/structured_search/drift_search/drift_context.py +++ b/graphrag/query/structured_search/drift_search/drift_context.py @@ -76,7 +76,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder): def init_local_context_builder(self) -> LocalSearchMixedContext: """ Initialize the local search mixed context builder. - + Returns ------- LocalSearchMixedContext: Initialized local context. @@ -97,7 +97,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder): def convert_reports_to_df(reports: list[CommunityReport]) -> pd.DataFrame: """ Convert a list of CommunityReport objects to a pandas DataFrame. - + Args ---- reports : list[CommunityReport] @@ -106,7 +106,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder): Returns ------- pd.DataFrame: DataFrame with report data. - + Raises ------ ValueError: If some reports are missing full content or full content embeddings. @@ -114,7 +114,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder): report_df = pd.DataFrame([report.__dict__ for report in reports]) missing_content_error = "Some reports are missing full content." missing_embedding_error = "Some reports are missing full content embeddings." - + if report_df["full_content"].isna().sum() > 0: raise ValueError(missing_content_error) if ( @@ -128,14 +128,14 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder): def check_query_doc_encodings(query_embedding: Any, embedding: Any) -> bool: """ Check if the embeddings are compatible. - + Args ---- query_embedding : Any Embedding of the query. embedding : Any Embedding to compare against. - + Returns ------- bool: True if embeddings match, otherwise False. @@ -149,7 +149,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder): def build_context(self, query: str, **kwargs) -> pd.DataFrame: """ Build DRIFT search context. - + Args ---- query : str @@ -158,14 +158,16 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder): Returns ------- pd.DataFrame: Top-k most similar documents. - + Raises ------ ValueError: If no community reports are available, or embeddings are incompatible. """ if self.reports is None: - missing_reports_error = "No community reports available. Please provide a list of reports." + missing_reports_error = ( + "No community reports available. Please provide a list of reports." + ) raise ValueError(missing_reports_error) query_processor = PrimerQueryProcessor( diff --git a/graphrag/query/structured_search/drift_search/primer.py b/graphrag/query/structured_search/drift_search/primer.py index 978daf30..e61a3613 100644 --- a/graphrag/query/structured_search/drift_search/primer.py +++ b/graphrag/query/structured_search/drift_search/primer.py @@ -61,13 +61,12 @@ class PrimerQueryProcessor: tuple[str, int]: Expanded query text and the number of tokens used. """ token_ct = 0 - template = secrets.choice(self.reports).full_content # nosec S311 + template = secrets.choice(self.reports).full_content # nosec S311 prompt = f"""Create a hypothetical answer to the following query: {query}\n\n Format it to follow the structure of the template below:\n\n {template}\n" Ensure that the hypothetical answer does not reference new named entities that are not present in the original query.""" - messages = [{"role": "user", "content": prompt}] @@ -193,7 +192,11 @@ class DRIFTPrimer: for i in range(primer_folds): start_idx = i * num_reports // primer_folds - end_idx = num_reports if i == primer_folds - 1 else (i + 1) * num_reports // primer_folds + end_idx = ( + num_reports + if i == primer_folds - 1 + else (i + 1) * num_reports // primer_folds + ) fold = reports.iloc[start_idx:end_idx] folds.append(fold) return folds diff --git a/graphrag/query/structured_search/drift_search/search.py b/graphrag/query/structured_search/drift_search/search.py index 47c60488..a2143097 100644 --- a/graphrag/query/structured_search/drift_search/search.py +++ b/graphrag/query/structured_search/drift_search/search.py @@ -124,16 +124,15 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]): error_msg = "No intermediate answers found in primer response. Ensure that the primer response includes intermediate answers." raise RuntimeError(error_msg) - intermediate_answer = "\n\n".join([ - i["intermediate_answer"] for i in response if "intermediate_answer" in i - ]) + intermediate_answer = "\n\n".join( + [ + i["intermediate_answer"] + for i in response + if "intermediate_answer" in i + ] + ) - follow_ups = [ - fu - for i in response - if "follow_up_queries" in i and i["follow_up_queries"] - for fu in i["follow_up_queries"] - ] + follow_ups = [fu for i in response for fu in i.get("follow_up_queries", [])] if len(follow_ups) == 0: error_msg = "No follow-up queries found in primer response. Ensure that the primer response includes follow-up queries." raise RuntimeError(error_msg) diff --git a/graphrag/query/structured_search/drift_search/state.py b/graphrag/query/structured_search/drift_search/state.py index 3ca284f4..b8a9edc6 100644 --- a/graphrag/query/structured_search/drift_search/state.py +++ b/graphrag/query/structured_search/drift_search/state.py @@ -20,9 +20,7 @@ class QueryState: def __init__(self): self.graph = nx.MultiDiGraph() - def add_action( - self, action: DriftAction, metadata: dict[str, Any] | None = None - ): + def add_action(self, action: DriftAction, metadata: dict[str, Any] | None = None): """Add an action to the graph with optional metadata.""" self.graph.add_node(action, **(metadata or {})) @@ -64,13 +62,13 @@ class QueryState: for node in unanswered: node.compute_score(scorer) return sorted( - unanswered, - key=lambda node: node.score - if node.score is not None - else float("-inf"), - reverse=True, - ) - + unanswered, + key=lambda node: node.score + if node.score is not None + else float("-inf"), + reverse=True, + ) + # shuffle the list if no scorer random.shuffle(unanswered) return list(unanswered)