mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
partially fixed merge conflicts
This commit is contained in:
parent
82548e11f8
commit
9ca67643b4
46
.semversioner/0.9.0.json
Normal file
46
.semversioner/0.9.0.json
Normal file
@ -0,0 +1,46 @@
|
||||
{
|
||||
"changes": [
|
||||
{
|
||||
"description": "Refactor graph creation.",
|
||||
"type": "minor"
|
||||
},
|
||||
{
|
||||
"description": "Dependency updates",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Fix Global Search with dynamic Community selection bug",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Fix question gen.",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Optimize Final Community Reports calculation and stabilize cache",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "miscellaneous code cleanup and minor changes for better alignment of style across the codebase.",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "replace llm package with fnllm",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "replaced md5 hash with sha256",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "replaced md5 hash with sha512",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "update API and add a demonstration notebook",
|
||||
"type": "patch"
|
||||
}
|
||||
],
|
||||
"created_at": "2024-12-06T20:12:30+00:00",
|
||||
"version": "0.9.0"
|
||||
}
|
||||
26
.semversioner/1.0.0.json
Normal file
26
.semversioner/1.0.0.json
Normal file
@ -0,0 +1,26 @@
|
||||
{
|
||||
"changes": [
|
||||
{
|
||||
"description": "Add Parent id to communities data model",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Add migration notebook.",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Create separate community workflow, collapse subflows.",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Dependency Updates",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "cleanup and refactor factory classes.",
|
||||
"type": "patch"
|
||||
}
|
||||
],
|
||||
"created_at": "2024-12-11T21:41:49+00:00",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Respect encoding_model option"
|
||||
}
|
||||
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Fix exception on error callbacks"
|
||||
}
|
||||
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Fix encoding model config parsing"
|
||||
}
|
||||
209
docs/examples_notebooks/api_overview.ipynb
Normal file
209
docs/examples_notebooks/api_overview.ipynb
Normal file
@ -0,0 +1,209 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Copyright (c) 2024 Microsoft Corporation.\n",
|
||||
"# Licensed under the MIT License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## API Overview\n",
|
||||
"\n",
|
||||
"This notebook provides a demonstration of how to interact with graphrag as a library using the API as opposed to the CLI. Note that graphrag's CLI actually connects to the library through this API for all operations. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import graphrag.api as api\n",
|
||||
"from graphrag.index.typing import PipelineRunResult"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Prerequisite\n",
|
||||
"As a prerequisite to all API operations, a `GraphRagConfig` object is required. It is the primary means to control the behavior of graphrag and can be instantiated from a `settings.yaml` configuration file.\n",
|
||||
"\n",
|
||||
"Please refer to the [CLI docs](https://microsoft.github.io/graphrag/cli/#init) for more detailed information on how to generate the `settings.yaml` file.\n",
|
||||
"\n",
|
||||
"#### Load `settings.yaml` configuration"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import yaml\n",
|
||||
"\n",
|
||||
"settings = yaml.safe_load(open(\"<project_directory>/settings.yaml\")) # noqa: PTH123, SIM115"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"At this point, you can modify the imported settings to align with your application's requirements. For example, if building a UI application, the application might need to change the input and/or storage destinations dynamically in order to enable users to build and query different indexes."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Generate a `GraphRagConfig` object"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from graphrag.config.create_graphrag_config import create_graphrag_config\n",
|
||||
"\n",
|
||||
"graphrag_config = create_graphrag_config(\n",
|
||||
" values=settings, root_dir=\"<project_directory>\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Indexing API\n",
|
||||
"\n",
|
||||
"*Indexing* is the process of ingesting raw text data and constructing a knowledge graph. GraphRAG currently supports plaintext (`.txt`) and `.csv` file formats."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Build an index"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"index_result: list[PipelineRunResult] = await api.build_index(config=graphrag_config)\n",
|
||||
"\n",
|
||||
"# index_result is a list of workflows that make up the indexing pipeline that was run\n",
|
||||
"for workflow_result in index_result:\n",
|
||||
" status = f\"error\\n{workflow_result.errors}\" if workflow_result.errors else \"success\"\n",
|
||||
" print(f\"Workflow Name: {workflow_result.workflow}\\tStatus: {status}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Query an index\n",
|
||||
"\n",
|
||||
"To query an index, several index files must first be read into memory and passed to the query API. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"final_nodes = pd.read_parquet(\"<project_directory>/output/create_final_nodes.parquet\")\n",
|
||||
"final_entities = pd.read_parquet(\n",
|
||||
" \"<project_directory>/output/create_final_entities.parquet\"\n",
|
||||
")\n",
|
||||
"final_communities = pd.read_parquet(\n",
|
||||
" \"<project_directory>/output/create_final_communities.parquet\"\n",
|
||||
")\n",
|
||||
"final_community_reports = pd.read_parquet(\n",
|
||||
" \"<project_directory>/output/create_final_community_reports.parquet\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response, context = await api.global_search(\n",
|
||||
" config=graphrag_config,\n",
|
||||
" nodes=final_nodes,\n",
|
||||
" entities=final_entities,\n",
|
||||
" communities=final_communities,\n",
|
||||
" community_reports=final_community_reports,\n",
|
||||
" community_level=2,\n",
|
||||
" dynamic_community_selection=False,\n",
|
||||
" response_type=\"Multiple Paragraphs\",\n",
|
||||
" query=\"Who is Scrooge and what are his main relationships?\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The response object is the official reponse from graphrag while the context object holds various metadata regarding the querying process used to obtain the final response."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Digging into the context a bit more provides users with extremely granular information such as what sources of data (down to the level of text chunks) were ultimately retrieved and used as part of the context sent to the LLM model)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pprint import pprint\n",
|
||||
"\n",
|
||||
"pprint(context) # noqa: T203"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "graphrag-venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
263
docs/examples_notebooks/index_migration.ipynb
Normal file
263
docs/examples_notebooks/index_migration.ipynb
Normal file
@ -0,0 +1,263 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Copyright (c) 2024 Microsoft Corporation.\n",
|
||||
"# Licensed under the MIT License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Index Migration\n",
|
||||
"\n",
|
||||
"This notebook is used to maintain data model parity with older indexes for the latest versions of GraphRAG. If you have a pre-1.0 index and need to migrate without re-running the entire pipeline, you can use this notebook to only update the pieces necessary for alignment.\n",
|
||||
"\n",
|
||||
"NOTE: we recommend regenerating your settings.yml with the latest version of GraphRAG using `graphrag init`. Copy your LLM settings into it before running this notebook. This ensures your config is aligned with the latest version for the migration. This also ensures that you have default vector store config, which is now required or indexing will fail.\n",
|
||||
"\n",
|
||||
"WARNING: This will overwrite your parquet files, you may want to make a backup!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This is the directory that has your settings.yml\n",
|
||||
"# NOTE: much older indexes may have been output with a timestamped directory\n",
|
||||
"# if this is the case, you will need to make sure the storage.base_dir in settings.yml points to it correctly\n",
|
||||
"PROJECT_DIRECTORY = \"<your project directory>\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"from graphrag.config.load_config import load_config\n",
|
||||
"from graphrag.config.resolve_path import resolve_paths\n",
|
||||
"from graphrag.index.create_pipeline_config import create_pipeline_config\n",
|
||||
"from graphrag.storage.factory import create_storage\n",
|
||||
"\n",
|
||||
"# This first block does some config loading, path resolution, and translation that is normally done by the CLI/API when running a full workflow\n",
|
||||
"config = load_config(Path(PROJECT_DIRECTORY))\n",
|
||||
"resolve_paths(config)\n",
|
||||
"pipeline_config = create_pipeline_config(config)\n",
|
||||
"storage = create_storage(pipeline_config.storage)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def remove_columns(df, columns):\n",
|
||||
" \"\"\"Remove columns from a DataFrame, suppressing errors.\"\"\"\n",
|
||||
" df.drop(labels=columns, axis=1, errors=\"ignore\", inplace=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 63,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_community_parent(nodes):\n",
|
||||
" \"\"\"Compute the parent community using the node membership as a lookup.\"\"\"\n",
|
||||
" parent_mapping = nodes.loc[:, [\"level\", \"community\", \"title\"]]\n",
|
||||
" nodes = nodes.loc[:, [\"level\", \"community\", \"title\"]]\n",
|
||||
"\n",
|
||||
" # Create a parent mapping by adding 1 to the level column\n",
|
||||
" parent_mapping[\"level\"] += 1 # Shift levels for parent relationship\n",
|
||||
" parent_mapping.rename(columns={\"community\": \"parent\"}, inplace=True)\n",
|
||||
"\n",
|
||||
" # Merge the parent information back into the base DataFrame\n",
|
||||
" nodes = nodes.merge(parent_mapping, on=[\"level\", \"title\"], how=\"left\")\n",
|
||||
"\n",
|
||||
" # Fill missing parents with -1 (default value)\n",
|
||||
" nodes[\"parent\"] = nodes[\"parent\"].fillna(-1).astype(int)\n",
|
||||
"\n",
|
||||
" join = (\n",
|
||||
" nodes.groupby([\"community\", \"level\", \"parent\"])\n",
|
||||
" .agg({\"title\": list})\n",
|
||||
" .reset_index()\n",
|
||||
" )\n",
|
||||
" return join[join[\"community\"] > -1].loc[:, [\"community\", \"parent\"]]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 64,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from uuid import uuid4\n",
|
||||
"\n",
|
||||
"from graphrag.utils.storage import load_table_from_storage, write_table_to_storage\n",
|
||||
"\n",
|
||||
"# First we'll go through any parquet files that had model changes and update them\n",
|
||||
"# The new data model may have removed excess columns as well, but we will only make the minimal changes required for compatibility\n",
|
||||
"\n",
|
||||
"final_documents = await load_table_from_storage(\n",
|
||||
" \"create_final_documents.parquet\", storage\n",
|
||||
")\n",
|
||||
"final_text_units = await load_table_from_storage(\n",
|
||||
" \"create_final_text_units.parquet\", storage\n",
|
||||
")\n",
|
||||
"final_entities = await load_table_from_storage(\"create_final_entities.parquet\", storage)\n",
|
||||
"final_nodes = await load_table_from_storage(\"create_final_nodes.parquet\", storage)\n",
|
||||
"final_relationships = await load_table_from_storage(\n",
|
||||
" \"create_final_relationships.parquet\", storage\n",
|
||||
")\n",
|
||||
"final_communities = await load_table_from_storage(\n",
|
||||
" \"create_final_communities.parquet\", storage\n",
|
||||
")\n",
|
||||
"final_community_reports = await load_table_from_storage(\n",
|
||||
" \"create_final_community_reports.parquet\", storage\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Documents renames raw_content for consistency\n",
|
||||
"if \"raw_content\" in final_documents.columns:\n",
|
||||
" final_documents.rename(columns={\"raw_content\": \"text\"}, inplace=True)\n",
|
||||
"final_documents[\"human_readable_id\"] = final_documents.index + 1\n",
|
||||
"\n",
|
||||
"# Text units just get a human_readable_id or consistency\n",
|
||||
"final_text_units[\"human_readable_id\"] = final_text_units.index + 1\n",
|
||||
"\n",
|
||||
"# We renamed \"name\" to \"title\" for consistency with the rest of the tables\n",
|
||||
"if \"name\" in final_entities.columns:\n",
|
||||
" final_entities.rename(columns={\"name\": \"title\"}, inplace=True)\n",
|
||||
"remove_columns(\n",
|
||||
" final_entities, [\"mname_embedding\", \"graph_embedding\", \"description_embedding\"]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Final nodes uses community for joins, which is now an int everywhere\n",
|
||||
"final_nodes[\"community\"] = final_nodes[\"community\"].fillna(-1)\n",
|
||||
"final_nodes[\"community\"] = final_nodes[\"community\"].astype(int)\n",
|
||||
"remove_columns(\n",
|
||||
" final_nodes,\n",
|
||||
" [\n",
|
||||
" \"type\",\n",
|
||||
" \"description\",\n",
|
||||
" \"source_id\",\n",
|
||||
" \"graph_embedding\",\n",
|
||||
" \"entity_type\",\n",
|
||||
" \"top_level_node_id\",\n",
|
||||
" \"size\",\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Relationships renames \"rank\" to \"combined_degree\" to be clear what the default ranking is\n",
|
||||
"if \"rank\" in final_relationships.columns:\n",
|
||||
" final_relationships.rename(columns={\"rank\": \"combined_degree\"}, inplace=True)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Compute the parents for each community, to add to communities and reports\n",
|
||||
"parent_df = get_community_parent(final_nodes)\n",
|
||||
"\n",
|
||||
"# Communities previously used the \"id\" field for the Leiden id, but we've moved this to the community field and use a uuid for id like the others\n",
|
||||
"if \"community\" not in final_communities.columns:\n",
|
||||
" final_communities[\"community\"] = final_communities[\"id\"].astype(int)\n",
|
||||
" final_communities[\"human_readable_id\"] = final_communities[\"community\"]\n",
|
||||
" final_communities[\"id\"] = [str(uuid4()) for _ in range(len(final_communities))]\n",
|
||||
"if \"parent\" not in final_communities.columns:\n",
|
||||
" final_communities = final_communities.merge(parent_df, on=\"community\", how=\"left\")\n",
|
||||
"remove_columns(final_communities, [\"raw_community\"])\n",
|
||||
"\n",
|
||||
"# We need int for community and the human_readable_id copy for consistency\n",
|
||||
"final_community_reports[\"community\"] = final_community_reports[\"community\"].astype(int)\n",
|
||||
"final_community_reports[\"human_readable_id\"] = final_community_reports[\"community\"]\n",
|
||||
"if \"parent\" not in final_community_reports.columns:\n",
|
||||
" final_community_reports = final_community_reports.merge(\n",
|
||||
" parent_df, on=\"community\", how=\"left\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"await write_table_to_storage(final_documents, \"create_final_documents.parquet\", storage)\n",
|
||||
"await write_table_to_storage(\n",
|
||||
" final_text_units, \"create_final_text_units.parquet\", storage\n",
|
||||
")\n",
|
||||
"await write_table_to_storage(final_entities, \"create_final_entities.parquet\", storage)\n",
|
||||
"await write_table_to_storage(final_nodes, \"create_final_nodes.parquet\", storage)\n",
|
||||
"await write_table_to_storage(\n",
|
||||
" final_relationships, \"create_final_relationships.parquet\", storage\n",
|
||||
")\n",
|
||||
"await write_table_to_storage(\n",
|
||||
" final_communities, \"create_final_communities.parquet\", storage\n",
|
||||
")\n",
|
||||
"await write_table_to_storage(\n",
|
||||
" final_community_reports, \"create_final_community_reports.parquet\", storage\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datashaper import NoopVerbCallbacks\n",
|
||||
"\n",
|
||||
"from graphrag.cache.factory import create_cache\n",
|
||||
"from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n",
|
||||
"\n",
|
||||
"# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n",
|
||||
"# We'll construct the context and run this function flow directly to avoid everything else\n",
|
||||
"\n",
|
||||
"workflow = next(\n",
|
||||
" (x for x in pipeline_config.workflows if x.name == \"generate_text_embeddings\"), None\n",
|
||||
")\n",
|
||||
"config = workflow.config\n",
|
||||
"text_embed = config.get(\"text_embed\", {})\n",
|
||||
"embedded_fields = config.get(\"embedded_fields\", {})\n",
|
||||
"callbacks = NoopVerbCallbacks()\n",
|
||||
"cache = create_cache(pipeline_config.cache, PROJECT_DIRECTORY)\n",
|
||||
"\n",
|
||||
"await generate_text_embeddings(\n",
|
||||
" final_documents=None,\n",
|
||||
" final_relationships=None,\n",
|
||||
" final_text_units=final_text_units,\n",
|
||||
" final_entities=final_entities,\n",
|
||||
" final_community_reports=final_community_reports,\n",
|
||||
" callbacks=callbacks,\n",
|
||||
" cache=cache,\n",
|
||||
" storage=storage,\n",
|
||||
" text_embed_config=text_embed,\n",
|
||||
" embedded_fields=embedded_fields,\n",
|
||||
" snapshot_embeddings_enabled=False,\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
4
graphrag/cache/__init__.py
vendored
Normal file
4
graphrag/cache/__init__.py
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A package containing cache implementations."""
|
||||
57
graphrag/cache/factory.py
vendored
Normal file
57
graphrag/cache/factory.py
vendored
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing create_cache method definition."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from graphrag.config.enums import CacheType
|
||||
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
|
||||
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
|
||||
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
|
||||
from graphrag.cache.memory_pipeline_cache import InMemoryCache
|
||||
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
|
||||
|
||||
|
||||
class CacheFactory:
|
||||
"""A factory class for cache implementations.
|
||||
|
||||
Includes a method for users to register a custom cache implementation.
|
||||
"""
|
||||
|
||||
cache_types: ClassVar[dict[str, type]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, cache_type: str, cache: type):
|
||||
"""Register a custom cache implementation."""
|
||||
cls.cache_types[cache_type] = cache
|
||||
|
||||
@classmethod
|
||||
def create_cache(
|
||||
cls, cache_type: CacheType | str | None, root_dir: str, kwargs: dict
|
||||
) -> PipelineCache:
|
||||
"""Create or get a cache from the provided type."""
|
||||
if not cache_type:
|
||||
return NoopPipelineCache()
|
||||
match cache_type:
|
||||
case CacheType.none:
|
||||
return NoopPipelineCache()
|
||||
case CacheType.memory:
|
||||
return InMemoryCache()
|
||||
case CacheType.file:
|
||||
return JsonPipelineCache(
|
||||
FilePipelineStorage(root_dir=root_dir).child(kwargs["base_dir"])
|
||||
)
|
||||
case CacheType.blob:
|
||||
return JsonPipelineCache(BlobPipelineStorage(**kwargs))
|
||||
case _:
|
||||
if cache_type in cls.cache_types:
|
||||
return cls.cache_types[cache_type](**kwargs)
|
||||
msg = f"Unknown cache type: {cache_type}"
|
||||
raise ValueError(msg)
|
||||
65
graphrag/cache/json_pipeline_cache.py
vendored
Normal file
65
graphrag/cache/json_pipeline_cache.py
vendored
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'JsonPipelineCache' model."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
|
||||
class JsonPipelineCache(PipelineCache):
|
||||
"""File pipeline cache class definition."""
|
||||
|
||||
_storage: PipelineStorage
|
||||
_encoding: str
|
||||
|
||||
def __init__(self, storage: PipelineStorage, encoding="utf-8"):
|
||||
"""Init method definition."""
|
||||
self._storage = storage
|
||||
self._encoding = encoding
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
"""Get method definition."""
|
||||
if await self.has(key):
|
||||
try:
|
||||
data = await self._storage.get(key, encoding=self._encoding)
|
||||
data = json.loads(data)
|
||||
except UnicodeDecodeError:
|
||||
await self._storage.delete(key)
|
||||
return None
|
||||
except json.decoder.JSONDecodeError:
|
||||
await self._storage.delete(key)
|
||||
return None
|
||||
else:
|
||||
return data.get("result")
|
||||
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: Any, debug_data: dict | None = None) -> None:
|
||||
"""Set method definition."""
|
||||
if value is None:
|
||||
return
|
||||
data = {"result": value, **(debug_data or {})}
|
||||
await self._storage.set(
|
||||
key, json.dumps(data, ensure_ascii=False), encoding=self._encoding
|
||||
)
|
||||
|
||||
async def has(self, key: str) -> bool:
|
||||
"""Has method definition."""
|
||||
return await self._storage.has(key)
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
"""Delete method definition."""
|
||||
if await self.has(key):
|
||||
await self._storage.delete(key)
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear method definition."""
|
||||
await self._storage.clear()
|
||||
|
||||
def child(self, name: str) -> "JsonPipelineCache":
|
||||
"""Child method definition."""
|
||||
return JsonPipelineCache(self._storage.child(name), encoding=self._encoding)
|
||||
78
graphrag/cache/memory_pipeline_cache.py
vendored
Normal file
78
graphrag/cache/memory_pipeline_cache.py
vendored
Normal file
@ -0,0 +1,78 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'InMemoryCache' model."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
|
||||
|
||||
class InMemoryCache(PipelineCache):
|
||||
"""In memory cache class definition."""
|
||||
|
||||
_cache: dict[str, Any]
|
||||
_name: str
|
||||
|
||||
def __init__(self, name: str | None = None):
|
||||
"""Init method definition."""
|
||||
self._cache = {}
|
||||
self._name = name or ""
|
||||
|
||||
async def get(self, key: str) -> Any:
|
||||
"""Get the value for the given key.
|
||||
|
||||
Args:
|
||||
- key - The key to get the value for.
|
||||
- as_bytes - Whether or not to return the value as bytes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - The value for the given key.
|
||||
"""
|
||||
key = self._create_cache_key(key)
|
||||
return self._cache.get(key)
|
||||
|
||||
async def set(self, key: str, value: Any, debug_data: dict | None = None) -> None:
|
||||
"""Set the value for the given key.
|
||||
|
||||
Args:
|
||||
- key - The key to set the value for.
|
||||
- value - The value to set.
|
||||
"""
|
||||
key = self._create_cache_key(key)
|
||||
self._cache[key] = value
|
||||
|
||||
async def has(self, key: str) -> bool:
|
||||
"""Return True if the given key exists in the storage.
|
||||
|
||||
Args:
|
||||
- key - The key to check for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - True if the key exists in the storage, False otherwise.
|
||||
"""
|
||||
key = self._create_cache_key(key)
|
||||
return key in self._cache
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
"""Delete the given key from the storage.
|
||||
|
||||
Args:
|
||||
- key - The key to delete.
|
||||
"""
|
||||
key = self._create_cache_key(key)
|
||||
del self._cache[key]
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the storage."""
|
||||
self._cache.clear()
|
||||
|
||||
def child(self, name: str) -> PipelineCache:
|
||||
"""Create a sub cache with the given name."""
|
||||
return InMemoryCache(name)
|
||||
|
||||
def _create_cache_key(self, key: str) -> str:
|
||||
"""Create a cache key for the given key."""
|
||||
return f"{self._name}{key}"
|
||||
65
graphrag/cache/noop_pipeline_cache.py
vendored
Normal file
65
graphrag/cache/noop_pipeline_cache.py
vendored
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Module containing the NoopPipelineCache implementation."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
|
||||
|
||||
class NoopPipelineCache(PipelineCache):
|
||||
"""A no-op implementation of the pipeline cache, usually useful for testing."""
|
||||
|
||||
async def get(self, key: str) -> Any:
|
||||
"""Get the value for the given key.
|
||||
|
||||
Args:
|
||||
- key - The key to get the value for.
|
||||
- as_bytes - Whether or not to return the value as bytes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - The value for the given key.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def set(
|
||||
self, key: str, value: str | bytes | None, debug_data: dict | None = None
|
||||
) -> None:
|
||||
"""Set the value for the given key.
|
||||
|
||||
Args:
|
||||
- key - The key to set the value for.
|
||||
- value - The value to set.
|
||||
"""
|
||||
|
||||
async def has(self, key: str) -> bool:
|
||||
"""Return True if the given key exists in the cache.
|
||||
|
||||
Args:
|
||||
- key - The key to check for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - True if the key exists in the cache, False otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
"""Delete the given key from the cache.
|
||||
|
||||
Args:
|
||||
- key - The key to delete.
|
||||
"""
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the cache."""
|
||||
|
||||
def child(self, name: str) -> PipelineCache:
|
||||
"""Create a child cache with the given name.
|
||||
|
||||
Args:
|
||||
- name - The name to create the sub cache with.
|
||||
"""
|
||||
return self
|
||||
67
graphrag/cache/pipeline_cache.py
vendored
Normal file
67
graphrag/cache/pipeline_cache.py
vendored
Normal file
@ -0,0 +1,67 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'PipelineCache' model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class PipelineCache(metaclass=ABCMeta):
|
||||
"""Provide a cache interface for the pipeline."""
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> Any:
|
||||
"""Get the value for the given key.
|
||||
|
||||
Args:
|
||||
- key - The key to get the value for.
|
||||
- as_bytes - Whether or not to return the value as bytes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - The value for the given key.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def set(self, key: str, value: Any, debug_data: dict | None = None) -> None:
|
||||
"""Set the value for the given key.
|
||||
|
||||
Args:
|
||||
- key - The key to set the value for.
|
||||
- value - The value to set.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def has(self, key: str) -> bool:
|
||||
"""Return True if the given key exists in the cache.
|
||||
|
||||
Args:
|
||||
- key - The key to check for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - True if the key exists in the cache, False otherwise.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, key: str) -> None:
|
||||
"""Delete the given key from the cache.
|
||||
|
||||
Args:
|
||||
- key - The key to delete.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> None:
|
||||
"""Clear the cache."""
|
||||
|
||||
@abstractmethod
|
||||
def child(self, name: str) -> PipelineCache:
|
||||
"""Create a child cache with the given name.
|
||||
|
||||
Args:
|
||||
- name - The name to create the sub cache with.
|
||||
"""
|
||||
46
graphrag/callbacks/factory.py
Normal file
46
graphrag/callbacks/factory.py
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Create a pipeline logger."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from datashaper import WorkflowCallbacks
|
||||
|
||||
from graphrag.callbacks.blob_workflow_callbacks import BlobWorkflowCallbacks
|
||||
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
|
||||
from graphrag.callbacks.file_workflow_callbacks import FileWorkflowCallbacks
|
||||
from graphrag.config.enums import ReportingType
|
||||
from graphrag.index.config.reporting import (
|
||||
PipelineBlobReportingConfig,
|
||||
PipelineFileReportingConfig,
|
||||
PipelineReportingConfig,
|
||||
)
|
||||
|
||||
|
||||
def create_pipeline_reporter(
|
||||
config: PipelineReportingConfig | None, root_dir: str | None
|
||||
) -> WorkflowCallbacks:
|
||||
"""Create a logger for the given pipeline config."""
|
||||
config = config or PipelineFileReportingConfig(base_dir="logs")
|
||||
|
||||
match config.type:
|
||||
case ReportingType.file:
|
||||
config = cast("PipelineFileReportingConfig", config)
|
||||
return FileWorkflowCallbacks(
|
||||
str(Path(root_dir or "") / (config.base_dir or ""))
|
||||
)
|
||||
case ReportingType.console:
|
||||
return ConsoleWorkflowCallbacks()
|
||||
case ReportingType.blob:
|
||||
config = cast("PipelineBlobReportingConfig", config)
|
||||
return BlobWorkflowCallbacks(
|
||||
config.connection_string,
|
||||
config.container_name,
|
||||
base_dir=config.base_dir,
|
||||
storage_account_blob_url=config.storage_account_blob_url,
|
||||
)
|
||||
case _:
|
||||
msg = f"Unknown reporting type: {config.type}"
|
||||
raise ValueError(msg)
|
||||
55
graphrag/index/exporter.py
Normal file
55
graphrag/index/exporter.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""ParquetExporter module."""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
import pandas as pd
|
||||
from pyarrow.lib import ArrowInvalid, ArrowTypeError
|
||||
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ParquetExporter:
|
||||
"""ParquetExporter class.
|
||||
|
||||
A class that exports dataframe's to a storage destination in .parquet file format.
|
||||
"""
|
||||
|
||||
_storage: PipelineStorage
|
||||
_on_error: ErrorHandlerFn
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: PipelineStorage,
|
||||
on_error: ErrorHandlerFn,
|
||||
):
|
||||
"""Create a new Parquet Table TableExporter."""
|
||||
self._storage = storage
|
||||
self._on_error = on_error
|
||||
|
||||
async def export(self, name: str, data: pd.DataFrame) -> None:
|
||||
"""Export dataframe to storage."""
|
||||
filename = f"{name}.parquet"
|
||||
log.info("exporting parquet table %s", filename)
|
||||
try:
|
||||
await self._storage.set(filename, data.to_parquet())
|
||||
except ArrowTypeError as e:
|
||||
log.exception("Error while exporting parquet table")
|
||||
self._on_error(
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
None,
|
||||
)
|
||||
except ArrowInvalid as e:
|
||||
log.exception("Error while exporting parquet table")
|
||||
self._on_error(
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
None,
|
||||
)
|
||||
43
graphrag/index/flows/compute_communities.py
Normal file
43
graphrag/index/flows/compute_communities.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All the steps to create the base entity graph."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.index.operations.cluster_graph import cluster_graph
|
||||
from graphrag.index.operations.create_graph import create_graph
|
||||
from graphrag.index.operations.snapshot import snapshot
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
|
||||
async def compute_communities(
|
||||
base_relationship_edges: pd.DataFrame,
|
||||
storage: PipelineStorage,
|
||||
clustering_strategy: dict[str, Any],
|
||||
snapshot_transient_enabled: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to create the base entity graph."""
|
||||
graph = create_graph(base_relationship_edges)
|
||||
|
||||
communities = cluster_graph(
|
||||
graph,
|
||||
strategy=clustering_strategy,
|
||||
)
|
||||
|
||||
base_communities = pd.DataFrame(
|
||||
communities, columns=pd.Index(["level", "community", "parent", "title"])
|
||||
).explode("title")
|
||||
base_communities["community"] = base_communities["community"].astype(int)
|
||||
|
||||
if snapshot_transient_enabled:
|
||||
await snapshot(
|
||||
base_communities,
|
||||
name="base_communities",
|
||||
storage=storage,
|
||||
formats=["parquet"],
|
||||
)
|
||||
|
||||
return base_communities
|
||||
147
graphrag/index/flows/extract_graph.py
Normal file
147
graphrag/index/flows/extract_graph.py
Normal file
@ -0,0 +1,147 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All the steps to create the base entity graph."""
|
||||
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
AsyncType,
|
||||
VerbCallbacks,
|
||||
)
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.index.operations.create_graph import create_graph
|
||||
from graphrag.index.operations.extract_entities import extract_entities
|
||||
from graphrag.index.operations.snapshot import snapshot
|
||||
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
|
||||
from graphrag.index.operations.summarize_descriptions import (
|
||||
summarize_descriptions,
|
||||
)
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
|
||||
async def extract_graph(
|
||||
text_units: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
storage: PipelineStorage,
|
||||
extraction_strategy: dict[str, Any] | None = None,
|
||||
extraction_num_threads: int = 4,
|
||||
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
|
||||
entity_types: list[str] | None = None,
|
||||
summarization_strategy: dict[str, Any] | None = None,
|
||||
summarization_num_threads: int = 4,
|
||||
snapshot_graphml_enabled: bool = False,
|
||||
snapshot_transient_enabled: bool = False,
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
"""All the steps to create the base entity graph."""
|
||||
# this returns a graph for each text unit, to be merged later
|
||||
entity_dfs, relationship_dfs = await extract_entities(
|
||||
text_units,
|
||||
callbacks,
|
||||
cache,
|
||||
text_column="text",
|
||||
id_column="id",
|
||||
strategy=extraction_strategy,
|
||||
async_mode=extraction_async_mode,
|
||||
entity_types=entity_types,
|
||||
num_threads=extraction_num_threads,
|
||||
)
|
||||
|
||||
merged_entities = _merge_entities(entity_dfs)
|
||||
merged_relationships = _merge_relationships(relationship_dfs)
|
||||
|
||||
entity_summaries, relationship_summaries = await summarize_descriptions(
|
||||
merged_entities,
|
||||
merged_relationships,
|
||||
callbacks,
|
||||
cache,
|
||||
strategy=summarization_strategy,
|
||||
num_threads=summarization_num_threads,
|
||||
)
|
||||
|
||||
base_relationship_edges = _prep_edges(merged_relationships, relationship_summaries)
|
||||
|
||||
graph = create_graph(base_relationship_edges)
|
||||
|
||||
base_entity_nodes = _prep_nodes(merged_entities, entity_summaries, graph)
|
||||
|
||||
if snapshot_graphml_enabled:
|
||||
# todo: extract graphs at each level, and add in meta like descriptions
|
||||
await snapshot_graphml(
|
||||
graph,
|
||||
name="graph",
|
||||
storage=storage,
|
||||
)
|
||||
|
||||
if snapshot_transient_enabled:
|
||||
await snapshot(
|
||||
base_entity_nodes,
|
||||
name="base_entity_nodes",
|
||||
storage=storage,
|
||||
formats=["parquet"],
|
||||
)
|
||||
await snapshot(
|
||||
base_relationship_edges,
|
||||
name="base_relationship_edges",
|
||||
storage=storage,
|
||||
formats=["parquet"],
|
||||
)
|
||||
|
||||
return (base_entity_nodes, base_relationship_edges)
|
||||
|
||||
|
||||
def _merge_entities(entity_dfs) -> pd.DataFrame:
|
||||
all_entities = pd.concat(entity_dfs, ignore_index=True)
|
||||
return (
|
||||
all_entities.groupby(["name", "type"], sort=False)
|
||||
.agg({"description": list, "source_id": list})
|
||||
.reset_index()
|
||||
)
|
||||
|
||||
|
||||
def _merge_relationships(relationship_dfs) -> pd.DataFrame:
|
||||
all_relationships = pd.concat(relationship_dfs, ignore_index=False)
|
||||
return (
|
||||
all_relationships.groupby(["source", "target"], sort=False)
|
||||
.agg({"description": list, "source_id": list, "weight": "sum"})
|
||||
.reset_index()
|
||||
)
|
||||
|
||||
|
||||
def _prep_nodes(entities, summaries, graph) -> pd.DataFrame:
|
||||
degrees_df = _compute_degree(graph)
|
||||
entities.drop(columns=["description"], inplace=True)
|
||||
nodes = (
|
||||
entities.merge(summaries, on="name", how="left")
|
||||
.merge(degrees_df, on="name")
|
||||
.drop_duplicates(subset="name")
|
||||
.rename(columns={"name": "title", "source_id": "text_unit_ids"})
|
||||
)
|
||||
nodes = nodes.loc[nodes["title"].notna()].reset_index()
|
||||
nodes["human_readable_id"] = nodes.index
|
||||
nodes["id"] = nodes["human_readable_id"].apply(lambda _x: str(uuid4()))
|
||||
return nodes
|
||||
|
||||
|
||||
def _prep_edges(relationships, summaries) -> pd.DataFrame:
|
||||
edges = (
|
||||
relationships.drop(columns=["description"])
|
||||
.drop_duplicates(subset=["source", "target"])
|
||||
.merge(summaries, on=["source", "target"], how="left")
|
||||
.rename(columns={"source_id": "text_unit_ids"})
|
||||
)
|
||||
edges["human_readable_id"] = edges.index
|
||||
edges["id"] = edges["human_readable_id"].apply(lambda _x: str(uuid4()))
|
||||
return edges
|
||||
|
||||
|
||||
def _compute_degree(graph: nx.Graph) -> pd.DataFrame:
|
||||
return pd.DataFrame([
|
||||
{"name": node, "degree": int(degree)}
|
||||
for node, degree in graph.degree # type: ignore
|
||||
])
|
||||
80
graphrag/index/input/factory.py
Normal file
80
graphrag/index/input/factory.py
Normal file
@ -0,0 +1,80 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing create_input method definition."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.enums import InputType
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.config.input import PipelineInputConfig
|
||||
from graphrag.index.input.csv import input_type as csv
|
||||
from graphrag.index.input.csv import load as load_csv
|
||||
from graphrag.index.input.text import input_type as text
|
||||
from graphrag.index.input.text import load as load_text
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.null_progress import NullProgressLogger
|
||||
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
|
||||
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = {
|
||||
text: load_text,
|
||||
csv: load_csv,
|
||||
}
|
||||
|
||||
|
||||
async def create_input(
|
||||
config: PipelineInputConfig | InputConfig,
|
||||
progress_reporter: ProgressLogger | None = None,
|
||||
root_dir: str | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Instantiate input data for a pipeline."""
|
||||
root_dir = root_dir or ""
|
||||
log.info("loading input from root_dir=%s", config.base_dir)
|
||||
progress_reporter = progress_reporter or NullProgressLogger()
|
||||
|
||||
match config.type:
|
||||
case InputType.blob:
|
||||
log.info("using blob storage input")
|
||||
if config.container_name is None:
|
||||
msg = "Container name required for blob storage"
|
||||
raise ValueError(msg)
|
||||
if (
|
||||
config.connection_string is None
|
||||
and config.storage_account_blob_url is None
|
||||
):
|
||||
msg = "Connection string or storage account blob url required for blob storage"
|
||||
raise ValueError(msg)
|
||||
storage = BlobPipelineStorage(
|
||||
connection_string=config.connection_string,
|
||||
storage_account_blob_url=config.storage_account_blob_url,
|
||||
container_name=config.container_name,
|
||||
path_prefix=config.base_dir,
|
||||
)
|
||||
case InputType.file:
|
||||
log.info("using file storage for input")
|
||||
storage = FilePipelineStorage(
|
||||
root_dir=str(Path(root_dir) / (config.base_dir or ""))
|
||||
)
|
||||
case _:
|
||||
log.info("using file storage for input")
|
||||
storage = FilePipelineStorage(
|
||||
root_dir=str(Path(root_dir) / (config.base_dir or ""))
|
||||
)
|
||||
|
||||
if config.file_type in loaders:
|
||||
progress = progress_reporter.child(
|
||||
f"Loading Input ({config.file_type})", transient=False
|
||||
)
|
||||
loader = loaders[config.file_type]
|
||||
results = await loader(config, progress, storage)
|
||||
return cast("pd.DataFrame", results)
|
||||
|
||||
msg = f"Unknown input type {config.file_type}"
|
||||
raise ValueError(msg)
|
||||
49
graphrag/index/llm/mock_llm.py
Normal file
49
graphrag/index/llm/mock_llm.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""A mock LLM that returns the given responses."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
from fnllm import ChatLLM, LLMInput, LLMOutput
|
||||
from fnllm.types.generics import THistoryEntry, TJsonModel, TModelParameters
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Unpack
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentResponse:
|
||||
"""A mock content-only response."""
|
||||
|
||||
content: str
|
||||
|
||||
|
||||
class MockChatLLM(ChatLLM):
|
||||
"""A mock LLM that returns the given responses."""
|
||||
|
||||
def __init__(self, responses: list[str | BaseModel], json: bool = False):
|
||||
self.responses = responses
|
||||
self.response_index = 0
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
**kwargs: Unpack[LLMInput[TJsonModel, THistoryEntry, TModelParameters]],
|
||||
) -> LLMOutput[Any, TJsonModel, THistoryEntry]:
|
||||
"""Return the next response in the list."""
|
||||
response = self.responses[self.response_index % len(self.responses)]
|
||||
self.response_index += 1
|
||||
|
||||
parsed_json = response if isinstance(response, BaseModel) else None
|
||||
response = (
|
||||
response.model_dump_json() if isinstance(response, BaseModel) else response
|
||||
)
|
||||
|
||||
return LLMOutput(
|
||||
output=ContentResponse(content=response),
|
||||
parsed_json=cast("TJsonModel", parsed_json),
|
||||
)
|
||||
|
||||
def child(self, name):
|
||||
"""Return a new mock LLM."""
|
||||
raise NotImplementedError
|
||||
12
graphrag/index/operations/create_graph.py
Normal file
12
graphrag/index/operations/create_graph.py
Normal file
@ -0,0 +1,12 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing create_graph definition."""
|
||||
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def create_graph(edges_df: pd.DataFrame) -> nx.Graph:
|
||||
"""Create a networkx graph from nodes and edges dataframes."""
|
||||
return nx.from_pandas_edgelist(edges_df)
|
||||
74
graphrag/index/workflows/v1/compute_communities.py
Normal file
74
graphrag/index/workflows/v1/compute_communities.py
Normal file
@ -0,0 +1,74 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing build_steps method definition."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
Table,
|
||||
verb,
|
||||
)
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep
|
||||
from graphrag.index.flows.compute_communities import compute_communities
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
workflow_name = "compute_communities"
|
||||
|
||||
|
||||
def build_steps(
|
||||
config: PipelineWorkflowConfig,
|
||||
) -> list[PipelineWorkflowStep]:
|
||||
"""
|
||||
Create the base communities from the graph edges.
|
||||
|
||||
## Dependencies
|
||||
* `workflow:extract_graph`
|
||||
"""
|
||||
clustering_config = config.get(
|
||||
"cluster_graph",
|
||||
{"strategy": {"type": "leiden"}},
|
||||
)
|
||||
clustering_strategy = clustering_config.get("strategy")
|
||||
|
||||
snapshot_transient = config.get("snapshot_transient", False) or False
|
||||
|
||||
return [
|
||||
{
|
||||
"verb": workflow_name,
|
||||
"args": {
|
||||
"clustering_strategy": clustering_strategy,
|
||||
"snapshot_transient_enabled": snapshot_transient,
|
||||
},
|
||||
"input": ({"source": "workflow:extract_graph"}),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@verb(
|
||||
name=workflow_name,
|
||||
treats_input_tables_as_immutable=True,
|
||||
)
|
||||
async def workflow(
|
||||
storage: PipelineStorage,
|
||||
runtime_storage: PipelineStorage,
|
||||
clustering_strategy: dict[str, Any],
|
||||
snapshot_transient_enabled: bool = False,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to create the base entity graph."""
|
||||
base_relationship_edges = await runtime_storage.get("base_relationship_edges")
|
||||
|
||||
base_communities = await compute_communities(
|
||||
base_relationship_edges,
|
||||
storage,
|
||||
clustering_strategy=clustering_strategy,
|
||||
snapshot_transient_enabled=snapshot_transient_enabled,
|
||||
)
|
||||
|
||||
await runtime_storage.set("base_communities", base_communities)
|
||||
|
||||
return create_verb_result(cast("Table", pd.DataFrame()))
|
||||
107
graphrag/index/workflows/v1/extract_graph.py
Normal file
107
graphrag/index/workflows/v1/extract_graph.py
Normal file
@ -0,0 +1,107 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing build_steps method definition."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
AsyncType,
|
||||
Table,
|
||||
VerbCallbacks,
|
||||
verb,
|
||||
)
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep
|
||||
from graphrag.index.flows.extract_graph import (
|
||||
extract_graph,
|
||||
)
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
workflow_name = "extract_graph"
|
||||
|
||||
|
||||
def build_steps(
|
||||
config: PipelineWorkflowConfig,
|
||||
) -> list[PipelineWorkflowStep]:
|
||||
"""
|
||||
Create the base table for the entity graph.
|
||||
|
||||
## Dependencies
|
||||
* `workflow:create_base_text_units`
|
||||
"""
|
||||
entity_extraction_config = config.get("entity_extract", {})
|
||||
async_mode = entity_extraction_config.get("async_mode", AsyncType.AsyncIO)
|
||||
extraction_strategy = entity_extraction_config.get("strategy")
|
||||
extraction_num_threads = entity_extraction_config.get("num_threads", 4)
|
||||
entity_types = entity_extraction_config.get("entity_types")
|
||||
|
||||
summarize_descriptions_config = config.get("summarize_descriptions", {})
|
||||
summarization_strategy = summarize_descriptions_config.get("strategy")
|
||||
summarization_num_threads = summarize_descriptions_config.get("num_threads", 4)
|
||||
|
||||
snapshot_graphml = config.get("snapshot_graphml", False) or False
|
||||
snapshot_transient = config.get("snapshot_transient", False) or False
|
||||
|
||||
return [
|
||||
{
|
||||
"verb": workflow_name,
|
||||
"args": {
|
||||
"extraction_strategy": extraction_strategy,
|
||||
"extraction_num_threads": extraction_num_threads,
|
||||
"extraction_async_mode": async_mode,
|
||||
"entity_types": entity_types,
|
||||
"summarization_strategy": summarization_strategy,
|
||||
"summarization_num_threads": summarization_num_threads,
|
||||
"snapshot_graphml_enabled": snapshot_graphml,
|
||||
"snapshot_transient_enabled": snapshot_transient,
|
||||
},
|
||||
"input": ({"source": "workflow:create_base_text_units"}),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@verb(
|
||||
name=workflow_name,
|
||||
treats_input_tables_as_immutable=True,
|
||||
)
|
||||
async def workflow(
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
storage: PipelineStorage,
|
||||
runtime_storage: PipelineStorage,
|
||||
extraction_strategy: dict[str, Any] | None,
|
||||
extraction_num_threads: int = 4,
|
||||
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
|
||||
entity_types: list[str] | None = None,
|
||||
summarization_strategy: dict[str, Any] | None = None,
|
||||
summarization_num_threads: int = 4,
|
||||
snapshot_graphml_enabled: bool = False,
|
||||
snapshot_transient_enabled: bool = False,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to create the base entity graph."""
|
||||
text_units = await runtime_storage.get("base_text_units")
|
||||
|
||||
base_entity_nodes, base_relationship_edges = await extract_graph(
|
||||
text_units,
|
||||
callbacks,
|
||||
cache,
|
||||
storage,
|
||||
extraction_strategy=extraction_strategy,
|
||||
extraction_num_threads=extraction_num_threads,
|
||||
extraction_async_mode=extraction_async_mode,
|
||||
entity_types=entity_types,
|
||||
summarization_strategy=summarization_strategy,
|
||||
summarization_num_threads=summarization_num_threads,
|
||||
snapshot_graphml_enabled=snapshot_graphml_enabled,
|
||||
snapshot_transient_enabled=snapshot_transient_enabled,
|
||||
)
|
||||
|
||||
await runtime_storage.set("base_entity_nodes", base_entity_nodes)
|
||||
await runtime_storage.set("base_relationship_edges", base_relationship_edges)
|
||||
|
||||
return create_verb_result(cast("Table", pd.DataFrame()))
|
||||
4
graphrag/logger/__init__.py
Normal file
4
graphrag/logger/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Logger utilities and implementations."""
|
||||
69
graphrag/logger/base.py
Normal file
69
graphrag/logger/base.py
Normal file
@ -0,0 +1,69 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Base classes for logging and progress reporting."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from datashaper.progress.types import Progress
|
||||
|
||||
|
||||
class StatusLogger(ABC):
|
||||
"""Provides a way to log status updates from the pipeline."""
|
||||
|
||||
@abstractmethod
|
||||
def error(self, message: str, details: dict[str, Any] | None = None):
|
||||
"""Log an error."""
|
||||
|
||||
@abstractmethod
|
||||
def warning(self, message: str, details: dict[str, Any] | None = None):
|
||||
"""Log a warning."""
|
||||
|
||||
@abstractmethod
|
||||
def log(self, message: str, details: dict[str, Any] | None = None):
|
||||
"""Report a log."""
|
||||
|
||||
|
||||
class ProgressLogger(ABC):
|
||||
"""
|
||||
Abstract base class for progress loggers.
|
||||
|
||||
This is used to report workflow processing progress via mechanisms like progress-bars.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, update: Progress):
|
||||
"""Update progress."""
|
||||
|
||||
@abstractmethod
|
||||
def dispose(self):
|
||||
"""Dispose of the progress logger."""
|
||||
|
||||
@abstractmethod
|
||||
def child(self, prefix: str, transient=True) -> "ProgressLogger":
|
||||
"""Create a child progress bar."""
|
||||
|
||||
@abstractmethod
|
||||
def force_refresh(self) -> None:
|
||||
"""Force a refresh."""
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""Stop the progress logger."""
|
||||
|
||||
@abstractmethod
|
||||
def error(self, message: str) -> None:
|
||||
"""Log an error."""
|
||||
|
||||
@abstractmethod
|
||||
def warning(self, message: str) -> None:
|
||||
"""Log a warning."""
|
||||
|
||||
@abstractmethod
|
||||
def info(self, message: str) -> None:
|
||||
"""Log information."""
|
||||
|
||||
@abstractmethod
|
||||
def success(self, message: str) -> None:
|
||||
"""Log success."""
|
||||
28
graphrag/logger/console.py
Normal file
28
graphrag/logger/console.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Console Log."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from graphrag.logger.base import StatusLogger
|
||||
|
||||
|
||||
class ConsoleReporter(StatusLogger):
|
||||
"""A logger that writes to a console."""
|
||||
|
||||
def error(self, message: str, details: dict[str, Any] | None = None):
|
||||
"""Log an error."""
|
||||
print(message, details) # noqa T201
|
||||
|
||||
def warning(self, message: str, details: dict[str, Any] | None = None):
|
||||
"""Log a warning."""
|
||||
_print_warning(message)
|
||||
|
||||
def log(self, message: str, details: dict[str, Any] | None = None):
|
||||
"""Log a log."""
|
||||
print(message, details) # noqa T201
|
||||
|
||||
|
||||
def _print_warning(skk):
|
||||
print(f"\033[93m {skk}\033[00m") # noqa T201
|
||||
43
graphrag/logger/factory.py
Normal file
43
graphrag/logger/factory.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Factory functions for creating loggers."""
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.null_progress import NullProgressLogger
|
||||
from graphrag.logger.print_progress import PrintProgressLogger
|
||||
from graphrag.logger.rich_progress import RichProgressLogger
|
||||
from graphrag.logger.types import LoggerType
|
||||
|
||||
|
||||
class LoggerFactory:
|
||||
"""A factory class for loggers."""
|
||||
|
||||
logger_types: ClassVar[dict[str, type]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, logger_type: str, logger: type):
|
||||
"""Register a custom logger implementation."""
|
||||
cls.logger_types[logger_type] = logger
|
||||
|
||||
@classmethod
|
||||
def create_logger(
|
||||
cls, logger_type: LoggerType | str, kwargs: dict | None = None
|
||||
) -> ProgressLogger:
|
||||
"""Create a logger based on the provided type."""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
match logger_type:
|
||||
case LoggerType.RICH:
|
||||
return RichProgressLogger("GraphRAG Indexer ")
|
||||
case LoggerType.PRINT:
|
||||
return PrintProgressLogger("GraphRAG Indexer ")
|
||||
case LoggerType.NONE:
|
||||
return NullProgressLogger()
|
||||
case _:
|
||||
if logger_type in cls.logger_types:
|
||||
return cls.logger_types[logger_type](**kwargs)
|
||||
# default to null logger if no other logger is found
|
||||
return NullProgressLogger()
|
||||
38
graphrag/logger/null_progress.py
Normal file
38
graphrag/logger/null_progress.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Null Progress Reporter."""
|
||||
|
||||
from graphrag.logger.base import Progress, ProgressLogger
|
||||
|
||||
|
||||
class NullProgressLogger(ProgressLogger):
|
||||
"""A progress logger that does nothing."""
|
||||
|
||||
def __call__(self, update: Progress) -> None:
|
||||
"""Update progress."""
|
||||
|
||||
def dispose(self) -> None:
|
||||
"""Dispose of the progress logger."""
|
||||
|
||||
def child(self, prefix: str, transient: bool = True) -> ProgressLogger:
|
||||
"""Create a child progress bar."""
|
||||
return self
|
||||
|
||||
def force_refresh(self) -> None:
|
||||
"""Force a refresh."""
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the progress logger."""
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""Log an error."""
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""Log a warning."""
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""Log information."""
|
||||
|
||||
def success(self, message: str) -> None:
|
||||
"""Log success."""
|
||||
50
graphrag/logger/print_progress.py
Normal file
50
graphrag/logger/print_progress.py
Normal file
@ -0,0 +1,50 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Print Progress Logger."""
|
||||
|
||||
from graphrag.logger.base import Progress, ProgressLogger
|
||||
|
||||
|
||||
class PrintProgressLogger(ProgressLogger):
|
||||
"""A progress logger that prints progress to stdout."""
|
||||
|
||||
prefix: str
|
||||
|
||||
def __init__(self, prefix: str):
|
||||
"""Create a new progress logger."""
|
||||
self.prefix = prefix
|
||||
print(f"\n{self.prefix}", end="") # noqa T201
|
||||
|
||||
def __call__(self, update: Progress) -> None:
|
||||
"""Update progress."""
|
||||
print(".", end="") # noqa T201
|
||||
|
||||
def dispose(self) -> None:
|
||||
"""Dispose of the progress logger."""
|
||||
|
||||
def child(self, prefix: str, transient: bool = True) -> ProgressLogger:
|
||||
"""Create a child progress bar."""
|
||||
return PrintProgressLogger(prefix)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the progress logger."""
|
||||
|
||||
def force_refresh(self) -> None:
|
||||
"""Force a refresh."""
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""Log an error."""
|
||||
print(f"\n{self.prefix}ERROR: {message}") # noqa T201
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""Log a warning."""
|
||||
print(f"\n{self.prefix}WARNING: {message}") # noqa T201
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""Log information."""
|
||||
print(f"\n{self.prefix}INFO: {message}") # noqa T201
|
||||
|
||||
def success(self, message: str) -> None:
|
||||
"""Log success."""
|
||||
print(f"\n{self.prefix}SUCCESS: {message}") # noqa T201
|
||||
165
graphrag/logger/rich_progress.py
Normal file
165
graphrag/logger/rich_progress.py
Normal file
@ -0,0 +1,165 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Rich-based progress logger for CLI use."""
|
||||
|
||||
# Print iterations progress
|
||||
import asyncio
|
||||
|
||||
from datashaper import Progress as DSProgress
|
||||
from rich.console import Console, Group
|
||||
from rich.live import Live
|
||||
from rich.progress import Progress, TaskID, TimeElapsedColumn
|
||||
from rich.spinner import Spinner
|
||||
from rich.tree import Tree
|
||||
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
|
||||
|
||||
# https://stackoverflow.com/a/34325723
|
||||
class RichProgressLogger(ProgressLogger):
|
||||
"""A rich-based progress logger for CLI use."""
|
||||
|
||||
_console: Console
|
||||
_group: Group
|
||||
_tree: Tree
|
||||
_live: Live
|
||||
_task: TaskID | None = None
|
||||
_prefix: str
|
||||
_transient: bool
|
||||
_disposing: bool = False
|
||||
_progressbar: Progress
|
||||
_last_refresh: float = 0
|
||||
|
||||
def dispose(self) -> None:
|
||||
"""Dispose of the progress logger."""
|
||||
self._disposing = True
|
||||
self._live.stop()
|
||||
|
||||
@property
|
||||
def console(self) -> Console:
|
||||
"""Get the console."""
|
||||
return self._console
|
||||
|
||||
@property
|
||||
def group(self) -> Group:
|
||||
"""Get the group."""
|
||||
return self._group
|
||||
|
||||
@property
|
||||
def tree(self) -> Tree:
|
||||
"""Get the tree."""
|
||||
return self._tree
|
||||
|
||||
@property
|
||||
def live(self) -> Live:
|
||||
"""Get the live."""
|
||||
return self._live
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
parent: "RichProgressLogger | None" = None,
|
||||
transient: bool = True,
|
||||
) -> None:
|
||||
"""Create a new rich-based progress logger."""
|
||||
self._prefix = prefix
|
||||
|
||||
if parent is None:
|
||||
console = Console()
|
||||
group = Group(Spinner("dots", prefix), fit=True)
|
||||
tree = Tree(group)
|
||||
live = Live(
|
||||
tree, console=console, refresh_per_second=1, vertical_overflow="crop"
|
||||
)
|
||||
live.start()
|
||||
|
||||
self._console = console
|
||||
self._group = group
|
||||
self._tree = tree
|
||||
self._live = live
|
||||
self._transient = False
|
||||
else:
|
||||
self._console = parent.console
|
||||
self._group = parent.group
|
||||
progress_columns = [*Progress.get_default_columns(), TimeElapsedColumn()]
|
||||
self._progressbar = Progress(
|
||||
*progress_columns, console=self._console, transient=transient
|
||||
)
|
||||
|
||||
tree = Tree(prefix)
|
||||
tree.add(self._progressbar)
|
||||
tree.hide_root = True
|
||||
|
||||
if parent is not None:
|
||||
parent_tree = parent.tree
|
||||
parent_tree.hide_root = False
|
||||
parent_tree.add(tree)
|
||||
|
||||
self._tree = tree
|
||||
self._live = parent.live
|
||||
self._transient = transient
|
||||
|
||||
self.refresh()
|
||||
|
||||
def refresh(self) -> None:
|
||||
"""Perform a debounced refresh."""
|
||||
now = asyncio.get_event_loop().time()
|
||||
duration = now - self._last_refresh
|
||||
if duration > 0.1:
|
||||
self._last_refresh = now
|
||||
self.force_refresh()
|
||||
|
||||
def force_refresh(self) -> None:
|
||||
"""Force a refresh."""
|
||||
self.live.refresh()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the progress logger."""
|
||||
self._live.stop()
|
||||
|
||||
def child(self, prefix: str, transient: bool = True) -> ProgressLogger:
|
||||
"""Create a child progress bar."""
|
||||
return RichProgressLogger(parent=self, prefix=prefix, transient=transient)
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""Log an error."""
|
||||
self._console.print(f"❌ [red]{message}[/red]")
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""Log a warning."""
|
||||
self._console.print(f"⚠️ [yellow]{message}[/yellow]")
|
||||
|
||||
def success(self, message: str) -> None:
|
||||
"""Log success."""
|
||||
self._console.print(f"🚀 [green]{message}[/green]")
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""Log information."""
|
||||
self._console.print(message)
|
||||
|
||||
def __call__(self, progress_update: DSProgress) -> None:
|
||||
"""Update progress."""
|
||||
if self._disposing:
|
||||
return
|
||||
progressbar = self._progressbar
|
||||
|
||||
if self._task is None:
|
||||
self._task = progressbar.add_task(self._prefix)
|
||||
|
||||
progress_description = ""
|
||||
if progress_update.description is not None:
|
||||
progress_description = f" - {progress_update.description}"
|
||||
|
||||
completed = progress_update.completed_items or progress_update.percent
|
||||
total = progress_update.total_items or 1
|
||||
progressbar.update(
|
||||
self._task,
|
||||
completed=completed,
|
||||
total=total,
|
||||
description=f"{self._prefix}{progress_description}",
|
||||
)
|
||||
if completed == total and self._transient:
|
||||
progressbar.update(self._task, visible=False)
|
||||
|
||||
self.refresh()
|
||||
22
graphrag/logger/types.py
Normal file
22
graphrag/logger/types.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Logging types.
|
||||
|
||||
This module defines the types of loggers that can be used.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
# Note: Code in this module was not included in the factory module because it negatively impacts the CLI experience.
|
||||
class LoggerType(str, Enum):
|
||||
"""The type of logger to use."""
|
||||
|
||||
RICH = "rich"
|
||||
PRINT = "print"
|
||||
NONE = "none"
|
||||
|
||||
def __str__(self):
|
||||
"""Return a string representation of the enum value."""
|
||||
return self.value
|
||||
193
graphrag/query/factory.py
Normal file
193
graphrag/query/factory.py
Normal file
@ -0,0 +1,193 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Query Factory methods to support CLI."""
|
||||
|
||||
import tiktoken
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.model.community import Community
|
||||
from graphrag.model.community_report import CommunityReport
|
||||
from graphrag.model.covariate import Covariate
|
||||
from graphrag.model.entity import Entity
|
||||
from graphrag.model.relationship import Relationship
|
||||
from graphrag.model.text_unit import TextUnit
|
||||
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
|
||||
from graphrag.query.llm.get_client import get_llm, get_text_embedder
|
||||
from graphrag.query.structured_search.drift_search.drift_context import (
|
||||
DRIFTSearchContextBuilder,
|
||||
)
|
||||
from graphrag.query.structured_search.drift_search.search import DRIFTSearch
|
||||
from graphrag.query.structured_search.global_search.community_context import (
|
||||
GlobalCommunityContext,
|
||||
)
|
||||
from graphrag.query.structured_search.global_search.search import GlobalSearch
|
||||
from graphrag.query.structured_search.local_search.mixed_context import (
|
||||
LocalSearchMixedContext,
|
||||
)
|
||||
from graphrag.query.structured_search.local_search.search import LocalSearch
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
|
||||
|
||||
def get_local_search_engine(
|
||||
config: GraphRagConfig,
|
||||
reports: list[CommunityReport],
|
||||
text_units: list[TextUnit],
|
||||
entities: list[Entity],
|
||||
relationships: list[Relationship],
|
||||
covariates: dict[str, list[Covariate]],
|
||||
response_type: str,
|
||||
description_embedding_store: BaseVectorStore,
|
||||
system_prompt: str | None = None,
|
||||
) -> LocalSearch:
|
||||
"""Create a local search engine based on data + configuration."""
|
||||
llm = get_llm(config)
|
||||
text_embedder = get_text_embedder(config)
|
||||
token_encoder = tiktoken.get_encoding(config.encoding_model)
|
||||
|
||||
ls_config = config.local_search
|
||||
|
||||
return LocalSearch(
|
||||
llm=llm,
|
||||
system_prompt=system_prompt,
|
||||
context_builder=LocalSearchMixedContext(
|
||||
community_reports=reports,
|
||||
text_units=text_units,
|
||||
entities=entities,
|
||||
relationships=relationships,
|
||||
covariates=covariates,
|
||||
entity_text_embeddings=description_embedding_store,
|
||||
embedding_vectorstore_key=EntityVectorStoreKey.ID, # if the vectorstore uses entity title as ids, set this to EntityVectorStoreKey.TITLE
|
||||
text_embedder=text_embedder,
|
||||
token_encoder=token_encoder,
|
||||
),
|
||||
token_encoder=token_encoder,
|
||||
llm_params={
|
||||
"max_tokens": ls_config.llm_max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 1000=1500)
|
||||
"temperature": ls_config.temperature,
|
||||
"top_p": ls_config.top_p,
|
||||
"n": ls_config.n,
|
||||
},
|
||||
context_builder_params={
|
||||
"text_unit_prop": ls_config.text_unit_prop,
|
||||
"community_prop": ls_config.community_prop,
|
||||
"conversation_history_max_turns": ls_config.conversation_history_max_turns,
|
||||
"conversation_history_user_turns_only": True,
|
||||
"top_k_mapped_entities": ls_config.top_k_entities,
|
||||
"top_k_relationships": ls_config.top_k_relationships,
|
||||
"include_entity_rank": True,
|
||||
"include_relationship_weight": True,
|
||||
"include_community_rank": False,
|
||||
"return_candidate_context": False,
|
||||
"embedding_vectorstore_key": EntityVectorStoreKey.ID, # set this to EntityVectorStoreKey.TITLE if the vectorstore uses entity title as ids
|
||||
"max_tokens": ls_config.max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)
|
||||
},
|
||||
response_type=response_type,
|
||||
)
|
||||
|
||||
|
||||
def get_global_search_engine(
|
||||
config: GraphRagConfig,
|
||||
reports: list[CommunityReport],
|
||||
entities: list[Entity],
|
||||
communities: list[Community],
|
||||
response_type: str,
|
||||
dynamic_community_selection: bool = False,
|
||||
map_system_prompt: str | None = None,
|
||||
reduce_system_prompt: str | None = None,
|
||||
general_knowledge_inclusion_prompt: str | None = None,
|
||||
) -> GlobalSearch:
|
||||
"""Create a global search engine based on data + configuration."""
|
||||
token_encoder = tiktoken.get_encoding(config.encoding_model)
|
||||
gs_config = config.global_search
|
||||
|
||||
dynamic_community_selection_kwargs = {}
|
||||
if dynamic_community_selection:
|
||||
# TODO: Allow for another llm definition only for Global Search to leverage -mini models
|
||||
|
||||
dynamic_community_selection_kwargs.update({
|
||||
"llm": get_llm(config),
|
||||
"token_encoder": tiktoken.encoding_for_model(config.llm.model),
|
||||
"keep_parent": gs_config.dynamic_search_keep_parent,
|
||||
"num_repeats": gs_config.dynamic_search_num_repeats,
|
||||
"use_summary": gs_config.dynamic_search_use_summary,
|
||||
"concurrent_coroutines": gs_config.dynamic_search_concurrent_coroutines,
|
||||
"threshold": gs_config.dynamic_search_threshold,
|
||||
"max_level": gs_config.dynamic_search_max_level,
|
||||
})
|
||||
|
||||
return GlobalSearch(
|
||||
llm=get_llm(config),
|
||||
map_system_prompt=map_system_prompt,
|
||||
reduce_system_prompt=reduce_system_prompt,
|
||||
general_knowledge_inclusion_prompt=general_knowledge_inclusion_prompt,
|
||||
context_builder=GlobalCommunityContext(
|
||||
community_reports=reports,
|
||||
communities=communities,
|
||||
entities=entities,
|
||||
token_encoder=token_encoder,
|
||||
dynamic_community_selection=dynamic_community_selection,
|
||||
dynamic_community_selection_kwargs=dynamic_community_selection_kwargs,
|
||||
),
|
||||
token_encoder=token_encoder,
|
||||
max_data_tokens=gs_config.data_max_tokens,
|
||||
map_llm_params={
|
||||
"max_tokens": gs_config.map_max_tokens,
|
||||
"temperature": gs_config.temperature,
|
||||
"top_p": gs_config.top_p,
|
||||
"n": gs_config.n,
|
||||
},
|
||||
reduce_llm_params={
|
||||
"max_tokens": gs_config.reduce_max_tokens,
|
||||
"temperature": gs_config.temperature,
|
||||
"top_p": gs_config.top_p,
|
||||
"n": gs_config.n,
|
||||
},
|
||||
allow_general_knowledge=False,
|
||||
json_mode=False,
|
||||
context_builder_params={
|
||||
"use_community_summary": False,
|
||||
"shuffle_data": True,
|
||||
"include_community_rank": True,
|
||||
"min_community_rank": 0,
|
||||
"community_rank_name": "rank",
|
||||
"include_community_weight": True,
|
||||
"community_weight_name": "occurrence weight",
|
||||
"normalize_community_weight": True,
|
||||
"max_tokens": gs_config.max_tokens,
|
||||
"context_name": "Reports",
|
||||
},
|
||||
concurrent_coroutines=gs_config.concurrency,
|
||||
response_type=response_type,
|
||||
)
|
||||
|
||||
|
||||
def get_drift_search_engine(
|
||||
config: GraphRagConfig,
|
||||
reports: list[CommunityReport],
|
||||
text_units: list[TextUnit],
|
||||
entities: list[Entity],
|
||||
relationships: list[Relationship],
|
||||
description_embedding_store: BaseVectorStore,
|
||||
local_system_prompt: str | None = None,
|
||||
) -> DRIFTSearch:
|
||||
"""Create a local search engine based on data + configuration."""
|
||||
llm = get_llm(config)
|
||||
text_embedder = get_text_embedder(config)
|
||||
token_encoder = tiktoken.get_encoding(config.encoding_model)
|
||||
|
||||
return DRIFTSearch(
|
||||
llm=llm,
|
||||
context_builder=DRIFTSearchContextBuilder(
|
||||
chat_llm=llm,
|
||||
text_embedder=text_embedder,
|
||||
entities=entities,
|
||||
relationships=relationships,
|
||||
reports=reports,
|
||||
entity_text_embeddings=description_embedding_store,
|
||||
text_units=text_units,
|
||||
local_system_prompt=local_system_prompt,
|
||||
config=config.drift_search,
|
||||
),
|
||||
token_encoder=token_encoder,
|
||||
)
|
||||
4
graphrag/storage/__init__.py
Normal file
4
graphrag/storage/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""The storage package root."""
|
||||
375
graphrag/storage/blob_pipeline_storage.py
Normal file
375
graphrag/storage/blob_pipeline_storage.py
Normal file
@ -0,0 +1,375 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Azure Blob Storage implementation of PipelineStorage."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
from datashaper import Progress
|
||||
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BlobPipelineStorage(PipelineStorage):
|
||||
"""The Blob-Storage implementation."""
|
||||
|
||||
_connection_string: str | None
|
||||
_container_name: str
|
||||
_path_prefix: str
|
||||
_encoding: str
|
||||
_storage_account_blob_url: str | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_string: str | None,
|
||||
container_name: str,
|
||||
encoding: str = "utf-8",
|
||||
path_prefix: str | None = None,
|
||||
storage_account_blob_url: str | None = None,
|
||||
):
|
||||
"""Create a new BlobStorage instance."""
|
||||
if connection_string:
|
||||
self._blob_service_client = BlobServiceClient.from_connection_string(
|
||||
connection_string
|
||||
)
|
||||
else:
|
||||
if storage_account_blob_url is None:
|
||||
msg = "Either connection_string or storage_account_blob_url must be provided."
|
||||
raise ValueError(msg)
|
||||
|
||||
self._blob_service_client = BlobServiceClient(
|
||||
account_url=storage_account_blob_url,
|
||||
credential=DefaultAzureCredential(),
|
||||
)
|
||||
self._encoding = encoding
|
||||
self._container_name = container_name
|
||||
self._connection_string = connection_string
|
||||
self._path_prefix = path_prefix or ""
|
||||
self._storage_account_blob_url = storage_account_blob_url
|
||||
self._storage_account_name = (
|
||||
storage_account_blob_url.split("//")[1].split(".")[0]
|
||||
if storage_account_blob_url
|
||||
else None
|
||||
)
|
||||
log.info(
|
||||
"creating blob storage at container=%s, path=%s",
|
||||
self._container_name,
|
||||
self._path_prefix,
|
||||
)
|
||||
self.create_container()
|
||||
|
||||
def create_container(self) -> None:
|
||||
"""Create the container if it does not exist."""
|
||||
if not self.container_exists():
|
||||
container_name = self._container_name
|
||||
container_names = [
|
||||
container.name
|
||||
for container in self._blob_service_client.list_containers()
|
||||
]
|
||||
if container_name not in container_names:
|
||||
self._blob_service_client.create_container(container_name)
|
||||
|
||||
def delete_container(self) -> None:
|
||||
"""Delete the container."""
|
||||
if self.container_exists():
|
||||
self._blob_service_client.delete_container(self._container_name)
|
||||
|
||||
def container_exists(self) -> bool:
|
||||
"""Check if the container exists."""
|
||||
container_name = self._container_name
|
||||
container_names = [
|
||||
container.name for container in self._blob_service_client.list_containers()
|
||||
]
|
||||
return container_name in container_names
|
||||
|
||||
def find(
|
||||
self,
|
||||
file_pattern: re.Pattern[str],
|
||||
base_dir: str | None = None,
|
||||
progress: ProgressLogger | None = None,
|
||||
file_filter: dict[str, Any] | None = None,
|
||||
max_count=-1,
|
||||
) -> Iterator[tuple[str, dict[str, Any]]]:
|
||||
"""Find blobs in a container using a file pattern, as well as a custom filter function.
|
||||
|
||||
Params:
|
||||
base_dir: The name of the base container.
|
||||
file_pattern: The file pattern to use.
|
||||
file_filter: A dictionary of key-value pairs to filter the blobs.
|
||||
max_count: The maximum number of blobs to return. If -1, all blobs are returned.
|
||||
|
||||
Returns
|
||||
-------
|
||||
An iterator of blob names and their corresponding regex matches.
|
||||
"""
|
||||
base_dir = base_dir or ""
|
||||
|
||||
log.info(
|
||||
"search container %s for files matching %s",
|
||||
self._container_name,
|
||||
file_pattern.pattern,
|
||||
)
|
||||
|
||||
def blobname(blob_name: str) -> str:
|
||||
if blob_name.startswith(self._path_prefix):
|
||||
blob_name = blob_name.replace(self._path_prefix, "", 1)
|
||||
if blob_name.startswith("/"):
|
||||
blob_name = blob_name[1:]
|
||||
return blob_name
|
||||
|
||||
def item_filter(item: dict[str, Any]) -> bool:
|
||||
if file_filter is None:
|
||||
return True
|
||||
|
||||
return all(re.match(value, item[key]) for key, value in file_filter.items())
|
||||
|
||||
try:
|
||||
container_client = self._blob_service_client.get_container_client(
|
||||
self._container_name
|
||||
)
|
||||
all_blobs = list(container_client.list_blobs())
|
||||
|
||||
num_loaded = 0
|
||||
num_total = len(list(all_blobs))
|
||||
num_filtered = 0
|
||||
for blob in all_blobs:
|
||||
match = file_pattern.match(blob.name)
|
||||
if match and blob.name.startswith(base_dir):
|
||||
group = match.groupdict()
|
||||
if item_filter(group):
|
||||
yield (blobname(blob.name), group)
|
||||
num_loaded += 1
|
||||
if max_count > 0 and num_loaded >= max_count:
|
||||
break
|
||||
else:
|
||||
num_filtered += 1
|
||||
else:
|
||||
num_filtered += 1
|
||||
if progress is not None:
|
||||
progress(
|
||||
_create_progress_status(num_loaded, num_filtered, num_total)
|
||||
)
|
||||
except Exception:
|
||||
log.exception(
|
||||
"Error finding blobs: base_dir=%s, file_pattern=%s, file_filter=%s",
|
||||
base_dir,
|
||||
file_pattern,
|
||||
file_filter,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get(
|
||||
self, key: str, as_bytes: bool | None = False, encoding: str | None = None
|
||||
) -> Any:
|
||||
"""Get a value from the cache."""
|
||||
try:
|
||||
key = self._keyname(key)
|
||||
container_client = self._blob_service_client.get_container_client(
|
||||
self._container_name
|
||||
)
|
||||
blob_client = container_client.get_blob_client(key)
|
||||
blob_data = blob_client.download_blob().readall()
|
||||
if not as_bytes:
|
||||
coding = encoding or self._encoding
|
||||
blob_data = blob_data.decode(coding)
|
||||
except Exception:
|
||||
log.exception("Error getting key %s", key)
|
||||
return None
|
||||
else:
|
||||
return blob_data
|
||||
|
||||
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
|
||||
"""Set a value in the cache."""
|
||||
try:
|
||||
key = self._keyname(key)
|
||||
container_client = self._blob_service_client.get_container_client(
|
||||
self._container_name
|
||||
)
|
||||
blob_client = container_client.get_blob_client(key)
|
||||
if isinstance(value, bytes):
|
||||
blob_client.upload_blob(value, overwrite=True)
|
||||
else:
|
||||
coding = encoding or self._encoding
|
||||
blob_client.upload_blob(value.encode(coding), overwrite=True)
|
||||
except Exception:
|
||||
log.exception("Error setting key %s: %s", key)
|
||||
|
||||
def set_df_json(self, key: str, dataframe: Any) -> None:
|
||||
"""Set a json dataframe."""
|
||||
if self._connection_string is None and self._storage_account_name:
|
||||
dataframe.to_json(
|
||||
self._abfs_url(key),
|
||||
storage_options={
|
||||
"account_name": self._storage_account_name,
|
||||
"credential": DefaultAzureCredential(),
|
||||
},
|
||||
orient="records",
|
||||
lines=True,
|
||||
force_ascii=False,
|
||||
)
|
||||
else:
|
||||
dataframe.to_json(
|
||||
self._abfs_url(key),
|
||||
storage_options={"connection_string": self._connection_string},
|
||||
orient="records",
|
||||
lines=True,
|
||||
force_ascii=False,
|
||||
)
|
||||
|
||||
def set_df_parquet(self, key: str, dataframe: Any) -> None:
|
||||
"""Set a parquet dataframe."""
|
||||
if self._connection_string is None and self._storage_account_name:
|
||||
dataframe.to_parquet(
|
||||
self._abfs_url(key),
|
||||
storage_options={
|
||||
"account_name": self._storage_account_name,
|
||||
"credential": DefaultAzureCredential(),
|
||||
},
|
||||
)
|
||||
else:
|
||||
dataframe.to_parquet(
|
||||
self._abfs_url(key),
|
||||
storage_options={"connection_string": self._connection_string},
|
||||
)
|
||||
|
||||
async def has(self, key: str) -> bool:
|
||||
"""Check if a key exists in the cache."""
|
||||
key = self._keyname(key)
|
||||
container_client = self._blob_service_client.get_container_client(
|
||||
self._container_name
|
||||
)
|
||||
blob_client = container_client.get_blob_client(key)
|
||||
return blob_client.exists()
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
"""Delete a key from the cache."""
|
||||
key = self._keyname(key)
|
||||
container_client = self._blob_service_client.get_container_client(
|
||||
self._container_name
|
||||
)
|
||||
blob_client = container_client.get_blob_client(key)
|
||||
blob_client.delete_blob()
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the cache."""
|
||||
|
||||
def child(self, name: str | None) -> "PipelineStorage":
|
||||
"""Create a child storage instance."""
|
||||
if name is None:
|
||||
return self
|
||||
path = str(Path(self._path_prefix) / name)
|
||||
return BlobPipelineStorage(
|
||||
self._connection_string,
|
||||
self._container_name,
|
||||
self._encoding,
|
||||
path,
|
||||
self._storage_account_blob_url,
|
||||
)
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
"""Return the keys in the storage."""
|
||||
msg = "Blob storage does yet not support listing keys."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def _keyname(self, key: str) -> str:
|
||||
"""Get the key name."""
|
||||
return str(Path(self._path_prefix) / key)
|
||||
|
||||
def _abfs_url(self, key: str) -> str:
|
||||
"""Get the ABFS URL."""
|
||||
path = str(Path(self._container_name) / self._path_prefix / key)
|
||||
return f"abfs://{path}"
|
||||
|
||||
|
||||
def create_blob_storage(
|
||||
connection_string: str | None,
|
||||
storage_account_blob_url: str | None,
|
||||
container_name: str,
|
||||
base_dir: str | None,
|
||||
) -> PipelineStorage:
|
||||
"""Create a blob based storage."""
|
||||
log.info("Creating blob storage at %s", container_name)
|
||||
if container_name is None:
|
||||
msg = "No container name provided for blob storage."
|
||||
raise ValueError(msg)
|
||||
if connection_string is None and storage_account_blob_url is None:
|
||||
msg = "No storage account blob url provided for blob storage."
|
||||
raise ValueError(msg)
|
||||
return BlobPipelineStorage(
|
||||
connection_string=connection_string,
|
||||
container_name=container_name,
|
||||
path_prefix=base_dir,
|
||||
storage_account_blob_url=storage_account_blob_url,
|
||||
)
|
||||
|
||||
|
||||
def validate_blob_container_name(container_name: str):
|
||||
"""
|
||||
Check if the provided blob container name is valid based on Azure rules.
|
||||
|
||||
- A blob container name must be between 3 and 63 characters in length.
|
||||
- Start with a letter or number
|
||||
- All letters used in blob container names must be lowercase.
|
||||
- Contain only letters, numbers, or the hyphen.
|
||||
- Consecutive hyphens are not permitted.
|
||||
- Cannot end with a hyphen.
|
||||
|
||||
Args:
|
||||
-----
|
||||
container_name (str)
|
||||
The blob container name to be validated.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool: True if valid, False otherwise.
|
||||
"""
|
||||
# Check the length of the name
|
||||
if len(container_name) < 3 or len(container_name) > 63:
|
||||
return ValueError(
|
||||
f"Container name must be between 3 and 63 characters in length. Name provided was {len(container_name)} characters long."
|
||||
)
|
||||
|
||||
# Check if the name starts with a letter or number
|
||||
if not container_name[0].isalnum():
|
||||
return ValueError(
|
||||
f"Container name must start with a letter or number. Starting character was {container_name[0]}."
|
||||
)
|
||||
|
||||
# Check for valid characters (letters, numbers, hyphen) and lowercase letters
|
||||
if not re.match(r"^[a-z0-9-]+$", container_name):
|
||||
return ValueError(
|
||||
f"Container name must only contain:\n- lowercase letters\n- numbers\n- or hyphens\nName provided was {container_name}."
|
||||
)
|
||||
|
||||
# Check for consecutive hyphens
|
||||
if "--" in container_name:
|
||||
return ValueError(
|
||||
f"Container name cannot contain consecutive hyphens. Name provided was {container_name}."
|
||||
)
|
||||
|
||||
# Check for hyphens at the end of the name
|
||||
if container_name[-1] == "-":
|
||||
return ValueError(
|
||||
f"Container name cannot end with a hyphen. Name provided was {container_name}."
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _create_progress_status(
|
||||
num_loaded: int, num_filtered: int, num_total: int
|
||||
) -> Progress:
|
||||
return Progress(
|
||||
total_items=num_total,
|
||||
completed_items=num_loaded + num_filtered,
|
||||
description=f"{num_loaded} files loaded ({num_filtered} filtered)",
|
||||
)
|
||||
308
graphrag/storage/cosmosdb_pipeline_storage.py
Normal file
308
graphrag/storage/cosmosdb_pipeline_storage.py
Normal file
@ -0,0 +1,308 @@
|
||||
#Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Azure CosmosDB Storage implementation of PipelineStorage."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Iterator
|
||||
from io import BytesIO, StringIO
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
from azure.cosmos import CosmosClient
|
||||
from azure.cosmos.partition_key import PartitionKey
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from datashaper import Progress
|
||||
|
||||
from graphrag.logging.base import ProgressReporter
|
||||
|
||||
from .pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
class CosmosDBPipelineStorage(PipelineStorage):
|
||||
"""The CosmosDB-Storage Implementation."""
|
||||
|
||||
_cosmosdb_account_url: str | None
|
||||
_connection_string: str | None
|
||||
_database_name: str
|
||||
_current_container: str | None
|
||||
_encoding: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cosmosdb_account_url: str | None,
|
||||
connection_string: str | None,
|
||||
database_name: str,
|
||||
encoding: str | None = None,
|
||||
current_container: str | None = None,
|
||||
):
|
||||
"""Initialize the CosmosDB-Storage."""
|
||||
if connection_string:
|
||||
self._cosmos_client = CosmosClient.from_connection_string(
|
||||
connection_string
|
||||
)
|
||||
else:
|
||||
if cosmosdb_account_url is None:
|
||||
msg = "Either connection_string or cosmosdb_accoun_url must be provided."
|
||||
raise ValueError(msg)
|
||||
|
||||
self._cosmos_client = CosmosClient(
|
||||
url=cosmosdb_account_url,
|
||||
credential=DefaultAzureCredential(),
|
||||
)
|
||||
|
||||
self._encoding = encoding or "utf-8"
|
||||
self._database_name = database_name
|
||||
self._connection_string = connection_string
|
||||
self._cosmosdb_account_url = cosmosdb_account_url
|
||||
self._current_container = current_container or None
|
||||
self._cosmosdb_account_name = (
|
||||
cosmosdb_account_url.split("//")[1].split(".")[0]
|
||||
if cosmosdb_account_url
|
||||
else None
|
||||
)
|
||||
self._database_client = self._cosmos_client.get_database_client(
|
||||
self._database_name
|
||||
)
|
||||
|
||||
log.info(
|
||||
"creating cosmosdb storage with account: %s and database: %s",
|
||||
self._cosmosdb_account_name,
|
||||
self._database_name,
|
||||
)
|
||||
self.create_database()
|
||||
if self._current_container is not None:
|
||||
self.create_container()
|
||||
|
||||
def create_database(self) -> None:
|
||||
"""Create the database if it doesn't exist."""
|
||||
database_name = self._database_name
|
||||
self._cosmos_client.create_database_if_not_exists(id=database_name)
|
||||
|
||||
def delete_database(self) -> None:
|
||||
"""Delete the database if it exists."""
|
||||
if self.database_exists():
|
||||
self._cosmos_client.delete_database(self._database_name)
|
||||
|
||||
def database_exists(self) -> bool:
|
||||
"""Check if the database exists."""
|
||||
database_name = self._database_name
|
||||
database_names = [
|
||||
database["id"] for database in self._cosmos_client.list_databases()
|
||||
]
|
||||
return database_name in database_names
|
||||
|
||||
def find(
|
||||
self,
|
||||
file_pattern: re.Pattern[str],
|
||||
base_dir: str | None = None,
|
||||
progress: ProgressReporter | None = None,
|
||||
file_filter: dict[str, Any] | None = None,
|
||||
max_count=-1,
|
||||
) -> Iterator[tuple[str, dict[str, Any]]]:
|
||||
"""Find documents in a Cosmos DB container using a file pattern and custom filter function.
|
||||
|
||||
Params:
|
||||
base_dir: The name of the base directory (not used in Cosmos DB context).
|
||||
file_pattern: The file pattern to use.
|
||||
file_filter: A dictionary of key-value pairs to filter the documents.
|
||||
max_count: The maximum number of documents to return. If -1, all documents are returned.
|
||||
|
||||
Returns
|
||||
-------
|
||||
An iterator of document IDs and their corresponding regex matches.
|
||||
"""
|
||||
base_dir = base_dir or ""
|
||||
|
||||
log.info(
|
||||
"search container %s for documents matching %s",
|
||||
self._current_container,
|
||||
file_pattern.pattern,
|
||||
)
|
||||
|
||||
def item_filter(item: dict[str, Any]) -> bool:
|
||||
if file_filter is None:
|
||||
return True
|
||||
return all(re.match(value, item.get(key, "")) for key, value in file_filter.items())
|
||||
|
||||
try:
|
||||
database_client = self._database_client
|
||||
container_client = database_client.get_container_client(str(self._current_container))
|
||||
query = "SELECT * FROM c WHERE CONTAINS(c.id, @pattern)"
|
||||
parameters: list[dict[str, Any]] = [{"name": "@pattern", "value": file_pattern.pattern}]
|
||||
if file_filter:
|
||||
for key, value in file_filter.items():
|
||||
query += f" AND c.{key} = @{key}"
|
||||
parameters.append({"name": f"@{key}", "value": value})
|
||||
items = container_client.query_items(query=query, parameters=parameters, enable_cross_partition_query=True)
|
||||
num_loaded = 0
|
||||
num_total = len(list(items))
|
||||
num_filtered = 0
|
||||
for item in items:
|
||||
match = file_pattern.match(item["id"])
|
||||
if match:
|
||||
group = match.groupdict()
|
||||
if item_filter(group):
|
||||
yield (item["id"], group)
|
||||
num_loaded += 1
|
||||
if max_count > 0 and num_loaded >= max_count:
|
||||
break
|
||||
else:
|
||||
num_filtered += 1
|
||||
else:
|
||||
num_filtered += 1
|
||||
if progress is not None:
|
||||
progress(
|
||||
_create_progress_status(num_loaded, num_filtered, num_total)
|
||||
)
|
||||
except Exception:
|
||||
log.exception("An error occurred while searching for documents in Cosmos DB.")
|
||||
|
||||
async def get(
|
||||
self, key: str, as_bytes: bool | None = None, encoding: str | None = None
|
||||
) -> Any:
|
||||
"""Get a file in the database for the given key."""
|
||||
try:
|
||||
database_client = self._database_client
|
||||
if self._current_container is not None:
|
||||
container_client = database_client.get_container_client(
|
||||
self._current_container
|
||||
)
|
||||
item = container_client.read_item(item=key, partition_key=key)
|
||||
item_body = item.get("body")
|
||||
item_json_str = json.dumps(item_body)
|
||||
if as_bytes:
|
||||
item_df = pd.read_json(
|
||||
StringIO(item_json_str),
|
||||
orient="records",
|
||||
lines=False
|
||||
)
|
||||
return item_df.to_parquet()
|
||||
return item_json_str
|
||||
except Exception:
|
||||
log.exception("Error reading item %s", key)
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
|
||||
"""Set a file in the database for the given key."""
|
||||
try:
|
||||
database_client = self._database_client
|
||||
if self._current_container is not None:
|
||||
container_client = database_client.get_container_client(
|
||||
self._current_container
|
||||
)
|
||||
if isinstance(value, bytes):
|
||||
value_df = pd.read_parquet(BytesIO(value))
|
||||
value_json = value_df.to_json(
|
||||
orient="records",
|
||||
lines=False,
|
||||
force_ascii=False
|
||||
)
|
||||
cosmos_db_item = {
|
||||
"id": key,
|
||||
"body": json.loads(value_json)
|
||||
}
|
||||
else:
|
||||
cosmos_db_item = {
|
||||
"id": key,
|
||||
"body": json.loads(value)
|
||||
}
|
||||
container_client.upsert_item(body=cosmos_db_item)
|
||||
except Exception:
|
||||
log.exception("Error writing item %s", key)
|
||||
|
||||
async def has(self, key: str) -> bool:
|
||||
"""Check if the given file exists in the cosmosdb storage."""
|
||||
database_client = self._database_client
|
||||
if self._current_container is not None:
|
||||
container_client = database_client.get_container_client(
|
||||
self._current_container
|
||||
)
|
||||
item_names = [
|
||||
item["id"]
|
||||
for item in container_client.read_all_items()
|
||||
]
|
||||
return key in item_names
|
||||
return False
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
"""Delete the given file from the cosmosdb storage."""
|
||||
database_client = self._database_client
|
||||
if self._current_container is not None:
|
||||
container_client = database_client.get_container_client(
|
||||
self._current_container
|
||||
)
|
||||
container_client.delete_item(item=key, partition_key=key)
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the cosmosdb storage."""
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
"""Return the keys in the storage."""
|
||||
msg = "CosmosDB storage does yet not support listing keys."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def child(self, name: str | None) -> "PipelineStorage":
|
||||
"""Create a child storage instance."""
|
||||
return self
|
||||
|
||||
def create_container(self) -> None:
|
||||
"""Create a container for the current container name if it doesn't exist."""
|
||||
database_client = self._database_client
|
||||
if self._current_container is not None:
|
||||
partition_key = PartitionKey(path="/id", kind="Hash")
|
||||
database_client.create_container_if_not_exists(
|
||||
id=self._current_container,
|
||||
partition_key=partition_key,
|
||||
)
|
||||
|
||||
|
||||
def delete_container(self) -> None:
|
||||
"""Delete the container with the current container name if it exists."""
|
||||
database_client = self._database_client
|
||||
if self.container_exists() and self._current_container is not None:
|
||||
database_client.delete_container(self._current_container)
|
||||
|
||||
def container_exists(self) -> bool:
|
||||
"""Check if the container with the current container name exists."""
|
||||
database_client = self._database_client
|
||||
container_names = [
|
||||
container["id"]
|
||||
for container in database_client.list_containers()
|
||||
]
|
||||
return self._current_container in container_names
|
||||
|
||||
def create_cosmosdb_storage(
|
||||
cosmosdb_account_url: str | None,
|
||||
connection_string: str | None,
|
||||
base_dir: str,
|
||||
container_name: str | None,
|
||||
) -> PipelineStorage:
|
||||
"""Create a CosmosDB storage instance."""
|
||||
log.info("Creating cosmosdb storage")
|
||||
if base_dir is None:
|
||||
msg = "No base_dir provided for database name"
|
||||
raise ValueError(msg)
|
||||
if connection_string is None and cosmosdb_account_url is None:
|
||||
msg = "No cosmosdb account url provided"
|
||||
raise ValueError(msg)
|
||||
return CosmosDBPipelineStorage(
|
||||
cosmosdb_account_url=cosmosdb_account_url,
|
||||
connection_string=connection_string,
|
||||
database_name=base_dir,
|
||||
current_container=container_name,
|
||||
)
|
||||
|
||||
def _create_progress_status(
|
||||
num_loaded: int, num_filtered: int, num_total: int
|
||||
) -> Progress:
|
||||
return Progress(
|
||||
total_items=num_total,
|
||||
completed_items=num_loaded + num_filtered,
|
||||
description=f"{num_loaded} files loaded ({num_filtered} filtered)",
|
||||
)
|
||||
48
graphrag/storage/factory.py
Normal file
48
graphrag/storage/factory.py
Normal file
@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Factory functions for creating storage."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from graphrag.config.enums import StorageType
|
||||
from graphrag.storage.blob_pipeline_storage import create_blob_storage
|
||||
from graphrag.storage.file_pipeline_storage import create_file_storage
|
||||
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
|
||||
class StorageFactory:
|
||||
"""A factory class for storage implementations.
|
||||
|
||||
Includes a method for users to register a custom storage implementation.
|
||||
"""
|
||||
|
||||
storage_types: ClassVar[dict[str, type]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, storage_type: str, storage: type):
|
||||
"""Register a custom storage implementation."""
|
||||
cls.storage_types[storage_type] = storage
|
||||
|
||||
@classmethod
|
||||
def create_storage(
|
||||
cls, storage_type: StorageType | str, kwargs: dict
|
||||
) -> PipelineStorage:
|
||||
"""Create or get a storage object from the provided type."""
|
||||
match storage_type:
|
||||
case StorageType.blob:
|
||||
return create_blob_storage(**kwargs)
|
||||
case StorageType.file:
|
||||
return create_file_storage(**kwargs)
|
||||
case StorageType.memory:
|
||||
return MemoryPipelineStorage()
|
||||
case _:
|
||||
if storage_type in cls.storage_types:
|
||||
return cls.storage_types[storage_type](**kwargs)
|
||||
msg = f"Unknown storage type: {storage_type}"
|
||||
raise ValueError(msg)
|
||||
169
graphrag/storage/file_pipeline_storage.py
Normal file
169
graphrag/storage/file_pipeline_storage.py
Normal file
@ -0,0 +1,169 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'FileStorage' and 'FilePipelineStorage' models."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import aiofiles
|
||||
from aiofiles.os import remove
|
||||
from aiofiles.ospath import exists
|
||||
from datashaper import Progress
|
||||
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FilePipelineStorage(PipelineStorage):
|
||||
"""File storage class definition."""
|
||||
|
||||
_root_dir: str
|
||||
_encoding: str
|
||||
|
||||
def __init__(self, root_dir: str = "", encoding: str = "utf-8"):
|
||||
"""Init method definition."""
|
||||
self._root_dir = root_dir
|
||||
self._encoding = encoding
|
||||
Path(self._root_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def find(
|
||||
self,
|
||||
file_pattern: re.Pattern[str],
|
||||
base_dir: str | None = None,
|
||||
progress: ProgressLogger | None = None,
|
||||
file_filter: dict[str, Any] | None = None,
|
||||
max_count=-1,
|
||||
) -> Iterator[tuple[str, dict[str, Any]]]:
|
||||
"""Find files in the storage using a file pattern, as well as a custom filter function."""
|
||||
|
||||
def item_filter(item: dict[str, Any]) -> bool:
|
||||
if file_filter is None:
|
||||
return True
|
||||
|
||||
return all(re.match(value, item[key]) for key, value in file_filter.items())
|
||||
|
||||
search_path = Path(self._root_dir) / (base_dir or "")
|
||||
log.info("search %s for files matching %s", search_path, file_pattern.pattern)
|
||||
all_files = list(search_path.rglob("**/*"))
|
||||
num_loaded = 0
|
||||
num_total = len(all_files)
|
||||
num_filtered = 0
|
||||
for file in all_files:
|
||||
match = file_pattern.match(f"{file}")
|
||||
if match:
|
||||
group = match.groupdict()
|
||||
if item_filter(group):
|
||||
filename = f"{file}".replace(self._root_dir, "")
|
||||
if filename.startswith(os.sep):
|
||||
filename = filename[1:]
|
||||
yield (filename, group)
|
||||
num_loaded += 1
|
||||
if max_count > 0 and num_loaded >= max_count:
|
||||
break
|
||||
else:
|
||||
num_filtered += 1
|
||||
else:
|
||||
num_filtered += 1
|
||||
if progress is not None:
|
||||
progress(_create_progress_status(num_loaded, num_filtered, num_total))
|
||||
|
||||
async def get(
|
||||
self, key: str, as_bytes: bool | None = False, encoding: str | None = None
|
||||
) -> Any:
|
||||
"""Get method definition."""
|
||||
file_path = join_path(self._root_dir, key)
|
||||
|
||||
if await self.has(key):
|
||||
return await self._read_file(file_path, as_bytes, encoding)
|
||||
if await exists(key):
|
||||
# Lookup for key, as it is pressumably a new file loaded from inputs
|
||||
# and not yet written to storage
|
||||
return await self._read_file(key, as_bytes, encoding)
|
||||
|
||||
return None
|
||||
|
||||
async def _read_file(
|
||||
self,
|
||||
path: str | Path,
|
||||
as_bytes: bool | None = False,
|
||||
encoding: str | None = None,
|
||||
) -> Any:
|
||||
"""Read the contents of a file."""
|
||||
read_type = "rb" if as_bytes else "r"
|
||||
encoding = None if as_bytes else (encoding or self._encoding)
|
||||
|
||||
async with aiofiles.open(
|
||||
path,
|
||||
cast("Any", read_type),
|
||||
encoding=encoding,
|
||||
) as f:
|
||||
return await f.read()
|
||||
|
||||
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
|
||||
"""Set method definition."""
|
||||
is_bytes = isinstance(value, bytes)
|
||||
write_type = "wb" if is_bytes else "w"
|
||||
encoding = None if is_bytes else encoding or self._encoding
|
||||
async with aiofiles.open(
|
||||
join_path(self._root_dir, key),
|
||||
cast("Any", write_type),
|
||||
encoding=encoding,
|
||||
) as f:
|
||||
await f.write(value)
|
||||
|
||||
async def has(self, key: str) -> bool:
|
||||
"""Has method definition."""
|
||||
return await exists(join_path(self._root_dir, key))
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
"""Delete method definition."""
|
||||
if await self.has(key):
|
||||
await remove(join_path(self._root_dir, key))
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear method definition."""
|
||||
for file in Path(self._root_dir).glob("*"):
|
||||
if file.is_dir():
|
||||
shutil.rmtree(file)
|
||||
else:
|
||||
file.unlink()
|
||||
|
||||
def child(self, name: str | None) -> "PipelineStorage":
|
||||
"""Create a child storage instance."""
|
||||
if name is None:
|
||||
return self
|
||||
return FilePipelineStorage(str(Path(self._root_dir) / Path(name)))
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
"""Return the keys in the storage."""
|
||||
return [item.name for item in Path(self._root_dir).iterdir() if item.is_file()]
|
||||
|
||||
|
||||
def join_path(file_path: str, file_name: str) -> Path:
|
||||
"""Join a path and a file. Independent of the OS."""
|
||||
return Path(file_path) / Path(file_name).parent / Path(file_name).name
|
||||
|
||||
|
||||
def create_file_storage(**kwargs: Any) -> PipelineStorage:
|
||||
"""Create a file based storage."""
|
||||
base_dir = kwargs["base_dir"]
|
||||
log.info("Creating file storage at %s", base_dir)
|
||||
return FilePipelineStorage(root_dir=base_dir)
|
||||
|
||||
|
||||
def _create_progress_status(
|
||||
num_loaded: int, num_filtered: int, num_total: int
|
||||
) -> Progress:
|
||||
return Progress(
|
||||
total_items=num_total,
|
||||
completed_items=num_loaded + num_filtered,
|
||||
description=f"{num_loaded} files loaded ({num_filtered} filtered)",
|
||||
)
|
||||
78
graphrag/storage/memory_pipeline_storage.py
Normal file
78
graphrag/storage/memory_pipeline_storage.py
Normal file
@ -0,0 +1,78 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'InMemoryStorage' model."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
|
||||
class MemoryPipelineStorage(FilePipelineStorage):
|
||||
"""In memory storage class definition."""
|
||||
|
||||
_storage: dict[str, Any]
|
||||
|
||||
def __init__(self):
|
||||
"""Init method definition."""
|
||||
super().__init__()
|
||||
self._storage = {}
|
||||
|
||||
async def get(
|
||||
self, key: str, as_bytes: bool | None = None, encoding: str | None = None
|
||||
) -> Any:
|
||||
"""Get the value for the given key.
|
||||
|
||||
Args:
|
||||
- key - The key to get the value for.
|
||||
- as_bytes - Whether or not to return the value as bytes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - The value for the given key.
|
||||
"""
|
||||
return self._storage.get(key)
|
||||
|
||||
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
|
||||
"""Set the value for the given key.
|
||||
|
||||
Args:
|
||||
- key - The key to set the value for.
|
||||
- value - The value to set.
|
||||
"""
|
||||
self._storage[key] = value
|
||||
|
||||
async def has(self, key: str) -> bool:
|
||||
"""Return True if the given key exists in the storage.
|
||||
|
||||
Args:
|
||||
- key - The key to check for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - True if the key exists in the storage, False otherwise.
|
||||
"""
|
||||
return key in self._storage
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
"""Delete the given key from the storage.
|
||||
|
||||
Args:
|
||||
- key - The key to delete.
|
||||
"""
|
||||
del self._storage[key]
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the storage."""
|
||||
self._storage.clear()
|
||||
|
||||
def child(self, name: str | None) -> "PipelineStorage":
|
||||
"""Create a child storage instance."""
|
||||
return MemoryPipelineStorage()
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
"""Return the keys in the storage."""
|
||||
return list(self._storage.keys())
|
||||
82
graphrag/storage/pipeline_storage.py
Normal file
82
graphrag/storage/pipeline_storage.py
Normal file
@ -0,0 +1,82 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'PipelineStorage' model."""
|
||||
|
||||
import re
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
|
||||
|
||||
class PipelineStorage(metaclass=ABCMeta):
|
||||
"""Provide a storage interface for the pipeline. This is where the pipeline will store its output data."""
|
||||
|
||||
@abstractmethod
|
||||
def find(
|
||||
self,
|
||||
file_pattern: re.Pattern[str],
|
||||
base_dir: str | None = None,
|
||||
progress: ProgressLogger | None = None,
|
||||
file_filter: dict[str, Any] | None = None,
|
||||
max_count=-1,
|
||||
) -> Iterator[tuple[str, dict[str, Any]]]:
|
||||
"""Find files in the storage using a file pattern, as well as a custom filter function."""
|
||||
|
||||
@abstractmethod
|
||||
async def get(
|
||||
self, key: str, as_bytes: bool | None = None, encoding: str | None = None
|
||||
) -> Any:
|
||||
"""Get the value for the given key.
|
||||
|
||||
Args:
|
||||
- key - The key to get the value for.
|
||||
- as_bytes - Whether or not to return the value as bytes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - The value for the given key.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
|
||||
"""Set the value for the given key.
|
||||
|
||||
Args:
|
||||
- key - The key to set the value for.
|
||||
- value - The value to set.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def has(self, key: str) -> bool:
|
||||
"""Return True if the given key exists in the storage.
|
||||
|
||||
Args:
|
||||
- key - The key to check for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - True if the key exists in the storage, False otherwise.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, key: str) -> None:
|
||||
"""Delete the given key from the storage.
|
||||
|
||||
Args:
|
||||
- key - The key to delete.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> None:
|
||||
"""Clear the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def child(self, name: str | None) -> "PipelineStorage":
|
||||
"""Create a child storage instance."""
|
||||
|
||||
@abstractmethod
|
||||
def keys(self) -> list[str]:
|
||||
"""List all keys in the storage."""
|
||||
125
tests/unit/indexing/workflows/test_export.py
Normal file
125
tests/unit/indexing/workflows/test_export.py
Normal file
@ -0,0 +1,125 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
Table,
|
||||
VerbInput,
|
||||
VerbResult,
|
||||
create_verb_result,
|
||||
)
|
||||
|
||||
from graphrag.index.config.pipeline import PipelineWorkflowReference
|
||||
from graphrag.index.run import run_pipeline
|
||||
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
|
||||
async def mock_verb(
|
||||
input: VerbInput, storage: PipelineStorage, **_kwargs
|
||||
) -> VerbResult:
|
||||
source = cast("pd.DataFrame", input.get_input())
|
||||
|
||||
output = source[["id"]]
|
||||
|
||||
await storage.set("mock_write", source[["id"]])
|
||||
|
||||
return create_verb_result(
|
||||
cast(
|
||||
"Table",
|
||||
output,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def mock_no_return_verb(
|
||||
input: VerbInput, storage: PipelineStorage, **_kwargs
|
||||
) -> VerbResult:
|
||||
source = cast("pd.DataFrame", input.get_input())
|
||||
|
||||
# write some outputs to storage independent of the return
|
||||
await storage.set("empty_write", source[["name"]])
|
||||
|
||||
return create_verb_result(
|
||||
cast(
|
||||
"Table",
|
||||
pd.DataFrame(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def test_normal_result_exports_parquet():
|
||||
mock_verbs: Any = {"mock_verb": mock_verb}
|
||||
mock_workflows: Any = {
|
||||
"mock_workflow": lambda _x: [
|
||||
{
|
||||
"verb": "mock_verb",
|
||||
"args": {
|
||||
"column": "test",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
workflows = [
|
||||
PipelineWorkflowReference(
|
||||
name="mock_workflow",
|
||||
config=None,
|
||||
)
|
||||
]
|
||||
dataset = pd.DataFrame({"id": [1, 2, 3], "name": ["a", "b", "c"]})
|
||||
storage = MemoryPipelineStorage()
|
||||
pipeline_result = [
|
||||
gen
|
||||
async for gen in run_pipeline(
|
||||
workflows,
|
||||
dataset,
|
||||
storage=storage,
|
||||
additional_workflows=mock_workflows,
|
||||
additional_verbs=mock_verbs,
|
||||
)
|
||||
]
|
||||
|
||||
assert len(pipeline_result) == 1
|
||||
assert storage.keys() == ["stats.json", "mock_write", "mock_workflow.parquet"], (
|
||||
"Mock workflow output should be written to storage by the exporter when there is a non-empty data frame"
|
||||
)
|
||||
|
||||
|
||||
async def test_empty_result_does_not_export_parquet():
|
||||
mock_verbs: Any = {"mock_no_return_verb": mock_no_return_verb}
|
||||
mock_workflows: Any = {
|
||||
"mock_workflow": lambda _x: [
|
||||
{
|
||||
"verb": "mock_no_return_verb",
|
||||
"args": {
|
||||
"column": "test",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
workflows = [
|
||||
PipelineWorkflowReference(
|
||||
name="mock_workflow",
|
||||
config=None,
|
||||
)
|
||||
]
|
||||
dataset = pd.DataFrame({"id": [1, 2, 3], "name": ["a", "b", "c"]})
|
||||
storage = MemoryPipelineStorage()
|
||||
pipeline_result = [
|
||||
gen
|
||||
async for gen in run_pipeline(
|
||||
workflows,
|
||||
dataset,
|
||||
storage=storage,
|
||||
additional_workflows=mock_workflows,
|
||||
additional_verbs=mock_verbs,
|
||||
)
|
||||
]
|
||||
|
||||
assert len(pipeline_result) == 1
|
||||
assert storage.keys() == [
|
||||
"stats.json",
|
||||
"empty_write",
|
||||
], "Mock workflow output should not be written to storage by the exporter"
|
||||
BIN
tests/verbs/data/base_communities.parquet
Normal file
BIN
tests/verbs/data/base_communities.parquet
Normal file
Binary file not shown.
BIN
tests/verbs/data/base_entity_nodes.parquet
Normal file
BIN
tests/verbs/data/base_entity_nodes.parquet
Normal file
Binary file not shown.
BIN
tests/verbs/data/base_relationship_edges.parquet
Normal file
BIN
tests/verbs/data/base_relationship_edges.parquet
Normal file
Binary file not shown.
52
tests/verbs/test_compute_communities.py
Normal file
52
tests/verbs/test_compute_communities.py
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.index.flows.compute_communities import (
|
||||
compute_communities,
|
||||
)
|
||||
from graphrag.index.run.utils import create_run_context
|
||||
from graphrag.index.workflows.v1.compute_communities import (
|
||||
workflow_name,
|
||||
)
|
||||
|
||||
from .util import (
|
||||
compare_outputs,
|
||||
get_config_for_workflow,
|
||||
load_test_table,
|
||||
)
|
||||
|
||||
|
||||
async def test_compute_communities():
|
||||
edges = load_test_table("base_relationship_edges")
|
||||
expected = load_test_table("base_communities")
|
||||
|
||||
context = create_run_context(None, None, None)
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
clustering_strategy = config["cluster_graph"]["strategy"]
|
||||
|
||||
actual = await compute_communities(
|
||||
edges, storage=context.storage, clustering_strategy=clustering_strategy
|
||||
)
|
||||
|
||||
columns = list(expected.columns.values)
|
||||
compare_outputs(actual, expected, columns)
|
||||
assert len(actual.columns) == len(expected.columns)
|
||||
|
||||
|
||||
async def test_compute_communities_with_snapshots():
|
||||
edges = load_test_table("base_relationship_edges")
|
||||
|
||||
context = create_run_context(None, None, None)
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
clustering_strategy = config["cluster_graph"]["strategy"]
|
||||
|
||||
await compute_communities(
|
||||
edges,
|
||||
storage=context.storage,
|
||||
clustering_strategy=clustering_strategy,
|
||||
snapshot_transient_enabled=True,
|
||||
)
|
||||
|
||||
assert context.storage.keys() == [
|
||||
"base_communities.parquet",
|
||||
], "Community snapshot keys differ"
|
||||
159
tests/verbs/test_extract_graph.py
Normal file
159
tests/verbs/test_extract_graph.py
Normal file
@ -0,0 +1,159 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
import pytest
|
||||
|
||||
from graphrag.config.enums import LLMType
|
||||
from graphrag.index.run.utils import create_run_context
|
||||
from graphrag.index.workflows.v1.extract_graph import (
|
||||
build_steps,
|
||||
workflow_name,
|
||||
)
|
||||
|
||||
from .util import (
|
||||
get_config_for_workflow,
|
||||
get_workflow_output,
|
||||
load_input_tables,
|
||||
load_test_table,
|
||||
)
|
||||
|
||||
MOCK_LLM_ENTITY_RESPONSES = [
|
||||
"""
|
||||
("entity"<|>COMPANY_A<|>COMPANY<|>Company_A is a test company)
|
||||
##
|
||||
("entity"<|>COMPANY_B<|>COMPANY<|>Company_B owns Company_A and also shares an address with Company_A)
|
||||
##
|
||||
("entity"<|>PERSON_C<|>PERSON<|>Person_C is director of Company_A)
|
||||
##
|
||||
("relationship"<|>COMPANY_A<|>COMPANY_B<|>Company_A and Company_B are related because Company_A is 100% owned by Company_B and the two companies also share the same address)<|>2)
|
||||
##
|
||||
("relationship"<|>COMPANY_A<|>PERSON_C<|>Company_A and Person_C are related because Person_C is director of Company_A<|>1))
|
||||
""".strip()
|
||||
]
|
||||
|
||||
MOCK_LLM_ENTITY_CONFIG = {
|
||||
"type": LLMType.StaticResponse,
|
||||
"responses": MOCK_LLM_ENTITY_RESPONSES,
|
||||
}
|
||||
|
||||
MOCK_LLM_SUMMARIZATION_RESPONSES = [
|
||||
"""
|
||||
This is a MOCK response for the LLM. It is summarized!
|
||||
""".strip()
|
||||
]
|
||||
|
||||
MOCK_LLM_SUMMARIZATION_CONFIG = {
|
||||
"type": LLMType.StaticResponse,
|
||||
"responses": MOCK_LLM_SUMMARIZATION_RESPONSES,
|
||||
}
|
||||
|
||||
|
||||
async def test_extract_graph():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_text_units",
|
||||
])
|
||||
|
||||
nodes_expected = load_test_table("base_entity_nodes")
|
||||
edges_expected = load_test_table("base_relationship_edges")
|
||||
|
||||
context = create_run_context(None, None, None)
|
||||
await context.runtime_storage.set(
|
||||
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||
)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
|
||||
config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_SUMMARIZATION_CONFIG
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
# graph construction creates transient tables for nodes, edges, and communities
|
||||
nodes_actual = await context.runtime_storage.get("base_entity_nodes")
|
||||
edges_actual = await context.runtime_storage.get("base_relationship_edges")
|
||||
|
||||
assert len(nodes_actual.columns) == len(nodes_expected.columns), (
|
||||
"Nodes dataframe columns differ"
|
||||
)
|
||||
|
||||
assert len(edges_actual.columns) == len(edges_expected.columns), (
|
||||
"Edges dataframe columns differ"
|
||||
)
|
||||
|
||||
# TODO: with the combined verb we can't force summarization
|
||||
# this is because the mock responses always result in a single description, which is returned verbatim rather than summarized
|
||||
# we need to update the mocking to provide somewhat unique graphs so a true merge happens
|
||||
# the assertion should grab a node and ensure the description matches the mock description, not the original as we are doing below
|
||||
|
||||
assert nodes_actual["description"].values[0] == "Company_A is a test company"
|
||||
|
||||
assert len(context.storage.keys()) == 0, "Storage should be empty"
|
||||
|
||||
|
||||
async def test_extract_graph_with_snapshots():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_text_units",
|
||||
])
|
||||
|
||||
context = create_run_context(None, None, None)
|
||||
await context.runtime_storage.set(
|
||||
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||
)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
|
||||
config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_SUMMARIZATION_CONFIG
|
||||
config["snapshot_graphml"] = True
|
||||
config["snapshot_transient"] = True
|
||||
config["embed_graph_enabled"] = True # need this on in order to see the snapshot
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert context.storage.keys() == [
|
||||
"graph.graphml",
|
||||
"base_entity_nodes.parquet",
|
||||
"base_relationship_edges.parquet",
|
||||
], "Graph snapshot keys differ"
|
||||
|
||||
|
||||
async def test_extract_graph_missing_llm_throws():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_text_units",
|
||||
])
|
||||
|
||||
context = create_run_context(None, None, None)
|
||||
await context.runtime_storage.set(
|
||||
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||
)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
|
||||
del config["summarize_descriptions"]["strategy"]["llm"]
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
with pytest.raises(ValueError): # noqa PT011
|
||||
await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user