Load query from blob (#1095)

* Moved query loading from file to helper function

* added loading parquets from blob to function

* resolved adlfs async error

* debugging cleanup and small fixes

* added connection string support

* semversioner and ruff fixes

* completed testing for merge with main

* more ruff changes

* fixed unbound vars warning

* rewrote function to use storage utils

* removed unused vars

---------

Co-authored-by: Kenny Zhang <zhangken@microsoft.com>
This commit is contained in:
KennyZhang1 2024-09-05 18:17:22 -04:00 committed by GitHub
parent 044516f538
commit 27c5468a8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 76 additions and 27 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "add querying from azure blob storage"
}

View File

@ -9,8 +9,14 @@ from pathlib import Path
import pandas as pd
from graphrag.config import load_config, resolve_path
from graphrag.config import (
GraphRagConfig,
load_config,
resolve_path,
)
from graphrag.index.create_pipeline_config import create_pipeline_config
from graphrag.index.progress import PrintProgressReporter
from graphrag.utils.storage import _create_storage, _load_table_from_storage
from . import api
@ -36,17 +42,21 @@ def run_global_search(
if data_dir:
config.storage.base_dir = str(resolve_path(data_dir, root))
data_path = Path(config.storage.base_dir).resolve()
final_nodes: pd.DataFrame = pd.read_parquet(
data_path / "create_final_nodes.parquet"
)
final_entities: pd.DataFrame = pd.read_parquet(
data_path / "create_final_entities.parquet"
)
final_community_reports: pd.DataFrame = pd.read_parquet(
data_path / "create_final_community_reports.parquet"
dataframe_dict = _resolve_parquet_files(
root_dir=root_dir,
config=config,
parquet_list=[
"create_final_nodes.parquet",
"create_final_entities.parquet",
"create_final_community_reports.parquet",
],
optional_list=[],
)
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
final_community_reports: pd.DataFrame = dataframe_dict[
"create_final_community_reports"
]
# call the Query API
if streaming:
@ -112,23 +122,26 @@ def run_local_search(
if data_dir:
config.storage.base_dir = str(resolve_path(data_dir, root))
data_path = Path(config.storage.base_dir).resolve()
final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet")
final_community_reports = pd.read_parquet(
data_path / "create_final_community_reports.parquet"
)
final_text_units = pd.read_parquet(data_path / "create_final_text_units.parquet")
final_relationships = pd.read_parquet(
data_path / "create_final_relationships.parquet"
)
final_entities = pd.read_parquet(data_path / "create_final_entities.parquet")
final_covariates_path = data_path / "create_final_covariates.parquet"
final_covariates = (
pd.read_parquet(final_covariates_path)
if final_covariates_path.exists()
else None
dataframe_dict = _resolve_parquet_files(
root_dir=root_dir,
config=config,
parquet_list=[
"create_final_nodes.parquet",
"create_final_community_reports.parquet",
"create_final_text_units.parquet",
"create_final_relationships.parquet",
"create_final_entities.parquet",
],
optional_list=["create_final_covariates.parquet"],
)
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
final_community_reports: pd.DataFrame = dataframe_dict[
"create_final_community_reports"
]
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]
final_relationships: pd.DataFrame = dataframe_dict["create_final_relationships"]
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
final_covariates: pd.DataFrame | None = dataframe_dict["create_final_covariates"]
# call the Query API
if streaming:
@ -179,3 +192,35 @@ def run_local_search(
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data
def _resolve_parquet_files(
root_dir: str,
config: GraphRagConfig,
parquet_list: list[str],
optional_list: list[str],
) -> dict[str, pd.DataFrame]:
"""Read parquet files to a dataframe dict."""
dataframe_dict = {}
pipeline_config = create_pipeline_config(config)
storage_obj = _create_storage(root_dir=root_dir, config=pipeline_config.storage)
for parquet_file in parquet_list:
df_key = parquet_file.split(".")[0]
df_value = asyncio.run(
_load_table_from_storage(name=parquet_file, storage=storage_obj)
)
dataframe_dict[df_key] = df_value
# for optional parquet files, set the dict entry to None instead of erroring out if it does not exist
for optional_file in optional_list:
file_exists = asyncio.run(storage_obj.has(optional_file))
df_key = optional_file.split(".")[0]
if file_exists:
df_value = asyncio.run(
_load_table_from_storage(name=optional_file, storage=storage_obj)
)
dataframe_dict[df_key] = df_value
else:
dataframe_dict[df_key] = None
return dataframe_dict