From 8cafb557697da9fc48ea8f55179ea7bc236d2dc9 Mon Sep 17 00:00:00 2001 From: Alonso Guevara Date: Mon, 14 Oct 2024 19:52:50 -0600 Subject: [PATCH] Ruff --- graphrag/query/structured_search/drift_search/action.py | 8 ++++---- graphrag/query/structured_search/drift_search/primer.py | 2 +- graphrag/query/structured_search/drift_search/state.py | 7 ++----- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/graphrag/query/structured_search/drift_search/action.py b/graphrag/query/structured_search/drift_search/action.py index 4514d4de..eb820603 100644 --- a/graphrag/query/structured_search/drift_search/action.py +++ b/graphrag/query/structured_search/drift_search/action.py @@ -86,9 +86,9 @@ class DriftAction: generated_tokens = 0 else: generated_tokens = num_tokens(self.answer, search_engine.token_encoder) - self.metadata.update({ - "token_ct": search_result.prompt_tokens + generated_tokens - }) + self.metadata.update( + {"token_ct": search_result.prompt_tokens + generated_tokens} + ) self.follow_ups = response.pop("follow_up_queries", []) if not self.follow_ups: @@ -203,7 +203,7 @@ class DriftAction: raise ValueError(error_message) except json.JSONDecodeError as e: error_message = f"Failed to parse response string: {e}. Parsed response must be a dictionary." - raise ValueError(error_message) + raise ValueError(error_message) from e error_message = f"Unsupported response type: {type(response).__name__}. Expected a dictionary or JSON string." raise ValueError(error_message) diff --git a/graphrag/query/structured_search/drift_search/primer.py b/graphrag/query/structured_search/drift_search/primer.py index 53b28bac..1a3d7b27 100644 --- a/graphrag/query/structured_search/drift_search/primer.py +++ b/graphrag/query/structured_search/drift_search/primer.py @@ -170,7 +170,7 @@ class DRIFTPrimer: return SearchResult( response=[response for response, _ in results_with_tokens], context_data={"top_k_reports": top_k_reports}, - context_text=top_k_reports.to_json(), + context_text=top_k_reports.to_json() or "", completion_time=completion_time, llm_calls=len(results_with_tokens), prompt_tokens=sum(tokens for _, tokens in results_with_tokens), diff --git a/graphrag/query/structured_search/drift_search/state.py b/graphrag/query/structured_search/drift_search/state.py index 21d594c4..17733f44 100644 --- a/graphrag/query/structured_search/drift_search/state.py +++ b/graphrag/query/structured_search/drift_search/state.py @@ -5,11 +5,8 @@ import logging import random -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from typing import Callable - +from collections.abc import Callable +from typing import Any import networkx as nx