From 44246e2fd9946fe76767d391028049c36d86111f Mon Sep 17 00:00:00 2001 From: Alonso Guevara Date: Mon, 14 Oct 2024 19:33:55 -0600 Subject: [PATCH] Some algorithmic refactors --- graphrag/model/community_report.py | 27 ++++++++++ .../structured_search/drift_search/action.py | 40 +++++++++++---- .../drift_search/drift_context.py | 40 +++++++++------ .../structured_search/drift_search/primer.py | 25 +++------- .../structured_search/drift_search/search.py | 14 ++---- .../structured_search/drift_search/state.py | 49 +++++++++++-------- 6 files changed, 124 insertions(+), 71 deletions(-) diff --git a/graphrag/model/community_report.py b/graphrag/model/community_report.py index 2666c0b5..a16e448c 100644 --- a/graphrag/model/community_report.py +++ b/graphrag/model/community_report.py @@ -62,3 +62,30 @@ class CommunityReport(Named): full_content_embedding=d.get(full_content_embedding_key), attributes=d.get(attributes_key), ) + + def to_dict( + self, + id_key: str = "id", + title_key: str = "title", + community_id_key: str = "community_id", + short_id_key: str = "short_id", + summary_key: str = "summary", + full_content_key: str = "full_content", + rank_key: str = "rank", + summary_embedding_key: str = "summary_embedding", + full_content_embedding_key: str = "full_content_embedding", + attributes_key: str = "attributes", + ) -> dict[str, Any]: + """Convert the community report to a dictionary.""" + return { + id_key: self.id, + title_key: self.title, + community_id_key: self.community_id, + short_id_key: self.short_id, + summary_key: self.summary, + full_content_key: self.full_content, + rank_key: self.rank, + summary_embedding_key: self.summary_embedding, + full_content_embedding_key: self.full_content_embedding, + attributes_key: self.attributes, + } diff --git a/graphrag/query/structured_search/drift_search/action.py b/graphrag/query/structured_search/drift_search/action.py index 19c8c3c6..4514d4de 100644 --- a/graphrag/query/structured_search/drift_search/action.py +++ b/graphrag/query/structured_search/drift_search/action.py @@ -146,15 +146,21 @@ class DriftAction: DriftAction A deserialized instance of DriftAction. """ - action = cls(data["query"]) + # Ensure 'query' exists in the data, raise a ValueError if missing + query = data.get("query") + if query is None: + error_message = "Missing 'query' key in serialized data" + raise ValueError(error_message) + + # Initialize the DriftAction + action = cls(query) action.answer = data.get("answer") action.score = data.get("score") action.metadata = data.get("metadata", {}) - action.follow_ups = ( - [cls.deserialize(fu_data) for fu_data in data.get("follow_up_queries", [])] - if "follow_ups" in data - else [] - ) + + action.follow_ups = [ + cls.deserialize(fu_data) for fu_data in data.get("follow_up_queries", []) + ] return action @classmethod @@ -187,10 +193,22 @@ class DriftAction: action.score = response.get("score") return action - error_message = "Response must be a dictionary" + # If response is a string, attempt to parse as JSON + if isinstance(response, str): + try: + parsed_response = json.loads(response) + if isinstance(parsed_response, dict): + return cls.from_primer_response(query, parsed_response) + error_message = "Parsed response must be a dictionary." + raise ValueError(error_message) + except json.JSONDecodeError as e: + error_message = f"Failed to parse response string: {e}. Parsed response must be a dictionary." + raise ValueError(error_message) + + error_message = f"Unsupported response type: {type(response).__name__}. Expected a dictionary or JSON string." raise ValueError(error_message) - def __hash__(self): + def __hash__(self) -> int: """ Allow DriftAction objects to be hashable for use in networkx.MultiDiGraph. @@ -203,7 +221,7 @@ class DriftAction: """ return hash(self.query) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: """ Check equality based on the query string. @@ -215,4 +233,6 @@ class DriftAction: bool True if the other object is a DriftAction with the same query, False otherwise. """ - return isinstance(other, DriftAction) and self.query == other.query + if not isinstance(other, DriftAction): + return False + return self.query == other.query diff --git a/graphrag/query/structured_search/drift_search/drift_context.py b/graphrag/query/structured_search/drift_search/drift_context.py index 07f0f2b0..806cbd49 100644 --- a/graphrag/query/structured_search/drift_search/drift_context.py +++ b/graphrag/query/structured_search/drift_search/drift_context.py @@ -111,12 +111,16 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder): ------ ValueError: If some reports are missing full content or full content embeddings. """ - report_df = pd.DataFrame([report.__dict__ for report in reports]) + report_df = pd.DataFrame([report.to_dict() for report in reports]) missing_content_error = "Some reports are missing full content." missing_embedding_error = "Some reports are missing full content embeddings." - if report_df["full_content"].isna().sum() > 0: + if ( + "full_content" not in report_df.columns + or report_df["full_content"].isna().sum() > 0 + ): raise ValueError(missing_content_error) + if ( "full_content_embedding" not in report_df.columns or report_df["full_content_embedding"].isna().sum() > 0 @@ -141,7 +145,9 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder): bool: True if embeddings match, otherwise False. """ return ( - isinstance(query_embedding, type(embedding)) + query_embedding is not None + and embedding is not None + and isinstance(query_embedding, type(embedding)) and len(query_embedding) == len(embedding) and isinstance(query_embedding[0], type(embedding[0])) ) @@ -182,21 +188,27 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder): report_df = self.convert_reports_to_df(self.reports) - if self.check_query_doc_encodings( + # Check compatibility between query embedding and document embeddings + if not self.check_query_doc_encodings( query_embedding, report_df["full_content_embedding"].iloc[0] ): - report_df["similarity"] = report_df["full_content_embedding"].apply( - lambda x: np.dot(x, query_embedding) - / (np.linalg.norm(x) * np.linalg.norm(query_embedding)) - ) - top_k = report_df.sort_values("similarity", ascending=False).head( - self.config.drift_k_followups - ) - else: - incompatible_embeddings_error = ( + error_message = ( "Query and document embeddings are not compatible. " "Please ensure that the embeddings are of the same type and length." ) - raise ValueError(incompatible_embeddings_error) + raise ValueError(error_message) + + # Vectorized cosine similarity computation + query_norm = np.linalg.norm(query_embedding) + document_norms = np.linalg.norm( + report_df["full_content_embedding"].to_list(), axis=1 + ) + dot_products = np.dot( + np.vstack(report_df["full_content_embedding"].to_list()), query_embedding + ) + report_df["similarity"] = dot_products / (document_norms * query_norm) + + # Sort by similarity and select top-k + top_k = report_df.nlargest(self.config.drift_k_followups, "similarity") return top_k.loc[:, ["short_id", "community_id", "full_content"]] diff --git a/graphrag/query/structured_search/drift_search/primer.py b/graphrag/query/structured_search/drift_search/primer.py index e61a3613..53b28bac 100644 --- a/graphrag/query/structured_search/drift_search/primer.py +++ b/graphrag/query/structured_search/drift_search/primer.py @@ -8,6 +8,7 @@ import logging import secrets import time +import numpy as np import pandas as pd import tiktoken from tqdm.asyncio import tqdm_asyncio @@ -159,19 +160,19 @@ class DRIFTPrimer: ------- SearchResult: The search result containing the response and context data. """ - start_time = time.time() + start_time = time.perf_counter() report_folds = self.split_reports(top_k_reports) tasks = [self.decompose_query(query, fold) for fold in report_folds] results_with_tokens = await tqdm_asyncio.gather(*tasks) - completion_time = time.time() - start_time + completion_time = time.perf_counter() - start_time return SearchResult( response=[response for response, _ in results_with_tokens], context_data={"top_k_reports": top_k_reports}, - context_text=str(top_k_reports), + context_text=top_k_reports.to_json(), completion_time=completion_time, - llm_calls=2, + llm_calls=len(results_with_tokens), prompt_tokens=sum(tokens for _, tokens in results_with_tokens), ) @@ -186,17 +187,7 @@ class DRIFTPrimer: ------- list[pd.DataFrame]: List of report folds. """ - folds = [] - num_reports = len(reports) primer_folds = self.config.primer_folds or 1 # Ensure at least one fold - - for i in range(primer_folds): - start_idx = i * num_reports // primer_folds - end_idx = ( - num_reports - if i == primer_folds - 1 - else (i + 1) * num_reports // primer_folds - ) - fold = reports.iloc[start_idx:end_idx] - folds.append(fold) - return folds + if primer_folds == 1: + return [reports] + return [pd.DataFrame(fold) for fold in np.array_split(reports, primer_folds)] diff --git a/graphrag/query/structured_search/drift_search/search.py b/graphrag/query/structured_search/drift_search/search.py index a2143097..93ac5bba 100644 --- a/graphrag/query/structured_search/drift_search/search.py +++ b/graphrag/query/structured_search/drift_search/search.py @@ -124,13 +124,9 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]): error_msg = "No intermediate answers found in primer response. Ensure that the primer response includes intermediate answers." raise RuntimeError(error_msg) - intermediate_answer = "\n\n".join( - [ - i["intermediate_answer"] - for i in response - if "intermediate_answer" in i - ] - ) + intermediate_answer = "\n\n".join([ + i["intermediate_answer"] for i in response if "intermediate_answer" in i + ]) follow_ups = [fu for i in response for fu in i.get("follow_up_queries", [])] if len(follow_ups) == 0: @@ -193,7 +189,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]): error_msg = "DRIFT Search query cannot be empty." raise ValueError(error_msg) - start_time = time.time() + start_time = time.perf_counter() primer_token_ct = 0 context_token_ct = 0 @@ -233,7 +229,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]): self.query_state.add_all_follow_ups(action, action.follow_ups) epochs += 1 - t_elapsed = time.time() - start_time + t_elapsed = time.perf_counter() - start_time # Calculate token usage total_tokens = ( diff --git a/graphrag/query/structured_search/drift_search/state.py b/graphrag/query/structured_search/drift_search/state.py index b8a9edc6..21d594c4 100644 --- a/graphrag/query/structured_search/drift_search/state.py +++ b/graphrag/query/structured_search/drift_search/state.py @@ -5,7 +5,11 @@ import logging import random -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from typing import Callable + import networkx as nx @@ -55,7 +59,9 @@ class QueryState: """Find all unanswered actions in the graph.""" return [node for node in self.graph.nodes if not node.is_complete] - def rank_incomplete_actions(self, scorer: Any | None = None) -> list[DriftAction]: + def rank_incomplete_actions( + self, scorer: Callable[[DriftAction], float] | None = None + ) -> list[DriftAction]: """Rank all unanswered actions in the graph if scorer available.""" unanswered = self.find_incomplete_actions() if scorer: @@ -63,9 +69,9 @@ class QueryState: node.compute_score(scorer) return sorted( unanswered, - key=lambda node: node.score - if node.score is not None - else float("-inf"), + key=lambda node: ( + node.score if node.score is not None else float("-inf") + ), reverse=True, ) @@ -81,30 +87,31 @@ class QueryState: node_to_id = {node: idx for idx, node in enumerate(self.graph.nodes())} # Serialize nodes - nodes = [] - for node in self.graph.nodes(): - node_data = node.serialize(include_follow_ups=False) - node_data["id"] = node_to_id[node] - node_attributes = self.graph.nodes[node] - if node_attributes: - node_data.update(node_attributes) - nodes.append(node_data) + nodes: list[dict[str, Any]] = [ + { + **node.serialize(include_follow_ups=False), + "id": node_to_id[node], + **self.graph.nodes[node], + } + for node in self.graph.nodes() + ] # Serialize edges - edges = [] - for u, v, edge_data in self.graph.edges(data=True): - edge_info = { + edges: list[dict[str, Any]] = [ + { "source": node_to_id[u], "target": node_to_id[v], "weight": edge_data.get("weight", 1.0), } - edges.append(edge_info) + for u, v, edge_data in self.graph.edges(data=True) + ] if include_context: - context_data = {} - for node in nodes: - if node["metadata"].get("context_data") and node.get("query"): - context_data[node["query"]] = node["metadata"]["context_data"] + context_data = { + node["query"]: node["metadata"]["context_data"] + for node in nodes + if node["metadata"].get("context_data") and node.get("query") + } context_text = str(context_data)