partially fixed merge conflicts

This commit is contained in:
Kenny Zhang 2024-12-18 15:10:45 -05:00
parent 82548e11f8
commit 9ca67643b4
44 changed files with 3517 additions and 0 deletions

46
.semversioner/0.9.0.json Normal file
View File

@ -0,0 +1,46 @@
{
"changes": [
{
"description": "Refactor graph creation.",
"type": "minor"
},
{
"description": "Dependency updates",
"type": "patch"
},
{
"description": "Fix Global Search with dynamic Community selection bug",
"type": "patch"
},
{
"description": "Fix question gen.",
"type": "patch"
},
{
"description": "Optimize Final Community Reports calculation and stabilize cache",
"type": "patch"
},
{
"description": "miscellaneous code cleanup and minor changes for better alignment of style across the codebase.",
"type": "patch"
},
{
"description": "replace llm package with fnllm",
"type": "patch"
},
{
"description": "replaced md5 hash with sha256",
"type": "patch"
},
{
"description": "replaced md5 hash with sha512",
"type": "patch"
},
{
"description": "update API and add a demonstration notebook",
"type": "patch"
}
],
"created_at": "2024-12-06T20:12:30+00:00",
"version": "0.9.0"
}

26
.semversioner/1.0.0.json Normal file
View File

@ -0,0 +1,26 @@
{
"changes": [
{
"description": "Add Parent id to communities data model",
"type": "patch"
},
{
"description": "Add migration notebook.",
"type": "patch"
},
{
"description": "Create separate community workflow, collapse subflows.",
"type": "patch"
},
{
"description": "Dependency Updates",
"type": "patch"
},
{
"description": "cleanup and refactor factory classes.",
"type": "patch"
}
],
"created_at": "2024-12-11T21:41:49+00:00",
"version": "1.0.0"
}

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Respect encoding_model option"
}

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Fix exception on error callbacks"
}

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Fix encoding model config parsing"
}

View File

@ -0,0 +1,209 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Copyright (c) 2024 Microsoft Corporation.\n",
"# Licensed under the MIT License."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## API Overview\n",
"\n",
"This notebook provides a demonstration of how to interact with graphrag as a library using the API as opposed to the CLI. Note that graphrag's CLI actually connects to the library through this API for all operations. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import graphrag.api as api\n",
"from graphrag.index.typing import PipelineRunResult"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prerequisite\n",
"As a prerequisite to all API operations, a `GraphRagConfig` object is required. It is the primary means to control the behavior of graphrag and can be instantiated from a `settings.yaml` configuration file.\n",
"\n",
"Please refer to the [CLI docs](https://microsoft.github.io/graphrag/cli/#init) for more detailed information on how to generate the `settings.yaml` file.\n",
"\n",
"#### Load `settings.yaml` configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import yaml\n",
"\n",
"settings = yaml.safe_load(open(\"<project_directory>/settings.yaml\")) # noqa: PTH123, SIM115"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"At this point, you can modify the imported settings to align with your application's requirements. For example, if building a UI application, the application might need to change the input and/or storage destinations dynamically in order to enable users to build and query different indexes."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generate a `GraphRagConfig` object"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from graphrag.config.create_graphrag_config import create_graphrag_config\n",
"\n",
"graphrag_config = create_graphrag_config(\n",
" values=settings, root_dir=\"<project_directory>\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Indexing API\n",
"\n",
"*Indexing* is the process of ingesting raw text data and constructing a knowledge graph. GraphRAG currently supports plaintext (`.txt`) and `.csv` file formats."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build an index"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"index_result: list[PipelineRunResult] = await api.build_index(config=graphrag_config)\n",
"\n",
"# index_result is a list of workflows that make up the indexing pipeline that was run\n",
"for workflow_result in index_result:\n",
" status = f\"error\\n{workflow_result.errors}\" if workflow_result.errors else \"success\"\n",
" print(f\"Workflow Name: {workflow_result.workflow}\\tStatus: {status}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Query an index\n",
"\n",
"To query an index, several index files must first be read into memory and passed to the query API. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"final_nodes = pd.read_parquet(\"<project_directory>/output/create_final_nodes.parquet\")\n",
"final_entities = pd.read_parquet(\n",
" \"<project_directory>/output/create_final_entities.parquet\"\n",
")\n",
"final_communities = pd.read_parquet(\n",
" \"<project_directory>/output/create_final_communities.parquet\"\n",
")\n",
"final_community_reports = pd.read_parquet(\n",
" \"<project_directory>/output/create_final_community_reports.parquet\"\n",
")\n",
"\n",
"response, context = await api.global_search(\n",
" config=graphrag_config,\n",
" nodes=final_nodes,\n",
" entities=final_entities,\n",
" communities=final_communities,\n",
" community_reports=final_community_reports,\n",
" community_level=2,\n",
" dynamic_community_selection=False,\n",
" response_type=\"Multiple Paragraphs\",\n",
" query=\"Who is Scrooge and what are his main relationships?\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The response object is the official reponse from graphrag while the context object holds various metadata regarding the querying process used to obtain the final response."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Digging into the context a bit more provides users with extremely granular information such as what sources of data (down to the level of text chunks) were ultimately retrieved and used as part of the context sent to the LLM model)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pprint import pprint\n",
"\n",
"pprint(context) # noqa: T203"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "graphrag-venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,263 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Copyright (c) 2024 Microsoft Corporation.\n",
"# Licensed under the MIT License."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Index Migration\n",
"\n",
"This notebook is used to maintain data model parity with older indexes for the latest versions of GraphRAG. If you have a pre-1.0 index and need to migrate without re-running the entire pipeline, you can use this notebook to only update the pieces necessary for alignment.\n",
"\n",
"NOTE: we recommend regenerating your settings.yml with the latest version of GraphRAG using `graphrag init`. Copy your LLM settings into it before running this notebook. This ensures your config is aligned with the latest version for the migration. This also ensures that you have default vector store config, which is now required or indexing will fail.\n",
"\n",
"WARNING: This will overwrite your parquet files, you may want to make a backup!"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# This is the directory that has your settings.yml\n",
"# NOTE: much older indexes may have been output with a timestamped directory\n",
"# if this is the case, you will need to make sure the storage.base_dir in settings.yml points to it correctly\n",
"PROJECT_DIRECTORY = \"<your project directory>\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"from graphrag.config.load_config import load_config\n",
"from graphrag.config.resolve_path import resolve_paths\n",
"from graphrag.index.create_pipeline_config import create_pipeline_config\n",
"from graphrag.storage.factory import create_storage\n",
"\n",
"# This first block does some config loading, path resolution, and translation that is normally done by the CLI/API when running a full workflow\n",
"config = load_config(Path(PROJECT_DIRECTORY))\n",
"resolve_paths(config)\n",
"pipeline_config = create_pipeline_config(config)\n",
"storage = create_storage(pipeline_config.storage)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def remove_columns(df, columns):\n",
" \"\"\"Remove columns from a DataFrame, suppressing errors.\"\"\"\n",
" df.drop(labels=columns, axis=1, errors=\"ignore\", inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"def get_community_parent(nodes):\n",
" \"\"\"Compute the parent community using the node membership as a lookup.\"\"\"\n",
" parent_mapping = nodes.loc[:, [\"level\", \"community\", \"title\"]]\n",
" nodes = nodes.loc[:, [\"level\", \"community\", \"title\"]]\n",
"\n",
" # Create a parent mapping by adding 1 to the level column\n",
" parent_mapping[\"level\"] += 1 # Shift levels for parent relationship\n",
" parent_mapping.rename(columns={\"community\": \"parent\"}, inplace=True)\n",
"\n",
" # Merge the parent information back into the base DataFrame\n",
" nodes = nodes.merge(parent_mapping, on=[\"level\", \"title\"], how=\"left\")\n",
"\n",
" # Fill missing parents with -1 (default value)\n",
" nodes[\"parent\"] = nodes[\"parent\"].fillna(-1).astype(int)\n",
"\n",
" join = (\n",
" nodes.groupby([\"community\", \"level\", \"parent\"])\n",
" .agg({\"title\": list})\n",
" .reset_index()\n",
" )\n",
" return join[join[\"community\"] > -1].loc[:, [\"community\", \"parent\"]]"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"from uuid import uuid4\n",
"\n",
"from graphrag.utils.storage import load_table_from_storage, write_table_to_storage\n",
"\n",
"# First we'll go through any parquet files that had model changes and update them\n",
"# The new data model may have removed excess columns as well, but we will only make the minimal changes required for compatibility\n",
"\n",
"final_documents = await load_table_from_storage(\n",
" \"create_final_documents.parquet\", storage\n",
")\n",
"final_text_units = await load_table_from_storage(\n",
" \"create_final_text_units.parquet\", storage\n",
")\n",
"final_entities = await load_table_from_storage(\"create_final_entities.parquet\", storage)\n",
"final_nodes = await load_table_from_storage(\"create_final_nodes.parquet\", storage)\n",
"final_relationships = await load_table_from_storage(\n",
" \"create_final_relationships.parquet\", storage\n",
")\n",
"final_communities = await load_table_from_storage(\n",
" \"create_final_communities.parquet\", storage\n",
")\n",
"final_community_reports = await load_table_from_storage(\n",
" \"create_final_community_reports.parquet\", storage\n",
")\n",
"\n",
"\n",
"# Documents renames raw_content for consistency\n",
"if \"raw_content\" in final_documents.columns:\n",
" final_documents.rename(columns={\"raw_content\": \"text\"}, inplace=True)\n",
"final_documents[\"human_readable_id\"] = final_documents.index + 1\n",
"\n",
"# Text units just get a human_readable_id or consistency\n",
"final_text_units[\"human_readable_id\"] = final_text_units.index + 1\n",
"\n",
"# We renamed \"name\" to \"title\" for consistency with the rest of the tables\n",
"if \"name\" in final_entities.columns:\n",
" final_entities.rename(columns={\"name\": \"title\"}, inplace=True)\n",
"remove_columns(\n",
" final_entities, [\"mname_embedding\", \"graph_embedding\", \"description_embedding\"]\n",
")\n",
"\n",
"# Final nodes uses community for joins, which is now an int everywhere\n",
"final_nodes[\"community\"] = final_nodes[\"community\"].fillna(-1)\n",
"final_nodes[\"community\"] = final_nodes[\"community\"].astype(int)\n",
"remove_columns(\n",
" final_nodes,\n",
" [\n",
" \"type\",\n",
" \"description\",\n",
" \"source_id\",\n",
" \"graph_embedding\",\n",
" \"entity_type\",\n",
" \"top_level_node_id\",\n",
" \"size\",\n",
" ],\n",
")\n",
"\n",
"# Relationships renames \"rank\" to \"combined_degree\" to be clear what the default ranking is\n",
"if \"rank\" in final_relationships.columns:\n",
" final_relationships.rename(columns={\"rank\": \"combined_degree\"}, inplace=True)\n",
"\n",
"\n",
"# Compute the parents for each community, to add to communities and reports\n",
"parent_df = get_community_parent(final_nodes)\n",
"\n",
"# Communities previously used the \"id\" field for the Leiden id, but we've moved this to the community field and use a uuid for id like the others\n",
"if \"community\" not in final_communities.columns:\n",
" final_communities[\"community\"] = final_communities[\"id\"].astype(int)\n",
" final_communities[\"human_readable_id\"] = final_communities[\"community\"]\n",
" final_communities[\"id\"] = [str(uuid4()) for _ in range(len(final_communities))]\n",
"if \"parent\" not in final_communities.columns:\n",
" final_communities = final_communities.merge(parent_df, on=\"community\", how=\"left\")\n",
"remove_columns(final_communities, [\"raw_community\"])\n",
"\n",
"# We need int for community and the human_readable_id copy for consistency\n",
"final_community_reports[\"community\"] = final_community_reports[\"community\"].astype(int)\n",
"final_community_reports[\"human_readable_id\"] = final_community_reports[\"community\"]\n",
"if \"parent\" not in final_community_reports.columns:\n",
" final_community_reports = final_community_reports.merge(\n",
" parent_df, on=\"community\", how=\"left\"\n",
" )\n",
"\n",
"await write_table_to_storage(final_documents, \"create_final_documents.parquet\", storage)\n",
"await write_table_to_storage(\n",
" final_text_units, \"create_final_text_units.parquet\", storage\n",
")\n",
"await write_table_to_storage(final_entities, \"create_final_entities.parquet\", storage)\n",
"await write_table_to_storage(final_nodes, \"create_final_nodes.parquet\", storage)\n",
"await write_table_to_storage(\n",
" final_relationships, \"create_final_relationships.parquet\", storage\n",
")\n",
"await write_table_to_storage(\n",
" final_communities, \"create_final_communities.parquet\", storage\n",
")\n",
"await write_table_to_storage(\n",
" final_community_reports, \"create_final_community_reports.parquet\", storage\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from datashaper import NoopVerbCallbacks\n",
"\n",
"from graphrag.cache.factory import create_cache\n",
"from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n",
"\n",
"# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n",
"# We'll construct the context and run this function flow directly to avoid everything else\n",
"\n",
"workflow = next(\n",
" (x for x in pipeline_config.workflows if x.name == \"generate_text_embeddings\"), None\n",
")\n",
"config = workflow.config\n",
"text_embed = config.get(\"text_embed\", {})\n",
"embedded_fields = config.get(\"embedded_fields\", {})\n",
"callbacks = NoopVerbCallbacks()\n",
"cache = create_cache(pipeline_config.cache, PROJECT_DIRECTORY)\n",
"\n",
"await generate_text_embeddings(\n",
" final_documents=None,\n",
" final_relationships=None,\n",
" final_text_units=final_text_units,\n",
" final_entities=final_entities,\n",
" final_community_reports=final_community_reports,\n",
" callbacks=callbacks,\n",
" cache=cache,\n",
" storage=storage,\n",
" text_embed_config=text_embed,\n",
" embedded_fields=embedded_fields,\n",
" snapshot_embeddings_enabled=False,\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

4
graphrag/cache/__init__.py vendored Normal file
View File

@ -0,0 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A package containing cache implementations."""

57
graphrag/cache/factory.py vendored Normal file
View File

@ -0,0 +1,57 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing create_cache method definition."""
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from graphrag.config.enums import CacheType
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
if TYPE_CHECKING:
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
from graphrag.cache.memory_pipeline_cache import InMemoryCache
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
class CacheFactory:
"""A factory class for cache implementations.
Includes a method for users to register a custom cache implementation.
"""
cache_types: ClassVar[dict[str, type]] = {}
@classmethod
def register(cls, cache_type: str, cache: type):
"""Register a custom cache implementation."""
cls.cache_types[cache_type] = cache
@classmethod
def create_cache(
cls, cache_type: CacheType | str | None, root_dir: str, kwargs: dict
) -> PipelineCache:
"""Create or get a cache from the provided type."""
if not cache_type:
return NoopPipelineCache()
match cache_type:
case CacheType.none:
return NoopPipelineCache()
case CacheType.memory:
return InMemoryCache()
case CacheType.file:
return JsonPipelineCache(
FilePipelineStorage(root_dir=root_dir).child(kwargs["base_dir"])
)
case CacheType.blob:
return JsonPipelineCache(BlobPipelineStorage(**kwargs))
case _:
if cache_type in cls.cache_types:
return cls.cache_types[cache_type](**kwargs)
msg = f"Unknown cache type: {cache_type}"
raise ValueError(msg)

65
graphrag/cache/json_pipeline_cache.py vendored Normal file
View File

@ -0,0 +1,65 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'JsonPipelineCache' model."""
import json
from typing import Any
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.storage.pipeline_storage import PipelineStorage
class JsonPipelineCache(PipelineCache):
"""File pipeline cache class definition."""
_storage: PipelineStorage
_encoding: str
def __init__(self, storage: PipelineStorage, encoding="utf-8"):
"""Init method definition."""
self._storage = storage
self._encoding = encoding
async def get(self, key: str) -> str | None:
"""Get method definition."""
if await self.has(key):
try:
data = await self._storage.get(key, encoding=self._encoding)
data = json.loads(data)
except UnicodeDecodeError:
await self._storage.delete(key)
return None
except json.decoder.JSONDecodeError:
await self._storage.delete(key)
return None
else:
return data.get("result")
return None
async def set(self, key: str, value: Any, debug_data: dict | None = None) -> None:
"""Set method definition."""
if value is None:
return
data = {"result": value, **(debug_data or {})}
await self._storage.set(
key, json.dumps(data, ensure_ascii=False), encoding=self._encoding
)
async def has(self, key: str) -> bool:
"""Has method definition."""
return await self._storage.has(key)
async def delete(self, key: str) -> None:
"""Delete method definition."""
if await self.has(key):
await self._storage.delete(key)
async def clear(self) -> None:
"""Clear method definition."""
await self._storage.clear()
def child(self, name: str) -> "JsonPipelineCache":
"""Child method definition."""
return JsonPipelineCache(self._storage.child(name), encoding=self._encoding)

78
graphrag/cache/memory_pipeline_cache.py vendored Normal file
View File

@ -0,0 +1,78 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'InMemoryCache' model."""
from typing import Any
from graphrag.cache.pipeline_cache import PipelineCache
class InMemoryCache(PipelineCache):
"""In memory cache class definition."""
_cache: dict[str, Any]
_name: str
def __init__(self, name: str | None = None):
"""Init method definition."""
self._cache = {}
self._name = name or ""
async def get(self, key: str) -> Any:
"""Get the value for the given key.
Args:
- key - The key to get the value for.
- as_bytes - Whether or not to return the value as bytes.
Returns
-------
- output - The value for the given key.
"""
key = self._create_cache_key(key)
return self._cache.get(key)
async def set(self, key: str, value: Any, debug_data: dict | None = None) -> None:
"""Set the value for the given key.
Args:
- key - The key to set the value for.
- value - The value to set.
"""
key = self._create_cache_key(key)
self._cache[key] = value
async def has(self, key: str) -> bool:
"""Return True if the given key exists in the storage.
Args:
- key - The key to check for.
Returns
-------
- output - True if the key exists in the storage, False otherwise.
"""
key = self._create_cache_key(key)
return key in self._cache
async def delete(self, key: str) -> None:
"""Delete the given key from the storage.
Args:
- key - The key to delete.
"""
key = self._create_cache_key(key)
del self._cache[key]
async def clear(self) -> None:
"""Clear the storage."""
self._cache.clear()
def child(self, name: str) -> PipelineCache:
"""Create a sub cache with the given name."""
return InMemoryCache(name)
def _create_cache_key(self, key: str) -> str:
"""Create a cache key for the given key."""
return f"{self._name}{key}"

65
graphrag/cache/noop_pipeline_cache.py vendored Normal file
View File

@ -0,0 +1,65 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Module containing the NoopPipelineCache implementation."""
from typing import Any
from graphrag.cache.pipeline_cache import PipelineCache
class NoopPipelineCache(PipelineCache):
"""A no-op implementation of the pipeline cache, usually useful for testing."""
async def get(self, key: str) -> Any:
"""Get the value for the given key.
Args:
- key - The key to get the value for.
- as_bytes - Whether or not to return the value as bytes.
Returns
-------
- output - The value for the given key.
"""
return None
async def set(
self, key: str, value: str | bytes | None, debug_data: dict | None = None
) -> None:
"""Set the value for the given key.
Args:
- key - The key to set the value for.
- value - The value to set.
"""
async def has(self, key: str) -> bool:
"""Return True if the given key exists in the cache.
Args:
- key - The key to check for.
Returns
-------
- output - True if the key exists in the cache, False otherwise.
"""
return False
async def delete(self, key: str) -> None:
"""Delete the given key from the cache.
Args:
- key - The key to delete.
"""
async def clear(self) -> None:
"""Clear the cache."""
def child(self, name: str) -> PipelineCache:
"""Create a child cache with the given name.
Args:
- name - The name to create the sub cache with.
"""
return self

67
graphrag/cache/pipeline_cache.py vendored Normal file
View File

@ -0,0 +1,67 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'PipelineCache' model."""
from __future__ import annotations
from abc import ABCMeta, abstractmethod
from typing import Any
class PipelineCache(metaclass=ABCMeta):
"""Provide a cache interface for the pipeline."""
@abstractmethod
async def get(self, key: str) -> Any:
"""Get the value for the given key.
Args:
- key - The key to get the value for.
- as_bytes - Whether or not to return the value as bytes.
Returns
-------
- output - The value for the given key.
"""
@abstractmethod
async def set(self, key: str, value: Any, debug_data: dict | None = None) -> None:
"""Set the value for the given key.
Args:
- key - The key to set the value for.
- value - The value to set.
"""
@abstractmethod
async def has(self, key: str) -> bool:
"""Return True if the given key exists in the cache.
Args:
- key - The key to check for.
Returns
-------
- output - True if the key exists in the cache, False otherwise.
"""
@abstractmethod
async def delete(self, key: str) -> None:
"""Delete the given key from the cache.
Args:
- key - The key to delete.
"""
@abstractmethod
async def clear(self) -> None:
"""Clear the cache."""
@abstractmethod
def child(self, name: str) -> PipelineCache:
"""Create a child cache with the given name.
Args:
- name - The name to create the sub cache with.
"""

View File

@ -0,0 +1,46 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Create a pipeline logger."""
from pathlib import Path
from typing import cast
from datashaper import WorkflowCallbacks
from graphrag.callbacks.blob_workflow_callbacks import BlobWorkflowCallbacks
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
from graphrag.callbacks.file_workflow_callbacks import FileWorkflowCallbacks
from graphrag.config.enums import ReportingType
from graphrag.index.config.reporting import (
PipelineBlobReportingConfig,
PipelineFileReportingConfig,
PipelineReportingConfig,
)
def create_pipeline_reporter(
config: PipelineReportingConfig | None, root_dir: str | None
) -> WorkflowCallbacks:
"""Create a logger for the given pipeline config."""
config = config or PipelineFileReportingConfig(base_dir="logs")
match config.type:
case ReportingType.file:
config = cast("PipelineFileReportingConfig", config)
return FileWorkflowCallbacks(
str(Path(root_dir or "") / (config.base_dir or ""))
)
case ReportingType.console:
return ConsoleWorkflowCallbacks()
case ReportingType.blob:
config = cast("PipelineBlobReportingConfig", config)
return BlobWorkflowCallbacks(
config.connection_string,
config.container_name,
base_dir=config.base_dir,
storage_account_blob_url=config.storage_account_blob_url,
)
case _:
msg = f"Unknown reporting type: {config.type}"
raise ValueError(msg)

View File

@ -0,0 +1,55 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""ParquetExporter module."""
import logging
import traceback
import pandas as pd
from pyarrow.lib import ArrowInvalid, ArrowTypeError
from graphrag.index.typing import ErrorHandlerFn
from graphrag.storage.pipeline_storage import PipelineStorage
log = logging.getLogger(__name__)
class ParquetExporter:
"""ParquetExporter class.
A class that exports dataframe's to a storage destination in .parquet file format.
"""
_storage: PipelineStorage
_on_error: ErrorHandlerFn
def __init__(
self,
storage: PipelineStorage,
on_error: ErrorHandlerFn,
):
"""Create a new Parquet Table TableExporter."""
self._storage = storage
self._on_error = on_error
async def export(self, name: str, data: pd.DataFrame) -> None:
"""Export dataframe to storage."""
filename = f"{name}.parquet"
log.info("exporting parquet table %s", filename)
try:
await self._storage.set(filename, data.to_parquet())
except ArrowTypeError as e:
log.exception("Error while exporting parquet table")
self._on_error(
e,
traceback.format_exc(),
None,
)
except ArrowInvalid as e:
log.exception("Error while exporting parquet table")
self._on_error(
e,
traceback.format_exc(),
None,
)

View File

@ -0,0 +1,43 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to create the base entity graph."""
from typing import Any
import pandas as pd
from graphrag.index.operations.cluster_graph import cluster_graph
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.snapshot import snapshot
from graphrag.storage.pipeline_storage import PipelineStorage
async def compute_communities(
base_relationship_edges: pd.DataFrame,
storage: PipelineStorage,
clustering_strategy: dict[str, Any],
snapshot_transient_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to create the base entity graph."""
graph = create_graph(base_relationship_edges)
communities = cluster_graph(
graph,
strategy=clustering_strategy,
)
base_communities = pd.DataFrame(
communities, columns=pd.Index(["level", "community", "parent", "title"])
).explode("title")
base_communities["community"] = base_communities["community"].astype(int)
if snapshot_transient_enabled:
await snapshot(
base_communities,
name="base_communities",
storage=storage,
formats=["parquet"],
)
return base_communities

View File

@ -0,0 +1,147 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to create the base entity graph."""
from typing import Any
from uuid import uuid4
import networkx as nx
import pandas as pd
from datashaper import (
AsyncType,
VerbCallbacks,
)
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.extract_entities import extract_entities
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
from graphrag.index.operations.summarize_descriptions import (
summarize_descriptions,
)
from graphrag.storage.pipeline_storage import PipelineStorage
async def extract_graph(
text_units: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
extraction_strategy: dict[str, Any] | None = None,
extraction_num_threads: int = 4,
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
summarization_strategy: dict[str, Any] | None = None,
summarization_num_threads: int = 4,
snapshot_graphml_enabled: bool = False,
snapshot_transient_enabled: bool = False,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""All the steps to create the base entity graph."""
# this returns a graph for each text unit, to be merged later
entity_dfs, relationship_dfs = await extract_entities(
text_units,
callbacks,
cache,
text_column="text",
id_column="id",
strategy=extraction_strategy,
async_mode=extraction_async_mode,
entity_types=entity_types,
num_threads=extraction_num_threads,
)
merged_entities = _merge_entities(entity_dfs)
merged_relationships = _merge_relationships(relationship_dfs)
entity_summaries, relationship_summaries = await summarize_descriptions(
merged_entities,
merged_relationships,
callbacks,
cache,
strategy=summarization_strategy,
num_threads=summarization_num_threads,
)
base_relationship_edges = _prep_edges(merged_relationships, relationship_summaries)
graph = create_graph(base_relationship_edges)
base_entity_nodes = _prep_nodes(merged_entities, entity_summaries, graph)
if snapshot_graphml_enabled:
# todo: extract graphs at each level, and add in meta like descriptions
await snapshot_graphml(
graph,
name="graph",
storage=storage,
)
if snapshot_transient_enabled:
await snapshot(
base_entity_nodes,
name="base_entity_nodes",
storage=storage,
formats=["parquet"],
)
await snapshot(
base_relationship_edges,
name="base_relationship_edges",
storage=storage,
formats=["parquet"],
)
return (base_entity_nodes, base_relationship_edges)
def _merge_entities(entity_dfs) -> pd.DataFrame:
all_entities = pd.concat(entity_dfs, ignore_index=True)
return (
all_entities.groupby(["name", "type"], sort=False)
.agg({"description": list, "source_id": list})
.reset_index()
)
def _merge_relationships(relationship_dfs) -> pd.DataFrame:
all_relationships = pd.concat(relationship_dfs, ignore_index=False)
return (
all_relationships.groupby(["source", "target"], sort=False)
.agg({"description": list, "source_id": list, "weight": "sum"})
.reset_index()
)
def _prep_nodes(entities, summaries, graph) -> pd.DataFrame:
degrees_df = _compute_degree(graph)
entities.drop(columns=["description"], inplace=True)
nodes = (
entities.merge(summaries, on="name", how="left")
.merge(degrees_df, on="name")
.drop_duplicates(subset="name")
.rename(columns={"name": "title", "source_id": "text_unit_ids"})
)
nodes = nodes.loc[nodes["title"].notna()].reset_index()
nodes["human_readable_id"] = nodes.index
nodes["id"] = nodes["human_readable_id"].apply(lambda _x: str(uuid4()))
return nodes
def _prep_edges(relationships, summaries) -> pd.DataFrame:
edges = (
relationships.drop(columns=["description"])
.drop_duplicates(subset=["source", "target"])
.merge(summaries, on=["source", "target"], how="left")
.rename(columns={"source_id": "text_unit_ids"})
)
edges["human_readable_id"] = edges.index
edges["id"] = edges["human_readable_id"].apply(lambda _x: str(uuid4()))
return edges
def _compute_degree(graph: nx.Graph) -> pd.DataFrame:
return pd.DataFrame([
{"name": node, "degree": int(degree)}
for node, degree in graph.degree # type: ignore
])

View File

@ -0,0 +1,80 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing create_input method definition."""
import logging
from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import cast
import pandas as pd
from graphrag.config.enums import InputType
from graphrag.config.models.input_config import InputConfig
from graphrag.index.config.input import PipelineInputConfig
from graphrag.index.input.csv import input_type as csv
from graphrag.index.input.csv import load as load_csv
from graphrag.index.input.text import input_type as text
from graphrag.index.input.text import load as load_text
from graphrag.logger.base import ProgressLogger
from graphrag.logger.null_progress import NullProgressLogger
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
log = logging.getLogger(__name__)
loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = {
text: load_text,
csv: load_csv,
}
async def create_input(
config: PipelineInputConfig | InputConfig,
progress_reporter: ProgressLogger | None = None,
root_dir: str | None = None,
) -> pd.DataFrame:
"""Instantiate input data for a pipeline."""
root_dir = root_dir or ""
log.info("loading input from root_dir=%s", config.base_dir)
progress_reporter = progress_reporter or NullProgressLogger()
match config.type:
case InputType.blob:
log.info("using blob storage input")
if config.container_name is None:
msg = "Container name required for blob storage"
raise ValueError(msg)
if (
config.connection_string is None
and config.storage_account_blob_url is None
):
msg = "Connection string or storage account blob url required for blob storage"
raise ValueError(msg)
storage = BlobPipelineStorage(
connection_string=config.connection_string,
storage_account_blob_url=config.storage_account_blob_url,
container_name=config.container_name,
path_prefix=config.base_dir,
)
case InputType.file:
log.info("using file storage for input")
storage = FilePipelineStorage(
root_dir=str(Path(root_dir) / (config.base_dir or ""))
)
case _:
log.info("using file storage for input")
storage = FilePipelineStorage(
root_dir=str(Path(root_dir) / (config.base_dir or ""))
)
if config.file_type in loaders:
progress = progress_reporter.child(
f"Loading Input ({config.file_type})", transient=False
)
loader = loaders[config.file_type]
results = await loader(config, progress, storage)
return cast("pd.DataFrame", results)
msg = f"Unknown input type {config.file_type}"
raise ValueError(msg)

View File

@ -0,0 +1,49 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A mock LLM that returns the given responses."""
from dataclasses import dataclass
from typing import Any, cast
from fnllm import ChatLLM, LLMInput, LLMOutput
from fnllm.types.generics import THistoryEntry, TJsonModel, TModelParameters
from pydantic import BaseModel
from typing_extensions import Unpack
@dataclass
class ContentResponse:
"""A mock content-only response."""
content: str
class MockChatLLM(ChatLLM):
"""A mock LLM that returns the given responses."""
def __init__(self, responses: list[str | BaseModel], json: bool = False):
self.responses = responses
self.response_index = 0
async def __call__(
self,
prompt: str,
**kwargs: Unpack[LLMInput[TJsonModel, THistoryEntry, TModelParameters]],
) -> LLMOutput[Any, TJsonModel, THistoryEntry]:
"""Return the next response in the list."""
response = self.responses[self.response_index % len(self.responses)]
self.response_index += 1
parsed_json = response if isinstance(response, BaseModel) else None
response = (
response.model_dump_json() if isinstance(response, BaseModel) else response
)
return LLMOutput(
output=ContentResponse(content=response),
parsed_json=cast("TJsonModel", parsed_json),
)
def child(self, name):
"""Return a new mock LLM."""
raise NotImplementedError

View File

@ -0,0 +1,12 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing create_graph definition."""
import networkx as nx
import pandas as pd
def create_graph(edges_df: pd.DataFrame) -> nx.Graph:
"""Create a networkx graph from nodes and edges dataframes."""
return nx.from_pandas_edgelist(edges_df)

View File

@ -0,0 +1,74 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing build_steps method definition."""
from typing import Any, cast
import pandas as pd
from datashaper import (
Table,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep
from graphrag.index.flows.compute_communities import compute_communities
from graphrag.storage.pipeline_storage import PipelineStorage
workflow_name = "compute_communities"
def build_steps(
config: PipelineWorkflowConfig,
) -> list[PipelineWorkflowStep]:
"""
Create the base communities from the graph edges.
## Dependencies
* `workflow:extract_graph`
"""
clustering_config = config.get(
"cluster_graph",
{"strategy": {"type": "leiden"}},
)
clustering_strategy = clustering_config.get("strategy")
snapshot_transient = config.get("snapshot_transient", False) or False
return [
{
"verb": workflow_name,
"args": {
"clustering_strategy": clustering_strategy,
"snapshot_transient_enabled": snapshot_transient,
},
"input": ({"source": "workflow:extract_graph"}),
},
]
@verb(
name=workflow_name,
treats_input_tables_as_immutable=True,
)
async def workflow(
storage: PipelineStorage,
runtime_storage: PipelineStorage,
clustering_strategy: dict[str, Any],
snapshot_transient_enabled: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to create the base entity graph."""
base_relationship_edges = await runtime_storage.get("base_relationship_edges")
base_communities = await compute_communities(
base_relationship_edges,
storage,
clustering_strategy=clustering_strategy,
snapshot_transient_enabled=snapshot_transient_enabled,
)
await runtime_storage.set("base_communities", base_communities)
return create_verb_result(cast("Table", pd.DataFrame()))

View File

@ -0,0 +1,107 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing build_steps method definition."""
from typing import Any, cast
import pandas as pd
from datashaper import (
AsyncType,
Table,
VerbCallbacks,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep
from graphrag.index.flows.extract_graph import (
extract_graph,
)
from graphrag.storage.pipeline_storage import PipelineStorage
workflow_name = "extract_graph"
def build_steps(
config: PipelineWorkflowConfig,
) -> list[PipelineWorkflowStep]:
"""
Create the base table for the entity graph.
## Dependencies
* `workflow:create_base_text_units`
"""
entity_extraction_config = config.get("entity_extract", {})
async_mode = entity_extraction_config.get("async_mode", AsyncType.AsyncIO)
extraction_strategy = entity_extraction_config.get("strategy")
extraction_num_threads = entity_extraction_config.get("num_threads", 4)
entity_types = entity_extraction_config.get("entity_types")
summarize_descriptions_config = config.get("summarize_descriptions", {})
summarization_strategy = summarize_descriptions_config.get("strategy")
summarization_num_threads = summarize_descriptions_config.get("num_threads", 4)
snapshot_graphml = config.get("snapshot_graphml", False) or False
snapshot_transient = config.get("snapshot_transient", False) or False
return [
{
"verb": workflow_name,
"args": {
"extraction_strategy": extraction_strategy,
"extraction_num_threads": extraction_num_threads,
"extraction_async_mode": async_mode,
"entity_types": entity_types,
"summarization_strategy": summarization_strategy,
"summarization_num_threads": summarization_num_threads,
"snapshot_graphml_enabled": snapshot_graphml,
"snapshot_transient_enabled": snapshot_transient,
},
"input": ({"source": "workflow:create_base_text_units"}),
},
]
@verb(
name=workflow_name,
treats_input_tables_as_immutable=True,
)
async def workflow(
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
runtime_storage: PipelineStorage,
extraction_strategy: dict[str, Any] | None,
extraction_num_threads: int = 4,
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
summarization_strategy: dict[str, Any] | None = None,
summarization_num_threads: int = 4,
snapshot_graphml_enabled: bool = False,
snapshot_transient_enabled: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to create the base entity graph."""
text_units = await runtime_storage.get("base_text_units")
base_entity_nodes, base_relationship_edges = await extract_graph(
text_units,
callbacks,
cache,
storage,
extraction_strategy=extraction_strategy,
extraction_num_threads=extraction_num_threads,
extraction_async_mode=extraction_async_mode,
entity_types=entity_types,
summarization_strategy=summarization_strategy,
summarization_num_threads=summarization_num_threads,
snapshot_graphml_enabled=snapshot_graphml_enabled,
snapshot_transient_enabled=snapshot_transient_enabled,
)
await runtime_storage.set("base_entity_nodes", base_entity_nodes)
await runtime_storage.set("base_relationship_edges", base_relationship_edges)
return create_verb_result(cast("Table", pd.DataFrame()))

View File

@ -0,0 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Logger utilities and implementations."""

69
graphrag/logger/base.py Normal file
View File

@ -0,0 +1,69 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Base classes for logging and progress reporting."""
from abc import ABC, abstractmethod
from typing import Any
from datashaper.progress.types import Progress
class StatusLogger(ABC):
"""Provides a way to log status updates from the pipeline."""
@abstractmethod
def error(self, message: str, details: dict[str, Any] | None = None):
"""Log an error."""
@abstractmethod
def warning(self, message: str, details: dict[str, Any] | None = None):
"""Log a warning."""
@abstractmethod
def log(self, message: str, details: dict[str, Any] | None = None):
"""Report a log."""
class ProgressLogger(ABC):
"""
Abstract base class for progress loggers.
This is used to report workflow processing progress via mechanisms like progress-bars.
"""
@abstractmethod
def __call__(self, update: Progress):
"""Update progress."""
@abstractmethod
def dispose(self):
"""Dispose of the progress logger."""
@abstractmethod
def child(self, prefix: str, transient=True) -> "ProgressLogger":
"""Create a child progress bar."""
@abstractmethod
def force_refresh(self) -> None:
"""Force a refresh."""
@abstractmethod
def stop(self) -> None:
"""Stop the progress logger."""
@abstractmethod
def error(self, message: str) -> None:
"""Log an error."""
@abstractmethod
def warning(self, message: str) -> None:
"""Log a warning."""
@abstractmethod
def info(self, message: str) -> None:
"""Log information."""
@abstractmethod
def success(self, message: str) -> None:
"""Log success."""

View File

@ -0,0 +1,28 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Console Log."""
from typing import Any
from graphrag.logger.base import StatusLogger
class ConsoleReporter(StatusLogger):
"""A logger that writes to a console."""
def error(self, message: str, details: dict[str, Any] | None = None):
"""Log an error."""
print(message, details) # noqa T201
def warning(self, message: str, details: dict[str, Any] | None = None):
"""Log a warning."""
_print_warning(message)
def log(self, message: str, details: dict[str, Any] | None = None):
"""Log a log."""
print(message, details) # noqa T201
def _print_warning(skk):
print(f"\033[93m {skk}\033[00m") # noqa T201

View File

@ -0,0 +1,43 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Factory functions for creating loggers."""
from typing import ClassVar
from graphrag.logger.base import ProgressLogger
from graphrag.logger.null_progress import NullProgressLogger
from graphrag.logger.print_progress import PrintProgressLogger
from graphrag.logger.rich_progress import RichProgressLogger
from graphrag.logger.types import LoggerType
class LoggerFactory:
"""A factory class for loggers."""
logger_types: ClassVar[dict[str, type]] = {}
@classmethod
def register(cls, logger_type: str, logger: type):
"""Register a custom logger implementation."""
cls.logger_types[logger_type] = logger
@classmethod
def create_logger(
cls, logger_type: LoggerType | str, kwargs: dict | None = None
) -> ProgressLogger:
"""Create a logger based on the provided type."""
if kwargs is None:
kwargs = {}
match logger_type:
case LoggerType.RICH:
return RichProgressLogger("GraphRAG Indexer ")
case LoggerType.PRINT:
return PrintProgressLogger("GraphRAG Indexer ")
case LoggerType.NONE:
return NullProgressLogger()
case _:
if logger_type in cls.logger_types:
return cls.logger_types[logger_type](**kwargs)
# default to null logger if no other logger is found
return NullProgressLogger()

View File

@ -0,0 +1,38 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Null Progress Reporter."""
from graphrag.logger.base import Progress, ProgressLogger
class NullProgressLogger(ProgressLogger):
"""A progress logger that does nothing."""
def __call__(self, update: Progress) -> None:
"""Update progress."""
def dispose(self) -> None:
"""Dispose of the progress logger."""
def child(self, prefix: str, transient: bool = True) -> ProgressLogger:
"""Create a child progress bar."""
return self
def force_refresh(self) -> None:
"""Force a refresh."""
def stop(self) -> None:
"""Stop the progress logger."""
def error(self, message: str) -> None:
"""Log an error."""
def warning(self, message: str) -> None:
"""Log a warning."""
def info(self, message: str) -> None:
"""Log information."""
def success(self, message: str) -> None:
"""Log success."""

View File

@ -0,0 +1,50 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Print Progress Logger."""
from graphrag.logger.base import Progress, ProgressLogger
class PrintProgressLogger(ProgressLogger):
"""A progress logger that prints progress to stdout."""
prefix: str
def __init__(self, prefix: str):
"""Create a new progress logger."""
self.prefix = prefix
print(f"\n{self.prefix}", end="") # noqa T201
def __call__(self, update: Progress) -> None:
"""Update progress."""
print(".", end="") # noqa T201
def dispose(self) -> None:
"""Dispose of the progress logger."""
def child(self, prefix: str, transient: bool = True) -> ProgressLogger:
"""Create a child progress bar."""
return PrintProgressLogger(prefix)
def stop(self) -> None:
"""Stop the progress logger."""
def force_refresh(self) -> None:
"""Force a refresh."""
def error(self, message: str) -> None:
"""Log an error."""
print(f"\n{self.prefix}ERROR: {message}") # noqa T201
def warning(self, message: str) -> None:
"""Log a warning."""
print(f"\n{self.prefix}WARNING: {message}") # noqa T201
def info(self, message: str) -> None:
"""Log information."""
print(f"\n{self.prefix}INFO: {message}") # noqa T201
def success(self, message: str) -> None:
"""Log success."""
print(f"\n{self.prefix}SUCCESS: {message}") # noqa T201

View File

@ -0,0 +1,165 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Rich-based progress logger for CLI use."""
# Print iterations progress
import asyncio
from datashaper import Progress as DSProgress
from rich.console import Console, Group
from rich.live import Live
from rich.progress import Progress, TaskID, TimeElapsedColumn
from rich.spinner import Spinner
from rich.tree import Tree
from graphrag.logger.base import ProgressLogger
# https://stackoverflow.com/a/34325723
class RichProgressLogger(ProgressLogger):
"""A rich-based progress logger for CLI use."""
_console: Console
_group: Group
_tree: Tree
_live: Live
_task: TaskID | None = None
_prefix: str
_transient: bool
_disposing: bool = False
_progressbar: Progress
_last_refresh: float = 0
def dispose(self) -> None:
"""Dispose of the progress logger."""
self._disposing = True
self._live.stop()
@property
def console(self) -> Console:
"""Get the console."""
return self._console
@property
def group(self) -> Group:
"""Get the group."""
return self._group
@property
def tree(self) -> Tree:
"""Get the tree."""
return self._tree
@property
def live(self) -> Live:
"""Get the live."""
return self._live
def __init__(
self,
prefix: str,
parent: "RichProgressLogger | None" = None,
transient: bool = True,
) -> None:
"""Create a new rich-based progress logger."""
self._prefix = prefix
if parent is None:
console = Console()
group = Group(Spinner("dots", prefix), fit=True)
tree = Tree(group)
live = Live(
tree, console=console, refresh_per_second=1, vertical_overflow="crop"
)
live.start()
self._console = console
self._group = group
self._tree = tree
self._live = live
self._transient = False
else:
self._console = parent.console
self._group = parent.group
progress_columns = [*Progress.get_default_columns(), TimeElapsedColumn()]
self._progressbar = Progress(
*progress_columns, console=self._console, transient=transient
)
tree = Tree(prefix)
tree.add(self._progressbar)
tree.hide_root = True
if parent is not None:
parent_tree = parent.tree
parent_tree.hide_root = False
parent_tree.add(tree)
self._tree = tree
self._live = parent.live
self._transient = transient
self.refresh()
def refresh(self) -> None:
"""Perform a debounced refresh."""
now = asyncio.get_event_loop().time()
duration = now - self._last_refresh
if duration > 0.1:
self._last_refresh = now
self.force_refresh()
def force_refresh(self) -> None:
"""Force a refresh."""
self.live.refresh()
def stop(self) -> None:
"""Stop the progress logger."""
self._live.stop()
def child(self, prefix: str, transient: bool = True) -> ProgressLogger:
"""Create a child progress bar."""
return RichProgressLogger(parent=self, prefix=prefix, transient=transient)
def error(self, message: str) -> None:
"""Log an error."""
self._console.print(f"❌ [red]{message}[/red]")
def warning(self, message: str) -> None:
"""Log a warning."""
self._console.print(f"⚠️ [yellow]{message}[/yellow]")
def success(self, message: str) -> None:
"""Log success."""
self._console.print(f"🚀 [green]{message}[/green]")
def info(self, message: str) -> None:
"""Log information."""
self._console.print(message)
def __call__(self, progress_update: DSProgress) -> None:
"""Update progress."""
if self._disposing:
return
progressbar = self._progressbar
if self._task is None:
self._task = progressbar.add_task(self._prefix)
progress_description = ""
if progress_update.description is not None:
progress_description = f" - {progress_update.description}"
completed = progress_update.completed_items or progress_update.percent
total = progress_update.total_items or 1
progressbar.update(
self._task,
completed=completed,
total=total,
description=f"{self._prefix}{progress_description}",
)
if completed == total and self._transient:
progressbar.update(self._task, visible=False)
self.refresh()

22
graphrag/logger/types.py Normal file
View File

@ -0,0 +1,22 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Logging types.
This module defines the types of loggers that can be used.
"""
from enum import Enum
# Note: Code in this module was not included in the factory module because it negatively impacts the CLI experience.
class LoggerType(str, Enum):
"""The type of logger to use."""
RICH = "rich"
PRINT = "print"
NONE = "none"
def __str__(self):
"""Return a string representation of the enum value."""
return self.value

193
graphrag/query/factory.py Normal file
View File

@ -0,0 +1,193 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Query Factory methods to support CLI."""
import tiktoken
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.model.community import Community
from graphrag.model.community_report import CommunityReport
from graphrag.model.covariate import Covariate
from graphrag.model.entity import Entity
from graphrag.model.relationship import Relationship
from graphrag.model.text_unit import TextUnit
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.llm.get_client import get_llm, get_text_embedder
from graphrag.query.structured_search.drift_search.drift_context import (
DRIFTSearchContextBuilder,
)
from graphrag.query.structured_search.drift_search.search import DRIFTSearch
from graphrag.query.structured_search.global_search.community_context import (
GlobalCommunityContext,
)
from graphrag.query.structured_search.global_search.search import GlobalSearch
from graphrag.query.structured_search.local_search.mixed_context import (
LocalSearchMixedContext,
)
from graphrag.query.structured_search.local_search.search import LocalSearch
from graphrag.vector_stores.base import BaseVectorStore
def get_local_search_engine(
config: GraphRagConfig,
reports: list[CommunityReport],
text_units: list[TextUnit],
entities: list[Entity],
relationships: list[Relationship],
covariates: dict[str, list[Covariate]],
response_type: str,
description_embedding_store: BaseVectorStore,
system_prompt: str | None = None,
) -> LocalSearch:
"""Create a local search engine based on data + configuration."""
llm = get_llm(config)
text_embedder = get_text_embedder(config)
token_encoder = tiktoken.get_encoding(config.encoding_model)
ls_config = config.local_search
return LocalSearch(
llm=llm,
system_prompt=system_prompt,
context_builder=LocalSearchMixedContext(
community_reports=reports,
text_units=text_units,
entities=entities,
relationships=relationships,
covariates=covariates,
entity_text_embeddings=description_embedding_store,
embedding_vectorstore_key=EntityVectorStoreKey.ID, # if the vectorstore uses entity title as ids, set this to EntityVectorStoreKey.TITLE
text_embedder=text_embedder,
token_encoder=token_encoder,
),
token_encoder=token_encoder,
llm_params={
"max_tokens": ls_config.llm_max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 1000=1500)
"temperature": ls_config.temperature,
"top_p": ls_config.top_p,
"n": ls_config.n,
},
context_builder_params={
"text_unit_prop": ls_config.text_unit_prop,
"community_prop": ls_config.community_prop,
"conversation_history_max_turns": ls_config.conversation_history_max_turns,
"conversation_history_user_turns_only": True,
"top_k_mapped_entities": ls_config.top_k_entities,
"top_k_relationships": ls_config.top_k_relationships,
"include_entity_rank": True,
"include_relationship_weight": True,
"include_community_rank": False,
"return_candidate_context": False,
"embedding_vectorstore_key": EntityVectorStoreKey.ID, # set this to EntityVectorStoreKey.TITLE if the vectorstore uses entity title as ids
"max_tokens": ls_config.max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)
},
response_type=response_type,
)
def get_global_search_engine(
config: GraphRagConfig,
reports: list[CommunityReport],
entities: list[Entity],
communities: list[Community],
response_type: str,
dynamic_community_selection: bool = False,
map_system_prompt: str | None = None,
reduce_system_prompt: str | None = None,
general_knowledge_inclusion_prompt: str | None = None,
) -> GlobalSearch:
"""Create a global search engine based on data + configuration."""
token_encoder = tiktoken.get_encoding(config.encoding_model)
gs_config = config.global_search
dynamic_community_selection_kwargs = {}
if dynamic_community_selection:
# TODO: Allow for another llm definition only for Global Search to leverage -mini models
dynamic_community_selection_kwargs.update({
"llm": get_llm(config),
"token_encoder": tiktoken.encoding_for_model(config.llm.model),
"keep_parent": gs_config.dynamic_search_keep_parent,
"num_repeats": gs_config.dynamic_search_num_repeats,
"use_summary": gs_config.dynamic_search_use_summary,
"concurrent_coroutines": gs_config.dynamic_search_concurrent_coroutines,
"threshold": gs_config.dynamic_search_threshold,
"max_level": gs_config.dynamic_search_max_level,
})
return GlobalSearch(
llm=get_llm(config),
map_system_prompt=map_system_prompt,
reduce_system_prompt=reduce_system_prompt,
general_knowledge_inclusion_prompt=general_knowledge_inclusion_prompt,
context_builder=GlobalCommunityContext(
community_reports=reports,
communities=communities,
entities=entities,
token_encoder=token_encoder,
dynamic_community_selection=dynamic_community_selection,
dynamic_community_selection_kwargs=dynamic_community_selection_kwargs,
),
token_encoder=token_encoder,
max_data_tokens=gs_config.data_max_tokens,
map_llm_params={
"max_tokens": gs_config.map_max_tokens,
"temperature": gs_config.temperature,
"top_p": gs_config.top_p,
"n": gs_config.n,
},
reduce_llm_params={
"max_tokens": gs_config.reduce_max_tokens,
"temperature": gs_config.temperature,
"top_p": gs_config.top_p,
"n": gs_config.n,
},
allow_general_knowledge=False,
json_mode=False,
context_builder_params={
"use_community_summary": False,
"shuffle_data": True,
"include_community_rank": True,
"min_community_rank": 0,
"community_rank_name": "rank",
"include_community_weight": True,
"community_weight_name": "occurrence weight",
"normalize_community_weight": True,
"max_tokens": gs_config.max_tokens,
"context_name": "Reports",
},
concurrent_coroutines=gs_config.concurrency,
response_type=response_type,
)
def get_drift_search_engine(
config: GraphRagConfig,
reports: list[CommunityReport],
text_units: list[TextUnit],
entities: list[Entity],
relationships: list[Relationship],
description_embedding_store: BaseVectorStore,
local_system_prompt: str | None = None,
) -> DRIFTSearch:
"""Create a local search engine based on data + configuration."""
llm = get_llm(config)
text_embedder = get_text_embedder(config)
token_encoder = tiktoken.get_encoding(config.encoding_model)
return DRIFTSearch(
llm=llm,
context_builder=DRIFTSearchContextBuilder(
chat_llm=llm,
text_embedder=text_embedder,
entities=entities,
relationships=relationships,
reports=reports,
entity_text_embeddings=description_embedding_store,
text_units=text_units,
local_system_prompt=local_system_prompt,
config=config.drift_search,
),
token_encoder=token_encoder,
)

View File

@ -0,0 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The storage package root."""

View File

@ -0,0 +1,375 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Azure Blob Storage implementation of PipelineStorage."""
import logging
import re
from collections.abc import Iterator
from pathlib import Path
from typing import Any
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient
from datashaper import Progress
from graphrag.logger.base import ProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
log = logging.getLogger(__name__)
class BlobPipelineStorage(PipelineStorage):
"""The Blob-Storage implementation."""
_connection_string: str | None
_container_name: str
_path_prefix: str
_encoding: str
_storage_account_blob_url: str | None
def __init__(
self,
connection_string: str | None,
container_name: str,
encoding: str = "utf-8",
path_prefix: str | None = None,
storage_account_blob_url: str | None = None,
):
"""Create a new BlobStorage instance."""
if connection_string:
self._blob_service_client = BlobServiceClient.from_connection_string(
connection_string
)
else:
if storage_account_blob_url is None:
msg = "Either connection_string or storage_account_blob_url must be provided."
raise ValueError(msg)
self._blob_service_client = BlobServiceClient(
account_url=storage_account_blob_url,
credential=DefaultAzureCredential(),
)
self._encoding = encoding
self._container_name = container_name
self._connection_string = connection_string
self._path_prefix = path_prefix or ""
self._storage_account_blob_url = storage_account_blob_url
self._storage_account_name = (
storage_account_blob_url.split("//")[1].split(".")[0]
if storage_account_blob_url
else None
)
log.info(
"creating blob storage at container=%s, path=%s",
self._container_name,
self._path_prefix,
)
self.create_container()
def create_container(self) -> None:
"""Create the container if it does not exist."""
if not self.container_exists():
container_name = self._container_name
container_names = [
container.name
for container in self._blob_service_client.list_containers()
]
if container_name not in container_names:
self._blob_service_client.create_container(container_name)
def delete_container(self) -> None:
"""Delete the container."""
if self.container_exists():
self._blob_service_client.delete_container(self._container_name)
def container_exists(self) -> bool:
"""Check if the container exists."""
container_name = self._container_name
container_names = [
container.name for container in self._blob_service_client.list_containers()
]
return container_name in container_names
def find(
self,
file_pattern: re.Pattern[str],
base_dir: str | None = None,
progress: ProgressLogger | None = None,
file_filter: dict[str, Any] | None = None,
max_count=-1,
) -> Iterator[tuple[str, dict[str, Any]]]:
"""Find blobs in a container using a file pattern, as well as a custom filter function.
Params:
base_dir: The name of the base container.
file_pattern: The file pattern to use.
file_filter: A dictionary of key-value pairs to filter the blobs.
max_count: The maximum number of blobs to return. If -1, all blobs are returned.
Returns
-------
An iterator of blob names and their corresponding regex matches.
"""
base_dir = base_dir or ""
log.info(
"search container %s for files matching %s",
self._container_name,
file_pattern.pattern,
)
def blobname(blob_name: str) -> str:
if blob_name.startswith(self._path_prefix):
blob_name = blob_name.replace(self._path_prefix, "", 1)
if blob_name.startswith("/"):
blob_name = blob_name[1:]
return blob_name
def item_filter(item: dict[str, Any]) -> bool:
if file_filter is None:
return True
return all(re.match(value, item[key]) for key, value in file_filter.items())
try:
container_client = self._blob_service_client.get_container_client(
self._container_name
)
all_blobs = list(container_client.list_blobs())
num_loaded = 0
num_total = len(list(all_blobs))
num_filtered = 0
for blob in all_blobs:
match = file_pattern.match(blob.name)
if match and blob.name.startswith(base_dir):
group = match.groupdict()
if item_filter(group):
yield (blobname(blob.name), group)
num_loaded += 1
if max_count > 0 and num_loaded >= max_count:
break
else:
num_filtered += 1
else:
num_filtered += 1
if progress is not None:
progress(
_create_progress_status(num_loaded, num_filtered, num_total)
)
except Exception:
log.exception(
"Error finding blobs: base_dir=%s, file_pattern=%s, file_filter=%s",
base_dir,
file_pattern,
file_filter,
)
raise
async def get(
self, key: str, as_bytes: bool | None = False, encoding: str | None = None
) -> Any:
"""Get a value from the cache."""
try:
key = self._keyname(key)
container_client = self._blob_service_client.get_container_client(
self._container_name
)
blob_client = container_client.get_blob_client(key)
blob_data = blob_client.download_blob().readall()
if not as_bytes:
coding = encoding or self._encoding
blob_data = blob_data.decode(coding)
except Exception:
log.exception("Error getting key %s", key)
return None
else:
return blob_data
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
"""Set a value in the cache."""
try:
key = self._keyname(key)
container_client = self._blob_service_client.get_container_client(
self._container_name
)
blob_client = container_client.get_blob_client(key)
if isinstance(value, bytes):
blob_client.upload_blob(value, overwrite=True)
else:
coding = encoding or self._encoding
blob_client.upload_blob(value.encode(coding), overwrite=True)
except Exception:
log.exception("Error setting key %s: %s", key)
def set_df_json(self, key: str, dataframe: Any) -> None:
"""Set a json dataframe."""
if self._connection_string is None and self._storage_account_name:
dataframe.to_json(
self._abfs_url(key),
storage_options={
"account_name": self._storage_account_name,
"credential": DefaultAzureCredential(),
},
orient="records",
lines=True,
force_ascii=False,
)
else:
dataframe.to_json(
self._abfs_url(key),
storage_options={"connection_string": self._connection_string},
orient="records",
lines=True,
force_ascii=False,
)
def set_df_parquet(self, key: str, dataframe: Any) -> None:
"""Set a parquet dataframe."""
if self._connection_string is None and self._storage_account_name:
dataframe.to_parquet(
self._abfs_url(key),
storage_options={
"account_name": self._storage_account_name,
"credential": DefaultAzureCredential(),
},
)
else:
dataframe.to_parquet(
self._abfs_url(key),
storage_options={"connection_string": self._connection_string},
)
async def has(self, key: str) -> bool:
"""Check if a key exists in the cache."""
key = self._keyname(key)
container_client = self._blob_service_client.get_container_client(
self._container_name
)
blob_client = container_client.get_blob_client(key)
return blob_client.exists()
async def delete(self, key: str) -> None:
"""Delete a key from the cache."""
key = self._keyname(key)
container_client = self._blob_service_client.get_container_client(
self._container_name
)
blob_client = container_client.get_blob_client(key)
blob_client.delete_blob()
async def clear(self) -> None:
"""Clear the cache."""
def child(self, name: str | None) -> "PipelineStorage":
"""Create a child storage instance."""
if name is None:
return self
path = str(Path(self._path_prefix) / name)
return BlobPipelineStorage(
self._connection_string,
self._container_name,
self._encoding,
path,
self._storage_account_blob_url,
)
def keys(self) -> list[str]:
"""Return the keys in the storage."""
msg = "Blob storage does yet not support listing keys."
raise NotImplementedError(msg)
def _keyname(self, key: str) -> str:
"""Get the key name."""
return str(Path(self._path_prefix) / key)
def _abfs_url(self, key: str) -> str:
"""Get the ABFS URL."""
path = str(Path(self._container_name) / self._path_prefix / key)
return f"abfs://{path}"
def create_blob_storage(
connection_string: str | None,
storage_account_blob_url: str | None,
container_name: str,
base_dir: str | None,
) -> PipelineStorage:
"""Create a blob based storage."""
log.info("Creating blob storage at %s", container_name)
if container_name is None:
msg = "No container name provided for blob storage."
raise ValueError(msg)
if connection_string is None and storage_account_blob_url is None:
msg = "No storage account blob url provided for blob storage."
raise ValueError(msg)
return BlobPipelineStorage(
connection_string=connection_string,
container_name=container_name,
path_prefix=base_dir,
storage_account_blob_url=storage_account_blob_url,
)
def validate_blob_container_name(container_name: str):
"""
Check if the provided blob container name is valid based on Azure rules.
- A blob container name must be between 3 and 63 characters in length.
- Start with a letter or number
- All letters used in blob container names must be lowercase.
- Contain only letters, numbers, or the hyphen.
- Consecutive hyphens are not permitted.
- Cannot end with a hyphen.
Args:
-----
container_name (str)
The blob container name to be validated.
Returns
-------
bool: True if valid, False otherwise.
"""
# Check the length of the name
if len(container_name) < 3 or len(container_name) > 63:
return ValueError(
f"Container name must be between 3 and 63 characters in length. Name provided was {len(container_name)} characters long."
)
# Check if the name starts with a letter or number
if not container_name[0].isalnum():
return ValueError(
f"Container name must start with a letter or number. Starting character was {container_name[0]}."
)
# Check for valid characters (letters, numbers, hyphen) and lowercase letters
if not re.match(r"^[a-z0-9-]+$", container_name):
return ValueError(
f"Container name must only contain:\n- lowercase letters\n- numbers\n- or hyphens\nName provided was {container_name}."
)
# Check for consecutive hyphens
if "--" in container_name:
return ValueError(
f"Container name cannot contain consecutive hyphens. Name provided was {container_name}."
)
# Check for hyphens at the end of the name
if container_name[-1] == "-":
return ValueError(
f"Container name cannot end with a hyphen. Name provided was {container_name}."
)
return True
def _create_progress_status(
num_loaded: int, num_filtered: int, num_total: int
) -> Progress:
return Progress(
total_items=num_total,
completed_items=num_loaded + num_filtered,
description=f"{num_loaded} files loaded ({num_filtered} filtered)",
)

View File

@ -0,0 +1,308 @@
#Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Azure CosmosDB Storage implementation of PipelineStorage."""
import json
import logging
import re
from collections.abc import Iterator
from io import BytesIO, StringIO
from typing import Any
import pandas as pd
from azure.cosmos import CosmosClient
from azure.cosmos.partition_key import PartitionKey
from azure.identity import DefaultAzureCredential
from datashaper import Progress
from graphrag.logging.base import ProgressReporter
from .pipeline_storage import PipelineStorage
log = logging.getLogger(__name__)
class CosmosDBPipelineStorage(PipelineStorage):
"""The CosmosDB-Storage Implementation."""
_cosmosdb_account_url: str | None
_connection_string: str | None
_database_name: str
_current_container: str | None
_encoding: str
def __init__(
self,
cosmosdb_account_url: str | None,
connection_string: str | None,
database_name: str,
encoding: str | None = None,
current_container: str | None = None,
):
"""Initialize the CosmosDB-Storage."""
if connection_string:
self._cosmos_client = CosmosClient.from_connection_string(
connection_string
)
else:
if cosmosdb_account_url is None:
msg = "Either connection_string or cosmosdb_accoun_url must be provided."
raise ValueError(msg)
self._cosmos_client = CosmosClient(
url=cosmosdb_account_url,
credential=DefaultAzureCredential(),
)
self._encoding = encoding or "utf-8"
self._database_name = database_name
self._connection_string = connection_string
self._cosmosdb_account_url = cosmosdb_account_url
self._current_container = current_container or None
self._cosmosdb_account_name = (
cosmosdb_account_url.split("//")[1].split(".")[0]
if cosmosdb_account_url
else None
)
self._database_client = self._cosmos_client.get_database_client(
self._database_name
)
log.info(
"creating cosmosdb storage with account: %s and database: %s",
self._cosmosdb_account_name,
self._database_name,
)
self.create_database()
if self._current_container is not None:
self.create_container()
def create_database(self) -> None:
"""Create the database if it doesn't exist."""
database_name = self._database_name
self._cosmos_client.create_database_if_not_exists(id=database_name)
def delete_database(self) -> None:
"""Delete the database if it exists."""
if self.database_exists():
self._cosmos_client.delete_database(self._database_name)
def database_exists(self) -> bool:
"""Check if the database exists."""
database_name = self._database_name
database_names = [
database["id"] for database in self._cosmos_client.list_databases()
]
return database_name in database_names
def find(
self,
file_pattern: re.Pattern[str],
base_dir: str | None = None,
progress: ProgressReporter | None = None,
file_filter: dict[str, Any] | None = None,
max_count=-1,
) -> Iterator[tuple[str, dict[str, Any]]]:
"""Find documents in a Cosmos DB container using a file pattern and custom filter function.
Params:
base_dir: The name of the base directory (not used in Cosmos DB context).
file_pattern: The file pattern to use.
file_filter: A dictionary of key-value pairs to filter the documents.
max_count: The maximum number of documents to return. If -1, all documents are returned.
Returns
-------
An iterator of document IDs and their corresponding regex matches.
"""
base_dir = base_dir or ""
log.info(
"search container %s for documents matching %s",
self._current_container,
file_pattern.pattern,
)
def item_filter(item: dict[str, Any]) -> bool:
if file_filter is None:
return True
return all(re.match(value, item.get(key, "")) for key, value in file_filter.items())
try:
database_client = self._database_client
container_client = database_client.get_container_client(str(self._current_container))
query = "SELECT * FROM c WHERE CONTAINS(c.id, @pattern)"
parameters: list[dict[str, Any]] = [{"name": "@pattern", "value": file_pattern.pattern}]
if file_filter:
for key, value in file_filter.items():
query += f" AND c.{key} = @{key}"
parameters.append({"name": f"@{key}", "value": value})
items = container_client.query_items(query=query, parameters=parameters, enable_cross_partition_query=True)
num_loaded = 0
num_total = len(list(items))
num_filtered = 0
for item in items:
match = file_pattern.match(item["id"])
if match:
group = match.groupdict()
if item_filter(group):
yield (item["id"], group)
num_loaded += 1
if max_count > 0 and num_loaded >= max_count:
break
else:
num_filtered += 1
else:
num_filtered += 1
if progress is not None:
progress(
_create_progress_status(num_loaded, num_filtered, num_total)
)
except Exception:
log.exception("An error occurred while searching for documents in Cosmos DB.")
async def get(
self, key: str, as_bytes: bool | None = None, encoding: str | None = None
) -> Any:
"""Get a file in the database for the given key."""
try:
database_client = self._database_client
if self._current_container is not None:
container_client = database_client.get_container_client(
self._current_container
)
item = container_client.read_item(item=key, partition_key=key)
item_body = item.get("body")
item_json_str = json.dumps(item_body)
if as_bytes:
item_df = pd.read_json(
StringIO(item_json_str),
orient="records",
lines=False
)
return item_df.to_parquet()
return item_json_str
except Exception:
log.exception("Error reading item %s", key)
return None
else:
return None
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
"""Set a file in the database for the given key."""
try:
database_client = self._database_client
if self._current_container is not None:
container_client = database_client.get_container_client(
self._current_container
)
if isinstance(value, bytes):
value_df = pd.read_parquet(BytesIO(value))
value_json = value_df.to_json(
orient="records",
lines=False,
force_ascii=False
)
cosmos_db_item = {
"id": key,
"body": json.loads(value_json)
}
else:
cosmos_db_item = {
"id": key,
"body": json.loads(value)
}
container_client.upsert_item(body=cosmos_db_item)
except Exception:
log.exception("Error writing item %s", key)
async def has(self, key: str) -> bool:
"""Check if the given file exists in the cosmosdb storage."""
database_client = self._database_client
if self._current_container is not None:
container_client = database_client.get_container_client(
self._current_container
)
item_names = [
item["id"]
for item in container_client.read_all_items()
]
return key in item_names
return False
async def delete(self, key: str) -> None:
"""Delete the given file from the cosmosdb storage."""
database_client = self._database_client
if self._current_container is not None:
container_client = database_client.get_container_client(
self._current_container
)
container_client.delete_item(item=key, partition_key=key)
async def clear(self) -> None:
"""Clear the cosmosdb storage."""
def keys(self) -> list[str]:
"""Return the keys in the storage."""
msg = "CosmosDB storage does yet not support listing keys."
raise NotImplementedError(msg)
def child(self, name: str | None) -> "PipelineStorage":
"""Create a child storage instance."""
return self
def create_container(self) -> None:
"""Create a container for the current container name if it doesn't exist."""
database_client = self._database_client
if self._current_container is not None:
partition_key = PartitionKey(path="/id", kind="Hash")
database_client.create_container_if_not_exists(
id=self._current_container,
partition_key=partition_key,
)
def delete_container(self) -> None:
"""Delete the container with the current container name if it exists."""
database_client = self._database_client
if self.container_exists() and self._current_container is not None:
database_client.delete_container(self._current_container)
def container_exists(self) -> bool:
"""Check if the container with the current container name exists."""
database_client = self._database_client
container_names = [
container["id"]
for container in database_client.list_containers()
]
return self._current_container in container_names
def create_cosmosdb_storage(
cosmosdb_account_url: str | None,
connection_string: str | None,
base_dir: str,
container_name: str | None,
) -> PipelineStorage:
"""Create a CosmosDB storage instance."""
log.info("Creating cosmosdb storage")
if base_dir is None:
msg = "No base_dir provided for database name"
raise ValueError(msg)
if connection_string is None and cosmosdb_account_url is None:
msg = "No cosmosdb account url provided"
raise ValueError(msg)
return CosmosDBPipelineStorage(
cosmosdb_account_url=cosmosdb_account_url,
connection_string=connection_string,
database_name=base_dir,
current_container=container_name,
)
def _create_progress_status(
num_loaded: int, num_filtered: int, num_total: int
) -> Progress:
return Progress(
total_items=num_total,
completed_items=num_loaded + num_filtered,
description=f"{num_loaded} files loaded ({num_filtered} filtered)",
)

View File

@ -0,0 +1,48 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Factory functions for creating storage."""
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from graphrag.config.enums import StorageType
from graphrag.storage.blob_pipeline_storage import create_blob_storage
from graphrag.storage.file_pipeline_storage import create_file_storage
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
if TYPE_CHECKING:
from graphrag.storage.pipeline_storage import PipelineStorage
class StorageFactory:
"""A factory class for storage implementations.
Includes a method for users to register a custom storage implementation.
"""
storage_types: ClassVar[dict[str, type]] = {}
@classmethod
def register(cls, storage_type: str, storage: type):
"""Register a custom storage implementation."""
cls.storage_types[storage_type] = storage
@classmethod
def create_storage(
cls, storage_type: StorageType | str, kwargs: dict
) -> PipelineStorage:
"""Create or get a storage object from the provided type."""
match storage_type:
case StorageType.blob:
return create_blob_storage(**kwargs)
case StorageType.file:
return create_file_storage(**kwargs)
case StorageType.memory:
return MemoryPipelineStorage()
case _:
if storage_type in cls.storage_types:
return cls.storage_types[storage_type](**kwargs)
msg = f"Unknown storage type: {storage_type}"
raise ValueError(msg)

View File

@ -0,0 +1,169 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'FileStorage' and 'FilePipelineStorage' models."""
import logging
import os
import re
import shutil
from collections.abc import Iterator
from pathlib import Path
from typing import Any, cast
import aiofiles
from aiofiles.os import remove
from aiofiles.ospath import exists
from datashaper import Progress
from graphrag.logger.base import ProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
log = logging.getLogger(__name__)
class FilePipelineStorage(PipelineStorage):
"""File storage class definition."""
_root_dir: str
_encoding: str
def __init__(self, root_dir: str = "", encoding: str = "utf-8"):
"""Init method definition."""
self._root_dir = root_dir
self._encoding = encoding
Path(self._root_dir).mkdir(parents=True, exist_ok=True)
def find(
self,
file_pattern: re.Pattern[str],
base_dir: str | None = None,
progress: ProgressLogger | None = None,
file_filter: dict[str, Any] | None = None,
max_count=-1,
) -> Iterator[tuple[str, dict[str, Any]]]:
"""Find files in the storage using a file pattern, as well as a custom filter function."""
def item_filter(item: dict[str, Any]) -> bool:
if file_filter is None:
return True
return all(re.match(value, item[key]) for key, value in file_filter.items())
search_path = Path(self._root_dir) / (base_dir or "")
log.info("search %s for files matching %s", search_path, file_pattern.pattern)
all_files = list(search_path.rglob("**/*"))
num_loaded = 0
num_total = len(all_files)
num_filtered = 0
for file in all_files:
match = file_pattern.match(f"{file}")
if match:
group = match.groupdict()
if item_filter(group):
filename = f"{file}".replace(self._root_dir, "")
if filename.startswith(os.sep):
filename = filename[1:]
yield (filename, group)
num_loaded += 1
if max_count > 0 and num_loaded >= max_count:
break
else:
num_filtered += 1
else:
num_filtered += 1
if progress is not None:
progress(_create_progress_status(num_loaded, num_filtered, num_total))
async def get(
self, key: str, as_bytes: bool | None = False, encoding: str | None = None
) -> Any:
"""Get method definition."""
file_path = join_path(self._root_dir, key)
if await self.has(key):
return await self._read_file(file_path, as_bytes, encoding)
if await exists(key):
# Lookup for key, as it is pressumably a new file loaded from inputs
# and not yet written to storage
return await self._read_file(key, as_bytes, encoding)
return None
async def _read_file(
self,
path: str | Path,
as_bytes: bool | None = False,
encoding: str | None = None,
) -> Any:
"""Read the contents of a file."""
read_type = "rb" if as_bytes else "r"
encoding = None if as_bytes else (encoding or self._encoding)
async with aiofiles.open(
path,
cast("Any", read_type),
encoding=encoding,
) as f:
return await f.read()
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
"""Set method definition."""
is_bytes = isinstance(value, bytes)
write_type = "wb" if is_bytes else "w"
encoding = None if is_bytes else encoding or self._encoding
async with aiofiles.open(
join_path(self._root_dir, key),
cast("Any", write_type),
encoding=encoding,
) as f:
await f.write(value)
async def has(self, key: str) -> bool:
"""Has method definition."""
return await exists(join_path(self._root_dir, key))
async def delete(self, key: str) -> None:
"""Delete method definition."""
if await self.has(key):
await remove(join_path(self._root_dir, key))
async def clear(self) -> None:
"""Clear method definition."""
for file in Path(self._root_dir).glob("*"):
if file.is_dir():
shutil.rmtree(file)
else:
file.unlink()
def child(self, name: str | None) -> "PipelineStorage":
"""Create a child storage instance."""
if name is None:
return self
return FilePipelineStorage(str(Path(self._root_dir) / Path(name)))
def keys(self) -> list[str]:
"""Return the keys in the storage."""
return [item.name for item in Path(self._root_dir).iterdir() if item.is_file()]
def join_path(file_path: str, file_name: str) -> Path:
"""Join a path and a file. Independent of the OS."""
return Path(file_path) / Path(file_name).parent / Path(file_name).name
def create_file_storage(**kwargs: Any) -> PipelineStorage:
"""Create a file based storage."""
base_dir = kwargs["base_dir"]
log.info("Creating file storage at %s", base_dir)
return FilePipelineStorage(root_dir=base_dir)
def _create_progress_status(
num_loaded: int, num_filtered: int, num_total: int
) -> Progress:
return Progress(
total_items=num_total,
completed_items=num_loaded + num_filtered,
description=f"{num_loaded} files loaded ({num_filtered} filtered)",
)

View File

@ -0,0 +1,78 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'InMemoryStorage' model."""
from typing import TYPE_CHECKING, Any
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
if TYPE_CHECKING:
from graphrag.storage.pipeline_storage import PipelineStorage
class MemoryPipelineStorage(FilePipelineStorage):
"""In memory storage class definition."""
_storage: dict[str, Any]
def __init__(self):
"""Init method definition."""
super().__init__()
self._storage = {}
async def get(
self, key: str, as_bytes: bool | None = None, encoding: str | None = None
) -> Any:
"""Get the value for the given key.
Args:
- key - The key to get the value for.
- as_bytes - Whether or not to return the value as bytes.
Returns
-------
- output - The value for the given key.
"""
return self._storage.get(key)
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
"""Set the value for the given key.
Args:
- key - The key to set the value for.
- value - The value to set.
"""
self._storage[key] = value
async def has(self, key: str) -> bool:
"""Return True if the given key exists in the storage.
Args:
- key - The key to check for.
Returns
-------
- output - True if the key exists in the storage, False otherwise.
"""
return key in self._storage
async def delete(self, key: str) -> None:
"""Delete the given key from the storage.
Args:
- key - The key to delete.
"""
del self._storage[key]
async def clear(self) -> None:
"""Clear the storage."""
self._storage.clear()
def child(self, name: str | None) -> "PipelineStorage":
"""Create a child storage instance."""
return MemoryPipelineStorage()
def keys(self) -> list[str]:
"""Return the keys in the storage."""
return list(self._storage.keys())

View File

@ -0,0 +1,82 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'PipelineStorage' model."""
import re
from abc import ABCMeta, abstractmethod
from collections.abc import Iterator
from typing import Any
from graphrag.logger.base import ProgressLogger
class PipelineStorage(metaclass=ABCMeta):
"""Provide a storage interface for the pipeline. This is where the pipeline will store its output data."""
@abstractmethod
def find(
self,
file_pattern: re.Pattern[str],
base_dir: str | None = None,
progress: ProgressLogger | None = None,
file_filter: dict[str, Any] | None = None,
max_count=-1,
) -> Iterator[tuple[str, dict[str, Any]]]:
"""Find files in the storage using a file pattern, as well as a custom filter function."""
@abstractmethod
async def get(
self, key: str, as_bytes: bool | None = None, encoding: str | None = None
) -> Any:
"""Get the value for the given key.
Args:
- key - The key to get the value for.
- as_bytes - Whether or not to return the value as bytes.
Returns
-------
- output - The value for the given key.
"""
@abstractmethod
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
"""Set the value for the given key.
Args:
- key - The key to set the value for.
- value - The value to set.
"""
@abstractmethod
async def has(self, key: str) -> bool:
"""Return True if the given key exists in the storage.
Args:
- key - The key to check for.
Returns
-------
- output - True if the key exists in the storage, False otherwise.
"""
@abstractmethod
async def delete(self, key: str) -> None:
"""Delete the given key from the storage.
Args:
- key - The key to delete.
"""
@abstractmethod
async def clear(self) -> None:
"""Clear the storage."""
@abstractmethod
def child(self, name: str | None) -> "PipelineStorage":
"""Create a child storage instance."""
@abstractmethod
def keys(self) -> list[str]:
"""List all keys in the storage."""

View File

@ -0,0 +1,125 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from typing import Any, cast
import pandas as pd
from datashaper import (
Table,
VerbInput,
VerbResult,
create_verb_result,
)
from graphrag.index.config.pipeline import PipelineWorkflowReference
from graphrag.index.run import run_pipeline
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
from graphrag.storage.pipeline_storage import PipelineStorage
async def mock_verb(
input: VerbInput, storage: PipelineStorage, **_kwargs
) -> VerbResult:
source = cast("pd.DataFrame", input.get_input())
output = source[["id"]]
await storage.set("mock_write", source[["id"]])
return create_verb_result(
cast(
"Table",
output,
)
)
async def mock_no_return_verb(
input: VerbInput, storage: PipelineStorage, **_kwargs
) -> VerbResult:
source = cast("pd.DataFrame", input.get_input())
# write some outputs to storage independent of the return
await storage.set("empty_write", source[["name"]])
return create_verb_result(
cast(
"Table",
pd.DataFrame(),
)
)
async def test_normal_result_exports_parquet():
mock_verbs: Any = {"mock_verb": mock_verb}
mock_workflows: Any = {
"mock_workflow": lambda _x: [
{
"verb": "mock_verb",
"args": {
"column": "test",
},
}
]
}
workflows = [
PipelineWorkflowReference(
name="mock_workflow",
config=None,
)
]
dataset = pd.DataFrame({"id": [1, 2, 3], "name": ["a", "b", "c"]})
storage = MemoryPipelineStorage()
pipeline_result = [
gen
async for gen in run_pipeline(
workflows,
dataset,
storage=storage,
additional_workflows=mock_workflows,
additional_verbs=mock_verbs,
)
]
assert len(pipeline_result) == 1
assert storage.keys() == ["stats.json", "mock_write", "mock_workflow.parquet"], (
"Mock workflow output should be written to storage by the exporter when there is a non-empty data frame"
)
async def test_empty_result_does_not_export_parquet():
mock_verbs: Any = {"mock_no_return_verb": mock_no_return_verb}
mock_workflows: Any = {
"mock_workflow": lambda _x: [
{
"verb": "mock_no_return_verb",
"args": {
"column": "test",
},
}
]
}
workflows = [
PipelineWorkflowReference(
name="mock_workflow",
config=None,
)
]
dataset = pd.DataFrame({"id": [1, 2, 3], "name": ["a", "b", "c"]})
storage = MemoryPipelineStorage()
pipeline_result = [
gen
async for gen in run_pipeline(
workflows,
dataset,
storage=storage,
additional_workflows=mock_workflows,
additional_verbs=mock_verbs,
)
]
assert len(pipeline_result) == 1
assert storage.keys() == [
"stats.json",
"empty_write",
], "Mock workflow output should not be written to storage by the exporter"

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,52 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.index.flows.compute_communities import (
compute_communities,
)
from graphrag.index.run.utils import create_run_context
from graphrag.index.workflows.v1.compute_communities import (
workflow_name,
)
from .util import (
compare_outputs,
get_config_for_workflow,
load_test_table,
)
async def test_compute_communities():
edges = load_test_table("base_relationship_edges")
expected = load_test_table("base_communities")
context = create_run_context(None, None, None)
config = get_config_for_workflow(workflow_name)
clustering_strategy = config["cluster_graph"]["strategy"]
actual = await compute_communities(
edges, storage=context.storage, clustering_strategy=clustering_strategy
)
columns = list(expected.columns.values)
compare_outputs(actual, expected, columns)
assert len(actual.columns) == len(expected.columns)
async def test_compute_communities_with_snapshots():
edges = load_test_table("base_relationship_edges")
context = create_run_context(None, None, None)
config = get_config_for_workflow(workflow_name)
clustering_strategy = config["cluster_graph"]["strategy"]
await compute_communities(
edges,
storage=context.storage,
clustering_strategy=clustering_strategy,
snapshot_transient_enabled=True,
)
assert context.storage.keys() == [
"base_communities.parquet",
], "Community snapshot keys differ"

View File

@ -0,0 +1,159 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
import pytest
from graphrag.config.enums import LLMType
from graphrag.index.run.utils import create_run_context
from graphrag.index.workflows.v1.extract_graph import (
build_steps,
workflow_name,
)
from .util import (
get_config_for_workflow,
get_workflow_output,
load_input_tables,
load_test_table,
)
MOCK_LLM_ENTITY_RESPONSES = [
"""
("entity"<|>COMPANY_A<|>COMPANY<|>Company_A is a test company)
##
("entity"<|>COMPANY_B<|>COMPANY<|>Company_B owns Company_A and also shares an address with Company_A)
##
("entity"<|>PERSON_C<|>PERSON<|>Person_C is director of Company_A)
##
("relationship"<|>COMPANY_A<|>COMPANY_B<|>Company_A and Company_B are related because Company_A is 100% owned by Company_B and the two companies also share the same address)<|>2)
##
("relationship"<|>COMPANY_A<|>PERSON_C<|>Company_A and Person_C are related because Person_C is director of Company_A<|>1))
""".strip()
]
MOCK_LLM_ENTITY_CONFIG = {
"type": LLMType.StaticResponse,
"responses": MOCK_LLM_ENTITY_RESPONSES,
}
MOCK_LLM_SUMMARIZATION_RESPONSES = [
"""
This is a MOCK response for the LLM. It is summarized!
""".strip()
]
MOCK_LLM_SUMMARIZATION_CONFIG = {
"type": LLMType.StaticResponse,
"responses": MOCK_LLM_SUMMARIZATION_RESPONSES,
}
async def test_extract_graph():
input_tables = load_input_tables([
"workflow:create_base_text_units",
])
nodes_expected = load_test_table("base_entity_nodes")
edges_expected = load_test_table("base_relationship_edges")
context = create_run_context(None, None, None)
await context.runtime_storage.set(
"base_text_units", input_tables["workflow:create_base_text_units"]
)
config = get_config_for_workflow(workflow_name)
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_SUMMARIZATION_CONFIG
steps = build_steps(config)
await get_workflow_output(
input_tables,
{
"steps": steps,
},
context=context,
)
# graph construction creates transient tables for nodes, edges, and communities
nodes_actual = await context.runtime_storage.get("base_entity_nodes")
edges_actual = await context.runtime_storage.get("base_relationship_edges")
assert len(nodes_actual.columns) == len(nodes_expected.columns), (
"Nodes dataframe columns differ"
)
assert len(edges_actual.columns) == len(edges_expected.columns), (
"Edges dataframe columns differ"
)
# TODO: with the combined verb we can't force summarization
# this is because the mock responses always result in a single description, which is returned verbatim rather than summarized
# we need to update the mocking to provide somewhat unique graphs so a true merge happens
# the assertion should grab a node and ensure the description matches the mock description, not the original as we are doing below
assert nodes_actual["description"].values[0] == "Company_A is a test company"
assert len(context.storage.keys()) == 0, "Storage should be empty"
async def test_extract_graph_with_snapshots():
input_tables = load_input_tables([
"workflow:create_base_text_units",
])
context = create_run_context(None, None, None)
await context.runtime_storage.set(
"base_text_units", input_tables["workflow:create_base_text_units"]
)
config = get_config_for_workflow(workflow_name)
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_SUMMARIZATION_CONFIG
config["snapshot_graphml"] = True
config["snapshot_transient"] = True
config["embed_graph_enabled"] = True # need this on in order to see the snapshot
steps = build_steps(config)
await get_workflow_output(
input_tables,
{
"steps": steps,
},
context=context,
)
assert context.storage.keys() == [
"graph.graphml",
"base_entity_nodes.parquet",
"base_relationship_edges.parquet",
], "Graph snapshot keys differ"
async def test_extract_graph_missing_llm_throws():
input_tables = load_input_tables([
"workflow:create_base_text_units",
])
context = create_run_context(None, None, None)
await context.runtime_storage.set(
"base_text_units", input_tables["workflow:create_base_text_units"]
)
config = get_config_for_workflow(workflow_name)
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
del config["summarize_descriptions"]["strategy"]["llm"]
steps = build_steps(config)
with pytest.raises(ValueError): # noqa PT011
await get_workflow_output(
input_tables,
{
"steps": steps,
},
context=context,
)