Formatting and ruff

This commit is contained in:
Alonso Guevara 2024-10-14 17:10:25 -06:00
parent 5afae11a0e
commit 60441985f2
6 changed files with 39 additions and 34 deletions

View File

@ -17,7 +17,6 @@ from graphrag.model import (
Relationship,
TextUnit,
)
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.oai.embedding import OpenAIEmbedding

View File

@ -86,7 +86,9 @@ class DriftAction:
generated_tokens = 0
else:
generated_tokens = num_tokens(self.answer, search_engine.token_encoder)
self.metadata.update({"token_ct": search_result.prompt_tokens + generated_tokens})
self.metadata.update({
"token_ct": search_result.prompt_tokens + generated_tokens
})
self.follow_ups = response.pop("follow_up_queries", [])
if not self.follow_ups:
@ -105,7 +107,9 @@ class DriftAction:
scorer (Any): The scorer to compute the score.
"""
score = scorer.compute_score(self.query, self.answer)
self.score = score if score is not None else float("-inf") # Default to -inf for sorting
self.score = (
score if score is not None else float("-inf")
) # Default to -inf for sorting
def serialize(self, include_follow_ups: bool = True) -> dict[str, Any]:
"""

View File

@ -76,7 +76,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
def init_local_context_builder(self) -> LocalSearchMixedContext:
"""
Initialize the local search mixed context builder.
Returns
-------
LocalSearchMixedContext: Initialized local context.
@ -97,7 +97,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
def convert_reports_to_df(reports: list[CommunityReport]) -> pd.DataFrame:
"""
Convert a list of CommunityReport objects to a pandas DataFrame.
Args
----
reports : list[CommunityReport]
@ -106,7 +106,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
Returns
-------
pd.DataFrame: DataFrame with report data.
Raises
------
ValueError: If some reports are missing full content or full content embeddings.
@ -114,7 +114,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
report_df = pd.DataFrame([report.__dict__ for report in reports])
missing_content_error = "Some reports are missing full content."
missing_embedding_error = "Some reports are missing full content embeddings."
if report_df["full_content"].isna().sum() > 0:
raise ValueError(missing_content_error)
if (
@ -128,14 +128,14 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
def check_query_doc_encodings(query_embedding: Any, embedding: Any) -> bool:
"""
Check if the embeddings are compatible.
Args
----
query_embedding : Any
Embedding of the query.
embedding : Any
Embedding to compare against.
Returns
-------
bool: True if embeddings match, otherwise False.
@ -149,7 +149,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
def build_context(self, query: str, **kwargs) -> pd.DataFrame:
"""
Build DRIFT search context.
Args
----
query : str
@ -158,14 +158,16 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
Returns
-------
pd.DataFrame: Top-k most similar documents.
Raises
------
ValueError: If no community reports are available, or embeddings
are incompatible.
"""
if self.reports is None:
missing_reports_error = "No community reports available. Please provide a list of reports."
missing_reports_error = (
"No community reports available. Please provide a list of reports."
)
raise ValueError(missing_reports_error)
query_processor = PrimerQueryProcessor(

View File

@ -61,13 +61,12 @@ class PrimerQueryProcessor:
tuple[str, int]: Expanded query text and the number of tokens used.
"""
token_ct = 0
template = secrets.choice(self.reports).full_content # nosec S311
template = secrets.choice(self.reports).full_content # nosec S311
prompt = f"""Create a hypothetical answer to the following query: {query}\n\n
Format it to follow the structure of the template below:\n\n
{template}\n"
Ensure that the hypothetical answer does not reference new named entities that are not present in the original query."""
messages = [{"role": "user", "content": prompt}]
@ -193,7 +192,11 @@ class DRIFTPrimer:
for i in range(primer_folds):
start_idx = i * num_reports // primer_folds
end_idx = num_reports if i == primer_folds - 1 else (i + 1) * num_reports // primer_folds
end_idx = (
num_reports
if i == primer_folds - 1
else (i + 1) * num_reports // primer_folds
)
fold = reports.iloc[start_idx:end_idx]
folds.append(fold)
return folds

View File

@ -124,16 +124,15 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
error_msg = "No intermediate answers found in primer response. Ensure that the primer response includes intermediate answers."
raise RuntimeError(error_msg)
intermediate_answer = "\n\n".join([
i["intermediate_answer"] for i in response if "intermediate_answer" in i
])
intermediate_answer = "\n\n".join(
[
i["intermediate_answer"]
for i in response
if "intermediate_answer" in i
]
)
follow_ups = [
fu
for i in response
if "follow_up_queries" in i and i["follow_up_queries"]
for fu in i["follow_up_queries"]
]
follow_ups = [fu for i in response for fu in i.get("follow_up_queries", [])]
if len(follow_ups) == 0:
error_msg = "No follow-up queries found in primer response. Ensure that the primer response includes follow-up queries."
raise RuntimeError(error_msg)

View File

@ -20,9 +20,7 @@ class QueryState:
def __init__(self):
self.graph = nx.MultiDiGraph()
def add_action(
self, action: DriftAction, metadata: dict[str, Any] | None = None
):
def add_action(self, action: DriftAction, metadata: dict[str, Any] | None = None):
"""Add an action to the graph with optional metadata."""
self.graph.add_node(action, **(metadata or {}))
@ -64,13 +62,13 @@ class QueryState:
for node in unanswered:
node.compute_score(scorer)
return sorted(
unanswered,
key=lambda node: node.score
if node.score is not None
else float("-inf"),
reverse=True,
)
unanswered,
key=lambda node: node.score
if node.score is not None
else float("-inf"),
reverse=True,
)
# shuffle the list if no scorer
random.shuffle(unanswered)
return list(unanswered)