mirror of
https://github.com/microsoft/graphrag.git
synced 2026-02-08 04:02:29 +08:00
Some algorithmic refactors
This commit is contained in:
parent
60441985f2
commit
44246e2fd9
@ -62,3 +62,30 @@ class CommunityReport(Named):
|
||||
full_content_embedding=d.get(full_content_embedding_key),
|
||||
attributes=d.get(attributes_key),
|
||||
)
|
||||
|
||||
def to_dict(
|
||||
self,
|
||||
id_key: str = "id",
|
||||
title_key: str = "title",
|
||||
community_id_key: str = "community_id",
|
||||
short_id_key: str = "short_id",
|
||||
summary_key: str = "summary",
|
||||
full_content_key: str = "full_content",
|
||||
rank_key: str = "rank",
|
||||
summary_embedding_key: str = "summary_embedding",
|
||||
full_content_embedding_key: str = "full_content_embedding",
|
||||
attributes_key: str = "attributes",
|
||||
) -> dict[str, Any]:
|
||||
"""Convert the community report to a dictionary."""
|
||||
return {
|
||||
id_key: self.id,
|
||||
title_key: self.title,
|
||||
community_id_key: self.community_id,
|
||||
short_id_key: self.short_id,
|
||||
summary_key: self.summary,
|
||||
full_content_key: self.full_content,
|
||||
rank_key: self.rank,
|
||||
summary_embedding_key: self.summary_embedding,
|
||||
full_content_embedding_key: self.full_content_embedding,
|
||||
attributes_key: self.attributes,
|
||||
}
|
||||
|
||||
@ -146,15 +146,21 @@ class DriftAction:
|
||||
DriftAction
|
||||
A deserialized instance of DriftAction.
|
||||
"""
|
||||
action = cls(data["query"])
|
||||
# Ensure 'query' exists in the data, raise a ValueError if missing
|
||||
query = data.get("query")
|
||||
if query is None:
|
||||
error_message = "Missing 'query' key in serialized data"
|
||||
raise ValueError(error_message)
|
||||
|
||||
# Initialize the DriftAction
|
||||
action = cls(query)
|
||||
action.answer = data.get("answer")
|
||||
action.score = data.get("score")
|
||||
action.metadata = data.get("metadata", {})
|
||||
action.follow_ups = (
|
||||
[cls.deserialize(fu_data) for fu_data in data.get("follow_up_queries", [])]
|
||||
if "follow_ups" in data
|
||||
else []
|
||||
)
|
||||
|
||||
action.follow_ups = [
|
||||
cls.deserialize(fu_data) for fu_data in data.get("follow_up_queries", [])
|
||||
]
|
||||
return action
|
||||
|
||||
@classmethod
|
||||
@ -187,10 +193,22 @@ class DriftAction:
|
||||
action.score = response.get("score")
|
||||
return action
|
||||
|
||||
error_message = "Response must be a dictionary"
|
||||
# If response is a string, attempt to parse as JSON
|
||||
if isinstance(response, str):
|
||||
try:
|
||||
parsed_response = json.loads(response)
|
||||
if isinstance(parsed_response, dict):
|
||||
return cls.from_primer_response(query, parsed_response)
|
||||
error_message = "Parsed response must be a dictionary."
|
||||
raise ValueError(error_message)
|
||||
except json.JSONDecodeError as e:
|
||||
error_message = f"Failed to parse response string: {e}. Parsed response must be a dictionary."
|
||||
raise ValueError(error_message)
|
||||
|
||||
error_message = f"Unsupported response type: {type(response).__name__}. Expected a dictionary or JSON string."
|
||||
raise ValueError(error_message)
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
"""
|
||||
Allow DriftAction objects to be hashable for use in networkx.MultiDiGraph.
|
||||
|
||||
@ -203,7 +221,7 @@ class DriftAction:
|
||||
"""
|
||||
return hash(self.query)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""
|
||||
Check equality based on the query string.
|
||||
|
||||
@ -215,4 +233,6 @@ class DriftAction:
|
||||
bool
|
||||
True if the other object is a DriftAction with the same query, False otherwise.
|
||||
"""
|
||||
return isinstance(other, DriftAction) and self.query == other.query
|
||||
if not isinstance(other, DriftAction):
|
||||
return False
|
||||
return self.query == other.query
|
||||
|
||||
@ -111,12 +111,16 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
------
|
||||
ValueError: If some reports are missing full content or full content embeddings.
|
||||
"""
|
||||
report_df = pd.DataFrame([report.__dict__ for report in reports])
|
||||
report_df = pd.DataFrame([report.to_dict() for report in reports])
|
||||
missing_content_error = "Some reports are missing full content."
|
||||
missing_embedding_error = "Some reports are missing full content embeddings."
|
||||
|
||||
if report_df["full_content"].isna().sum() > 0:
|
||||
if (
|
||||
"full_content" not in report_df.columns
|
||||
or report_df["full_content"].isna().sum() > 0
|
||||
):
|
||||
raise ValueError(missing_content_error)
|
||||
|
||||
if (
|
||||
"full_content_embedding" not in report_df.columns
|
||||
or report_df["full_content_embedding"].isna().sum() > 0
|
||||
@ -141,7 +145,9 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
bool: True if embeddings match, otherwise False.
|
||||
"""
|
||||
return (
|
||||
isinstance(query_embedding, type(embedding))
|
||||
query_embedding is not None
|
||||
and embedding is not None
|
||||
and isinstance(query_embedding, type(embedding))
|
||||
and len(query_embedding) == len(embedding)
|
||||
and isinstance(query_embedding[0], type(embedding[0]))
|
||||
)
|
||||
@ -182,21 +188,27 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
|
||||
report_df = self.convert_reports_to_df(self.reports)
|
||||
|
||||
if self.check_query_doc_encodings(
|
||||
# Check compatibility between query embedding and document embeddings
|
||||
if not self.check_query_doc_encodings(
|
||||
query_embedding, report_df["full_content_embedding"].iloc[0]
|
||||
):
|
||||
report_df["similarity"] = report_df["full_content_embedding"].apply(
|
||||
lambda x: np.dot(x, query_embedding)
|
||||
/ (np.linalg.norm(x) * np.linalg.norm(query_embedding))
|
||||
)
|
||||
top_k = report_df.sort_values("similarity", ascending=False).head(
|
||||
self.config.drift_k_followups
|
||||
)
|
||||
else:
|
||||
incompatible_embeddings_error = (
|
||||
error_message = (
|
||||
"Query and document embeddings are not compatible. "
|
||||
"Please ensure that the embeddings are of the same type and length."
|
||||
)
|
||||
raise ValueError(incompatible_embeddings_error)
|
||||
raise ValueError(error_message)
|
||||
|
||||
# Vectorized cosine similarity computation
|
||||
query_norm = np.linalg.norm(query_embedding)
|
||||
document_norms = np.linalg.norm(
|
||||
report_df["full_content_embedding"].to_list(), axis=1
|
||||
)
|
||||
dot_products = np.dot(
|
||||
np.vstack(report_df["full_content_embedding"].to_list()), query_embedding
|
||||
)
|
||||
report_df["similarity"] = dot_products / (document_norms * query_norm)
|
||||
|
||||
# Sort by similarity and select top-k
|
||||
top_k = report_df.nlargest(self.config.drift_k_followups, "similarity")
|
||||
|
||||
return top_k.loc[:, ["short_id", "community_id", "full_content"]]
|
||||
|
||||
@ -8,6 +8,7 @@ import logging
|
||||
import secrets
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
@ -159,19 +160,19 @@ class DRIFTPrimer:
|
||||
-------
|
||||
SearchResult: The search result containing the response and context data.
|
||||
"""
|
||||
start_time = time.time()
|
||||
start_time = time.perf_counter()
|
||||
report_folds = self.split_reports(top_k_reports)
|
||||
tasks = [self.decompose_query(query, fold) for fold in report_folds]
|
||||
results_with_tokens = await tqdm_asyncio.gather(*tasks)
|
||||
|
||||
completion_time = time.time() - start_time
|
||||
completion_time = time.perf_counter() - start_time
|
||||
|
||||
return SearchResult(
|
||||
response=[response for response, _ in results_with_tokens],
|
||||
context_data={"top_k_reports": top_k_reports},
|
||||
context_text=str(top_k_reports),
|
||||
context_text=top_k_reports.to_json(),
|
||||
completion_time=completion_time,
|
||||
llm_calls=2,
|
||||
llm_calls=len(results_with_tokens),
|
||||
prompt_tokens=sum(tokens for _, tokens in results_with_tokens),
|
||||
)
|
||||
|
||||
@ -186,17 +187,7 @@ class DRIFTPrimer:
|
||||
-------
|
||||
list[pd.DataFrame]: List of report folds.
|
||||
"""
|
||||
folds = []
|
||||
num_reports = len(reports)
|
||||
primer_folds = self.config.primer_folds or 1 # Ensure at least one fold
|
||||
|
||||
for i in range(primer_folds):
|
||||
start_idx = i * num_reports // primer_folds
|
||||
end_idx = (
|
||||
num_reports
|
||||
if i == primer_folds - 1
|
||||
else (i + 1) * num_reports // primer_folds
|
||||
)
|
||||
fold = reports.iloc[start_idx:end_idx]
|
||||
folds.append(fold)
|
||||
return folds
|
||||
if primer_folds == 1:
|
||||
return [reports]
|
||||
return [pd.DataFrame(fold) for fold in np.array_split(reports, primer_folds)]
|
||||
|
||||
@ -124,13 +124,9 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
error_msg = "No intermediate answers found in primer response. Ensure that the primer response includes intermediate answers."
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
intermediate_answer = "\n\n".join(
|
||||
[
|
||||
i["intermediate_answer"]
|
||||
for i in response
|
||||
if "intermediate_answer" in i
|
||||
]
|
||||
)
|
||||
intermediate_answer = "\n\n".join([
|
||||
i["intermediate_answer"] for i in response if "intermediate_answer" in i
|
||||
])
|
||||
|
||||
follow_ups = [fu for i in response for fu in i.get("follow_up_queries", [])]
|
||||
if len(follow_ups) == 0:
|
||||
@ -193,7 +189,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
error_msg = "DRIFT Search query cannot be empty."
|
||||
raise ValueError(error_msg)
|
||||
|
||||
start_time = time.time()
|
||||
start_time = time.perf_counter()
|
||||
primer_token_ct = 0
|
||||
context_token_ct = 0
|
||||
|
||||
@ -233,7 +229,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
self.query_state.add_all_follow_ups(action, action.follow_ups)
|
||||
epochs += 1
|
||||
|
||||
t_elapsed = time.time() - start_time
|
||||
t_elapsed = time.perf_counter() - start_time
|
||||
|
||||
# Calculate token usage
|
||||
total_tokens = (
|
||||
|
||||
@ -5,7 +5,11 @@
|
||||
|
||||
import logging
|
||||
import random
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Callable
|
||||
|
||||
|
||||
import networkx as nx
|
||||
|
||||
@ -55,7 +59,9 @@ class QueryState:
|
||||
"""Find all unanswered actions in the graph."""
|
||||
return [node for node in self.graph.nodes if not node.is_complete]
|
||||
|
||||
def rank_incomplete_actions(self, scorer: Any | None = None) -> list[DriftAction]:
|
||||
def rank_incomplete_actions(
|
||||
self, scorer: Callable[[DriftAction], float] | None = None
|
||||
) -> list[DriftAction]:
|
||||
"""Rank all unanswered actions in the graph if scorer available."""
|
||||
unanswered = self.find_incomplete_actions()
|
||||
if scorer:
|
||||
@ -63,9 +69,9 @@ class QueryState:
|
||||
node.compute_score(scorer)
|
||||
return sorted(
|
||||
unanswered,
|
||||
key=lambda node: node.score
|
||||
if node.score is not None
|
||||
else float("-inf"),
|
||||
key=lambda node: (
|
||||
node.score if node.score is not None else float("-inf")
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
@ -81,30 +87,31 @@ class QueryState:
|
||||
node_to_id = {node: idx for idx, node in enumerate(self.graph.nodes())}
|
||||
|
||||
# Serialize nodes
|
||||
nodes = []
|
||||
for node in self.graph.nodes():
|
||||
node_data = node.serialize(include_follow_ups=False)
|
||||
node_data["id"] = node_to_id[node]
|
||||
node_attributes = self.graph.nodes[node]
|
||||
if node_attributes:
|
||||
node_data.update(node_attributes)
|
||||
nodes.append(node_data)
|
||||
nodes: list[dict[str, Any]] = [
|
||||
{
|
||||
**node.serialize(include_follow_ups=False),
|
||||
"id": node_to_id[node],
|
||||
**self.graph.nodes[node],
|
||||
}
|
||||
for node in self.graph.nodes()
|
||||
]
|
||||
|
||||
# Serialize edges
|
||||
edges = []
|
||||
for u, v, edge_data in self.graph.edges(data=True):
|
||||
edge_info = {
|
||||
edges: list[dict[str, Any]] = [
|
||||
{
|
||||
"source": node_to_id[u],
|
||||
"target": node_to_id[v],
|
||||
"weight": edge_data.get("weight", 1.0),
|
||||
}
|
||||
edges.append(edge_info)
|
||||
for u, v, edge_data in self.graph.edges(data=True)
|
||||
]
|
||||
|
||||
if include_context:
|
||||
context_data = {}
|
||||
for node in nodes:
|
||||
if node["metadata"].get("context_data") and node.get("query"):
|
||||
context_data[node["query"]] = node["metadata"]["context_data"]
|
||||
context_data = {
|
||||
node["query"]: node["metadata"]["context_data"]
|
||||
for node in nodes
|
||||
if node["metadata"].get("context_data") and node.get("query")
|
||||
}
|
||||
|
||||
context_text = str(context_data)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user