mirror of
https://github.com/microsoft/graphrag.git
synced 2026-02-08 04:02:29 +08:00
Formatting and ruff
This commit is contained in:
parent
5afae11a0e
commit
60441985f2
@ -17,7 +17,6 @@ from graphrag.model import (
|
||||
Relationship,
|
||||
TextUnit,
|
||||
)
|
||||
|
||||
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
|
||||
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
|
||||
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
|
||||
|
||||
@ -86,7 +86,9 @@ class DriftAction:
|
||||
generated_tokens = 0
|
||||
else:
|
||||
generated_tokens = num_tokens(self.answer, search_engine.token_encoder)
|
||||
self.metadata.update({"token_ct": search_result.prompt_tokens + generated_tokens})
|
||||
self.metadata.update({
|
||||
"token_ct": search_result.prompt_tokens + generated_tokens
|
||||
})
|
||||
|
||||
self.follow_ups = response.pop("follow_up_queries", [])
|
||||
if not self.follow_ups:
|
||||
@ -105,7 +107,9 @@ class DriftAction:
|
||||
scorer (Any): The scorer to compute the score.
|
||||
"""
|
||||
score = scorer.compute_score(self.query, self.answer)
|
||||
self.score = score if score is not None else float("-inf") # Default to -inf for sorting
|
||||
self.score = (
|
||||
score if score is not None else float("-inf")
|
||||
) # Default to -inf for sorting
|
||||
|
||||
def serialize(self, include_follow_ups: bool = True) -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
@ -76,7 +76,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
def init_local_context_builder(self) -> LocalSearchMixedContext:
|
||||
"""
|
||||
Initialize the local search mixed context builder.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
LocalSearchMixedContext: Initialized local context.
|
||||
@ -97,7 +97,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
def convert_reports_to_df(reports: list[CommunityReport]) -> pd.DataFrame:
|
||||
"""
|
||||
Convert a list of CommunityReport objects to a pandas DataFrame.
|
||||
|
||||
|
||||
Args
|
||||
----
|
||||
reports : list[CommunityReport]
|
||||
@ -106,7 +106,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame: DataFrame with report data.
|
||||
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError: If some reports are missing full content or full content embeddings.
|
||||
@ -114,7 +114,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
report_df = pd.DataFrame([report.__dict__ for report in reports])
|
||||
missing_content_error = "Some reports are missing full content."
|
||||
missing_embedding_error = "Some reports are missing full content embeddings."
|
||||
|
||||
|
||||
if report_df["full_content"].isna().sum() > 0:
|
||||
raise ValueError(missing_content_error)
|
||||
if (
|
||||
@ -128,14 +128,14 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
def check_query_doc_encodings(query_embedding: Any, embedding: Any) -> bool:
|
||||
"""
|
||||
Check if the embeddings are compatible.
|
||||
|
||||
|
||||
Args
|
||||
----
|
||||
query_embedding : Any
|
||||
Embedding of the query.
|
||||
embedding : Any
|
||||
Embedding to compare against.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool: True if embeddings match, otherwise False.
|
||||
@ -149,7 +149,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
def build_context(self, query: str, **kwargs) -> pd.DataFrame:
|
||||
"""
|
||||
Build DRIFT search context.
|
||||
|
||||
|
||||
Args
|
||||
----
|
||||
query : str
|
||||
@ -158,14 +158,16 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame: Top-k most similar documents.
|
||||
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError: If no community reports are available, or embeddings
|
||||
are incompatible.
|
||||
"""
|
||||
if self.reports is None:
|
||||
missing_reports_error = "No community reports available. Please provide a list of reports."
|
||||
missing_reports_error = (
|
||||
"No community reports available. Please provide a list of reports."
|
||||
)
|
||||
raise ValueError(missing_reports_error)
|
||||
|
||||
query_processor = PrimerQueryProcessor(
|
||||
|
||||
@ -61,13 +61,12 @@ class PrimerQueryProcessor:
|
||||
tuple[str, int]: Expanded query text and the number of tokens used.
|
||||
"""
|
||||
token_ct = 0
|
||||
template = secrets.choice(self.reports).full_content # nosec S311
|
||||
template = secrets.choice(self.reports).full_content # nosec S311
|
||||
|
||||
prompt = f"""Create a hypothetical answer to the following query: {query}\n\n
|
||||
Format it to follow the structure of the template below:\n\n
|
||||
{template}\n"
|
||||
Ensure that the hypothetical answer does not reference new named entities that are not present in the original query."""
|
||||
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
@ -193,7 +192,11 @@ class DRIFTPrimer:
|
||||
|
||||
for i in range(primer_folds):
|
||||
start_idx = i * num_reports // primer_folds
|
||||
end_idx = num_reports if i == primer_folds - 1 else (i + 1) * num_reports // primer_folds
|
||||
end_idx = (
|
||||
num_reports
|
||||
if i == primer_folds - 1
|
||||
else (i + 1) * num_reports // primer_folds
|
||||
)
|
||||
fold = reports.iloc[start_idx:end_idx]
|
||||
folds.append(fold)
|
||||
return folds
|
||||
|
||||
@ -124,16 +124,15 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
error_msg = "No intermediate answers found in primer response. Ensure that the primer response includes intermediate answers."
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
intermediate_answer = "\n\n".join([
|
||||
i["intermediate_answer"] for i in response if "intermediate_answer" in i
|
||||
])
|
||||
intermediate_answer = "\n\n".join(
|
||||
[
|
||||
i["intermediate_answer"]
|
||||
for i in response
|
||||
if "intermediate_answer" in i
|
||||
]
|
||||
)
|
||||
|
||||
follow_ups = [
|
||||
fu
|
||||
for i in response
|
||||
if "follow_up_queries" in i and i["follow_up_queries"]
|
||||
for fu in i["follow_up_queries"]
|
||||
]
|
||||
follow_ups = [fu for i in response for fu in i.get("follow_up_queries", [])]
|
||||
if len(follow_ups) == 0:
|
||||
error_msg = "No follow-up queries found in primer response. Ensure that the primer response includes follow-up queries."
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
@ -20,9 +20,7 @@ class QueryState:
|
||||
def __init__(self):
|
||||
self.graph = nx.MultiDiGraph()
|
||||
|
||||
def add_action(
|
||||
self, action: DriftAction, metadata: dict[str, Any] | None = None
|
||||
):
|
||||
def add_action(self, action: DriftAction, metadata: dict[str, Any] | None = None):
|
||||
"""Add an action to the graph with optional metadata."""
|
||||
self.graph.add_node(action, **(metadata or {}))
|
||||
|
||||
@ -64,13 +62,13 @@ class QueryState:
|
||||
for node in unanswered:
|
||||
node.compute_score(scorer)
|
||||
return sorted(
|
||||
unanswered,
|
||||
key=lambda node: node.score
|
||||
if node.score is not None
|
||||
else float("-inf"),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
unanswered,
|
||||
key=lambda node: node.score
|
||||
if node.score is not None
|
||||
else float("-inf"),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
# shuffle the list if no scorer
|
||||
random.shuffle(unanswered)
|
||||
return list(unanswered)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user