Add GraphRAG Cache package.

This commit is contained in:
Derek Worthen 2025-10-22 10:19:10 -07:00
parent 4404668aa8
commit 71f9c09f3f
36 changed files with 489 additions and 262 deletions

View File

@ -0,0 +1 @@
3.12

View File

@ -0,0 +1,97 @@
# GraphRAG Cache
## Basic
```python
import asyncio
from graphrag_storage import StorageConfig, create_storage, StorageType
from graphrag_cache import CacheConfig, create_cache, CacheType
async def run():
# Json cache requires a storage implementation.
storage = create_storage(
StorageConfig(
type=StorageType.File
base_dir="output"
)
)
cache = create_cache(
CacheConfig(
type=CacheType.Json
),
storage=storage
)
await cache.set("my_key", {"some": "object to cache"})
print(await cache.get("my_key"))
if __name__ == "__main__":
asyncio.run(run())
```
## Custom Cache
```python
import asyncio
from typing import Any
from graphrag_storage import Storage
from graphrag_cache import Cache, CacheConfig, create_cache, register_cache
class MyCache(Cache):
def __init__(self, storage: Storage, some_setting: str, optional_setting: str = "default setting", **kwargs: Any):
# Validate settings and initialize
...
#Implement rest of interface
...
register_cache("MyCache", MyCache)
async def run():
cache = create_cache(
CacheConfig(
type="MyCache"
some_setting="My Setting"
)
# if your cache relies on a storage implementation you can pass that here
# storage=some_storage
)
# Or use the factory directly to instantiate with a dict instead of using
# CacheConfig + create_factory
# from graphrag_cache.cache_factory import cache_factory
# cache = cache_factory.create(strategy="MyCache", init_args={"storage": storage_implementation, "some_setting": "My Setting"})
await cache.set("my_key", {"some": "object to cache"})
print(await cache.get("my_key"))
if __name__ == "__main__":
asyncio.run(run())
```
### Details
By default, the `create_cache` comes with the following cache providers registered that correspond to the entries in the `CacheType` enum.
- `JsonCache`
- `MemoryCache`
- `NoopCache`
The preregistration happens dynamically, e.g., `JsonCache` is only imported and registered if you request a `JsonCache` with `create_cache(CacheType.Json, ...)`. There is no need to manually import and register builtin cache providers when using `create_cache`.
If you want a clean factory with no preregistered cache providers then directly import `cache_factory` and bypass using `create_cache`. The downside is that `cache_factory.create` uses a dict for init args instead of the strongly typed `CacheConfig` used with `create_cache`.
```python
from graphrag_cache.cache_factory import cache_factory
from graphrag_cache.json_cache import JsonCache
# cache_factory has no preregistered providers so you must register any
# providers you plan on using.
# May also register a custom implementation, see above for example.
cache_factory.register("my_cache_impl", JsonCache)
cache = cache_factory.create(strategy="my_cache_impl", init_args={"some_setting": "..."})
...
```

View File

@ -0,0 +1,17 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The GraphRAG Cache package."""
from graphrag_cache.cache import Cache
from graphrag_cache.cache_config import CacheConfig
from graphrag_cache.cache_factory import create_cache, register_cache
from graphrag_cache.cache_type import CacheType
__all__ = [
"Cache",
"CacheConfig",
"CacheType",
"create_cache",
"register_cache",
]

View File

@ -1,17 +1,21 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'PipelineCache' model."""
"""Abstract base class for cache."""
from __future__ import annotations
from abc import ABCMeta, abstractmethod
from abc import ABC, abstractmethod
from typing import Any
class PipelineCache(metaclass=ABCMeta):
class Cache(ABC):
"""Provide a cache interface for the pipeline."""
@abstractmethod
def __init__(self, **kwargs: Any) -> None:
"""Create a cache instance."""
@abstractmethod
async def get(self, key: str) -> Any:
"""Get the value for the given key.
@ -59,7 +63,7 @@ class PipelineCache(metaclass=ABCMeta):
"""Clear the cache."""
@abstractmethod
def child(self, name: str) -> PipelineCache:
def child(self, name: str) -> Cache:
"""Create a child cache with the given name.
Args:

View File

@ -0,0 +1,30 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Cache configuration model."""
from pydantic import BaseModel, ConfigDict, Field
from graphrag_cache.cache_type import CacheType
class CacheConfig(BaseModel):
"""The configuration section for cache."""
model_config = ConfigDict(extra="allow")
"""Allow extra fields to support custom cache implementations."""
type: str = Field(
description="The cache type to use. Builtin types include 'Json', 'Memory', and 'Noop'.",
default=CacheType.Json,
)
encoding: str | None = Field(
description="The encoding to use for file-based caching.",
default=None,
)
name: str | None = Field(
description="The name to use for the cache instance.",
default=None,
)

View File

@ -0,0 +1,82 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Cache factory implementation."""
from collections.abc import Callable
from graphrag_common.factory import Factory, ServiceScope
from graphrag_storage import Storage
from graphrag_cache.cache import Cache
from graphrag_cache.cache_config import CacheConfig
from graphrag_cache.cache_type import CacheType
class CacheFactory(Factory[Cache]):
"""A factory class for cache implementations."""
cache_factory = CacheFactory()
def register_cache(
cache_type: str,
cache_initializer: Callable[..., Cache],
scope: ServiceScope = "transient",
) -> None:
"""Register a custom storage implementation.
Args
----
- storage_type: str
The storage id to register.
- storage_initializer: Callable[..., Storage]
The storage initializer to register.
"""
cache_factory.register(cache_type, cache_initializer, scope)
def create_cache(config: CacheConfig, storage: Storage | None = None) -> Cache:
"""Create a cache implementation based on the given configuration.
Args
----
- config: CacheConfig
The cache configuration to use.
- storage: Storage | None
The storage implementation to use for file-based caches such as 'Json'.
Returns
-------
Cache
The created cache implementation.
"""
config_model = config.model_dump()
cache_strategy = config.type
if cache_strategy not in cache_factory:
match cache_strategy:
case "json":
from graphrag_cache.json_cache import JsonCache
register_cache(CacheType.Json, JsonCache)
case "memory":
from graphrag_cache.memory_cache import MemoryCache
register_cache(CacheType.Memory, MemoryCache)
case "noop":
from graphrag_cache.noop_cache import NoopCache
register_cache(CacheType.Noop, NoopCache)
case _:
msg = f"CacheConfig.type '{cache_strategy}' is not registered in the CacheFactory. Registered types: {', '.join(cache_factory.keys())}."
raise ValueError(msg)
return cache_factory.create(
strategy=cache_strategy, init_args={"storage": storage, **config_model}
)

View File

@ -0,0 +1,15 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Builtin cache implementation types."""
from enum import StrEnum
class CacheType(StrEnum):
"""Enum for cache types."""
Json = "json"
Memory = "memory"
Noop = "noop"

View File

@ -8,21 +8,21 @@ from typing import Any
from graphrag_storage import Storage
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag_cache.cache import Cache
class JsonPipelineCache(PipelineCache):
class JsonCache(Cache):
"""File pipeline cache class definition."""
_storage: Storage
_encoding: str
def __init__(self, storage: Storage, encoding="utf-8"):
def __init__(self, storage: Storage, encoding="utf-8", **kwargs: Any) -> None:
"""Init method definition."""
self._storage = storage
self._encoding = encoding
async def get(self, key: str) -> str | None:
async def get(self, key: str) -> Any | None:
"""Get method definition."""
if await self.has(key):
try:
@ -61,6 +61,6 @@ class JsonPipelineCache(PipelineCache):
"""Clear method definition."""
await self._storage.clear()
def child(self, name: str) -> "JsonPipelineCache":
def child(self, name: str) -> "Cache":
"""Child method definition."""
return JsonPipelineCache(self._storage.child(name), encoding=self._encoding)
return JsonCache(self._storage.child(name), encoding=self._encoding)

View File

@ -1,20 +1,20 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'InMemoryCache' model."""
"""MemoryCache implementation."""
from typing import Any
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag_cache.cache import Cache
class InMemoryCache(PipelineCache):
class MemoryCache(Cache):
"""In memory cache class definition."""
_cache: dict[str, Any]
_name: str
def __init__(self, name: str | None = None):
def __init__(self, name: str | None = None, **kwargs: Any) -> None:
"""Init method definition."""
self._cache = {}
self._name = name or ""
@ -69,9 +69,9 @@ class InMemoryCache(PipelineCache):
"""Clear the storage."""
self._cache.clear()
def child(self, name: str) -> PipelineCache:
def child(self, name: str) -> "Cache":
"""Create a sub cache with the given name."""
return InMemoryCache(name)
return MemoryCache(name)
def _create_cache_key(self, key: str) -> str:
"""Create a cache key for the given key."""

View File

@ -1,15 +1,18 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Module containing the NoopPipelineCache implementation."""
"""NoopCache implementation."""
from typing import Any
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag_cache.cache import Cache
class NoopPipelineCache(PipelineCache):
"""A no-op implementation of the pipeline cache, usually useful for testing."""
class NoopCache(Cache):
"""A no-op implementation of Cache, usually useful for testing."""
def __init__(self, **kwargs: Any) -> None:
"""Init method definition."""
async def get(self, key: str) -> Any:
"""Get the value for the given key.
@ -56,7 +59,7 @@ class NoopPipelineCache(PipelineCache):
async def clear(self) -> None:
"""Clear the cache."""
def child(self, name: str) -> PipelineCache:
def child(self, name: str) -> "Cache":
"""Create a child cache with the given name.
Args:

View File

@ -0,0 +1,43 @@
[project]
name = "graphrag-cache"
version = "2.7.0"
description = "GraphRAG cache package."
authors = [
{name = "Alonso Guevara Fernández", email = "alonsog@microsoft.com"},
{name = "Andrés Morales Esquivel", email = "andresmor@microsoft.com"},
{name = "Chris Trevino", email = "chtrevin@microsoft.com"},
{name = "David Tittsworth", email = "datittsw@microsoft.com"},
{name = "Dayenne de Souza", email = "ddesouza@microsoft.com"},
{name = "Derek Worthen", email = "deworthe@microsoft.com"},
{name = "Gaudy Blanco Meneses", email = "gaudyb@microsoft.com"},
{name = "Ha Trinh", email = "trinhha@microsoft.com"},
{name = "Jonathan Larson", email = "jolarso@microsoft.com"},
{name = "Josh Bradley", email = "joshbradley@microsoft.com"},
{name = "Kate Lytvynets", email = "kalytv@microsoft.com"},
{name = "Kenny Zhang", email = "zhangken@microsoft.com"},
{name = "Mónica Carvajal"},
{name = "Nathan Evans", email = "naevans@microsoft.com"},
{name = "Rodrigo Racanicci", email = "rracanicci@microsoft.com"},
{name = "Sarah Smith", email = "smithsarah@microsoft.com"},
]
license = "MIT"
readme = "README.md"
license-files = ["LICENSE"]
requires-python = ">=3.10,<3.13"
classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
dependencies = [
"graphrag-common==2.7.0",
"graphrag-storage==2.7.0",
]
[project.urls]
Source = "https://github.com/microsoft/graphrag"
[build-system]
requires = ["hatchling>=1.27.0,<2.0.0"]
build-backend = "hatchling.build"

View File

@ -1,4 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A package containing cache implementations."""

View File

@ -1,68 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Factory functions for creating a cache."""
from __future__ import annotations
from graphrag_common.factory import Factory
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
from graphrag.cache.memory_pipeline_cache import InMemoryCache
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.enums import CacheType
class CacheFactory(Factory[PipelineCache]):
"""A factory class for cache implementations.
Includes a method for users to register a custom cache implementation.
Configuration arguments are passed to each cache implementation as kwargs
for individual enforcement of required/optional arguments.
"""
# --- register built-in cache implementations ---
def create_file_cache(**kwargs) -> PipelineCache:
"""Create a file-based cache implementation."""
from graphrag_storage.file_storage import FileStorage
storage = FileStorage(**kwargs)
return JsonPipelineCache(storage)
def create_blob_cache(**kwargs) -> PipelineCache:
"""Create a blob storage-based cache implementation."""
from graphrag_storage.azure_blob_storage import AzureBlobStorage
storage = AzureBlobStorage(**kwargs)
return JsonPipelineCache(storage)
def create_cosmosdb_cache(**kwargs) -> PipelineCache:
"""Create a CosmosDB-based cache implementation."""
from graphrag_storage.azure_cosmos_storage import AzureCosmosStorage
storage = AzureCosmosStorage(**kwargs)
return JsonPipelineCache(storage)
def create_noop_cache(**_kwargs) -> PipelineCache:
"""Create a no-op cache implementation."""
return NoopPipelineCache()
def create_memory_cache(**kwargs) -> PipelineCache:
"""Create a memory cache implementation."""
return InMemoryCache(**kwargs)
# --- register built-in cache implementations ---
cache_factory = CacheFactory()
cache_factory.register(CacheType.none.value, create_noop_cache)
cache_factory.register(CacheType.memory.value, create_memory_cache)
cache_factory.register(CacheType.file.value, create_file_cache)
cache_factory.register(CacheType.blob.value, create_blob_cache)
cache_factory.register(CacheType.cosmosdb.value, create_cosmosdb_cache)

View File

@ -7,13 +7,13 @@ from dataclasses import dataclass, field
from pathlib import Path
from typing import ClassVar
from graphrag_cache import CacheType
from graphrag_storage import StorageType
from graphrag.config.embeddings import default_embeddings
from graphrag.config.enums import (
AsyncType,
AuthType,
CacheType,
ChunkStrategyType,
InputFileType,
ModelType,
@ -27,6 +27,7 @@ from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import
DEFAULT_INPUT_BASE_DIR = "input"
DEFAULT_OUTPUT_BASE_DIR = "output"
DEFAULT_CACHE_BASE_DIR = "cache"
DEFAULT_UPDATE_OUTPUT_BASE_DIR = "update_output"
DEFAULT_CHAT_MODEL_ID = "default_chat_model"
DEFAULT_CHAT_MODEL_TYPE = ModelType.Chat
@ -59,12 +60,9 @@ class BasicSearchDefaults:
class CacheDefaults:
"""Default values for cache."""
type: ClassVar[CacheType] = CacheType.file
base_dir: str = "cache"
connection_string: None = None
container_name: None = None
storage_account_blob_url: None = None
cosmosdb_account_url: None = None
type: CacheType = CacheType.Json
encoding: str | None = None
name: str | None = None
@dataclass
@ -240,6 +238,13 @@ class StorageDefaults:
azure_cosmosdb_account_url: None = None
@dataclass
class CacheStorageDefaults(StorageDefaults):
"""Default values for cache storage."""
base_dir: str | None = DEFAULT_CACHE_BASE_DIR
@dataclass
class InputStorageDefaults(StorageDefaults):
"""Default values for input storage."""
@ -396,6 +401,7 @@ class GraphRagConfigDefaults:
default_factory=UpdateIndexOutputDefaults
)
cache: CacheDefaults = field(default_factory=CacheDefaults)
cache_storage: CacheStorageDefaults = field(default_factory=CacheStorageDefaults)
input: InputDefaults = field(default_factory=InputDefaults)
embed_text: EmbedTextDefaults = field(default_factory=EmbedTextDefaults)
chunks: ChunksDefaults = field(default_factory=ChunksDefaults)

View File

@ -66,9 +66,12 @@ output:
type: {graphrag_config_defaults.output.type} # or blob, cosmosdb
base_dir: "{graphrag_config_defaults.output.base_dir}"
cache_storage:
type: {graphrag_config_defaults.cache_storage.type} # [file, blob, cosmosdb]
base_dir: "{graphrag_config_defaults.cache_storage.base_dir}"
cache:
type: {graphrag_config_defaults.cache.type.value} # [file, blob, cosmosdb]
base_dir: "{graphrag_config_defaults.cache.base_dir}"
type: {graphrag_config_defaults.cache.type} # [json, memory, noop]
reporting:
type: {graphrag_config_defaults.reporting.type.value} # [file, blob]

View File

@ -1,38 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Parameterization settings for the default configuration."""
from pydantic import BaseModel, Field
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import CacheType
class CacheConfig(BaseModel):
"""The default configuration section for Cache."""
type: CacheType | str = Field(
description="The cache type to use.",
default=graphrag_config_defaults.cache.type,
)
base_dir: str = Field(
description="The base directory for the cache.",
default=graphrag_config_defaults.cache.base_dir,
)
connection_string: str | None = Field(
description="The cache connection string to use.",
default=graphrag_config_defaults.cache.connection_string,
)
container_name: str | None = Field(
description="The cache container name to use.",
default=graphrag_config_defaults.cache.container_name,
)
storage_account_blob_url: str | None = Field(
description="The storage account blob url to use.",
default=graphrag_config_defaults.cache.storage_account_blob_url,
)
cosmosdb_account_url: str | None = Field(
description="The cosmosdb account url to use.",
default=graphrag_config_defaults.cache.cosmosdb_account_url,
)

View File

@ -3,9 +3,11 @@
"""Parameterization settings for the default configuration."""
from dataclasses import asdict
from pathlib import Path
from devtools import pformat
from graphrag_cache import CacheConfig
from graphrag_storage import StorageConfig, StorageType
from pydantic import BaseModel, Field, model_validator
@ -13,7 +15,6 @@ import graphrag.config.defaults as defs
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import VectorStoreType
from graphrag.config.models.basic_search_config import BasicSearchConfig
from graphrag.config.models.cache_config import CacheConfig
from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.config.models.cluster_graph_config import ClusterGraphConfig
from graphrag.config.models.community_reports_config import CommunityReportsConfig
@ -164,8 +165,17 @@ class GraphRagConfig(BaseModel):
Path(self.update_index_output.base_dir).resolve()
)
cache_storage: StorageConfig | None = Field(
description="The cache storage configuration.",
default=StorageConfig(
**asdict(graphrag_config_defaults.cache_storage),
),
)
"""The cache storage configuration."""
cache: CacheConfig = Field(
description="The cache configuration.", default=CacheConfig()
description="The cache configuration.",
default=CacheConfig(**asdict(graphrag_config_defaults.cache)),
)
"""The cache configuration."""

View File

@ -7,8 +7,8 @@ from itertools import combinations
import numpy as np
import pandas as pd
from graphrag_cache import Cache
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.enums import AsyncType
from graphrag.index.operations.build_noun_graph.np_extractors.base import (
BaseNounPhraseExtractor,
@ -24,7 +24,7 @@ async def build_noun_graph(
normalize_edge_weights: bool,
num_threads: int,
async_mode: AsyncType,
cache: PipelineCache,
cache: Cache,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Build a noun graph from text units."""
text_units = text_unit_df.loc[:, ["id", "text"]]
@ -44,7 +44,7 @@ async def _extract_nodes(
text_analyzer: BaseNounPhraseExtractor,
num_threads: int,
async_mode: AsyncType,
cache: PipelineCache,
cache: Cache,
) -> pd.DataFrame:
"""
Extract initial nodes and edges from text units.

View File

@ -7,7 +7,8 @@ from collections.abc import Awaitable, Callable, Iterable
from dataclasses import dataclass
from typing import Any
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag_cache import Cache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
@ -42,7 +43,7 @@ CovariateExtractStrategy = Callable[
list[str],
dict[str, str],
WorkflowCallbacks,
PipelineCache,
Cache,
dict[str, Any],
],
Awaitable[CovariateExtractionResult],

View File

@ -12,6 +12,7 @@ from dataclasses import asdict
from typing import Any
import pandas as pd
from graphrag_cache import create_cache
from graphrag_storage import Storage, create_storage
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
@ -20,7 +21,6 @@ from graphrag.index.run.utils import create_run_context
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.pipeline import Pipeline
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
from graphrag.utils.api import create_cache_from_config
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
logger = logging.getLogger(__name__)
@ -37,7 +37,10 @@ async def run_pipeline(
"""Run all workflows using a simplified pipeline."""
input_storage = create_storage(config.input.storage)
output_storage = create_storage(config.output)
cache = create_cache_from_config(config.cache)
cache_storage: Storage | None = None
if config.cache_storage:
cache_storage = create_storage(config.cache_storage)
cache = create_cache(config.cache, storage=cache_storage)
# load existing state in case any workflows are stateful
state_json = await output_storage.get("context.json")

View File

@ -3,11 +3,11 @@
"""Utility functions for the GraphRAG run module."""
from graphrag_cache import Cache
from graphrag_cache.memory_cache import MemoryCache
from graphrag_storage import Storage, create_storage
from graphrag_storage.memory_storage import MemoryStorage
from graphrag.cache.memory_pipeline_cache import InMemoryCache
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.callbacks.workflow_callbacks_manager import WorkflowCallbacksManager
@ -21,7 +21,7 @@ def create_run_context(
input_storage: Storage | None = None,
output_storage: Storage | None = None,
previous_storage: Storage | None = None,
cache: PipelineCache | None = None,
cache: Cache | None = None,
callbacks: WorkflowCallbacks | None = None,
stats: PipelineRunStats | None = None,
state: PipelineState | None = None,
@ -31,7 +31,7 @@ def create_run_context(
input_storage=input_storage or MemoryStorage(),
output_storage=output_storage or MemoryStorage(),
previous_storage=previous_storage or MemoryStorage(),
cache=cache or InMemoryCache(),
cache=cache or MemoryCache(),
callbacks=callbacks or NoopWorkflowCallbacks(),
stats=stats or PipelineRunStats(),
state=state or {},

View File

@ -6,7 +6,7 @@
from dataclasses import dataclass
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag_cache import Cache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.typing.state import PipelineState
from graphrag.index.typing.stats import PipelineRunStats
@ -24,7 +24,7 @@ class PipelineRunContext:
"Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider."
previous_storage: Storage
"Storage for previous pipeline run when running in update mode."
cache: PipelineCache
cache: Cache
"Cache instance for reading previous LLM responses."
callbacks: WorkflowCallbacks
"Callbacks to be called during the pipeline run."

View File

@ -6,8 +6,8 @@
import logging
import pandas as pd
from graphrag_cache import Cache
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.enums import AsyncType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.operations.build_noun_graph.build_noun_graph import build_noun_graph
@ -59,7 +59,7 @@ async def run_workflow(
async def extract_graph_nlp(
text_units: pd.DataFrame,
cache: PipelineCache,
cache: Cache,
text_analyzer: BaseNounPhraseExtractor,
normalize_edge_weights: bool,
num_threads: int,

View File

@ -6,9 +6,9 @@
import logging
import pandas as pd
from graphrag_cache import Cache
from graphrag_storage import Storage
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.run.utils import get_update_storages
@ -59,7 +59,7 @@ async def _update_entities_and_relationships(
delta_storage: Storage,
output_storage: Storage,
config: GraphRagConfig,
cache: PipelineCache,
cache: Cache,
callbacks: WorkflowCallbacks,
) -> tuple[pd.DataFrame, pd.DataFrame, dict]:
"""Update Final Entities and Relationships output."""

View File

@ -38,7 +38,8 @@ from graphrag.language_model.providers.litellm.types import (
)
if TYPE_CHECKING:
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag_cache import Cache
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.language_model.response.base import ModelResponse as MR # noqa: N817
@ -113,7 +114,7 @@ def _create_base_completions(
def _create_completions(
model_config: "LanguageModelConfig",
cache: "PipelineCache | None",
cache: "Cache | None",
cache_key_prefix: str,
) -> tuple[FixedModelCompletion, AFixedModelCompletion]:
"""Wrap the base litellm completion function with the model configuration and additional features.
@ -203,7 +204,7 @@ class LitellmChatModel:
self,
name: str,
config: "LanguageModelConfig",
cache: "PipelineCache | None" = None,
cache: "Cache | None" = None,
**kwargs: Any,
):
self.name = name

View File

@ -33,7 +33,8 @@ from graphrag.language_model.providers.litellm.types import (
)
if TYPE_CHECKING:
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag_cache import Cache
from graphrag.config.models.language_model_config import LanguageModelConfig
litellm.suppress_debug_info = True
@ -99,7 +100,7 @@ def _create_base_embeddings(
def _create_embeddings(
model_config: "LanguageModelConfig",
cache: "PipelineCache | None",
cache: "Cache | None",
cache_key_prefix: str,
) -> tuple[FixedModelEmbedding, AFixedModelEmbedding]:
"""Wrap the base litellm embedding function with the model configuration and additional features.
@ -167,7 +168,7 @@ class LitellmEmbeddingModel:
self,
name: str,
config: "LanguageModelConfig",
cache: "PipelineCache | None" = None,
cache: "Cache | None" = None,
**kwargs: Any,
):
self.name = name

View File

@ -15,7 +15,8 @@ from graphrag.language_model.providers.litellm.types import (
)
if TYPE_CHECKING:
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag_cache import Cache
from graphrag.config.models.language_model_config import LanguageModelConfig
@ -24,7 +25,7 @@ def with_cache(
sync_fn: LitellmRequestFunc,
async_fn: AsyncLitellmRequestFunc,
model_config: "LanguageModelConfig",
cache: "PipelineCache",
cache: "Cache",
request_type: Literal["chat", "embedding"],
cache_key_prefix: str,
) -> tuple[LitellmRequestFunc, AsyncLitellmRequestFunc]:

View File

@ -8,9 +8,9 @@ from typing import Any
import numpy as np
import pandas as pd
from graphrag_cache.noop_cache import NoopCache
from graphrag_storage import create_storage
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.input.factory import InputReaderFactory
@ -60,7 +60,7 @@ async def load_docs_in_chunks(
model_type=embeddings_llm_settings.type,
config=embeddings_llm_settings,
callbacks=NoopWorkflowCallbacks(),
cache=NoopPipelineCache(),
cache=NoopCache(),
)
tokenizer = get_tokenizer(embeddings_llm_settings)
input_storage = create_storage(config.input.storage)

View File

@ -6,10 +6,7 @@
from pathlib import Path
from typing import Any
from graphrag.cache.factory import CacheFactory
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.embeddings import create_index_name
from graphrag.config.models.cache_config import CacheConfig
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.vector_stores.base import (
BaseVectorStore,
@ -98,15 +95,6 @@ def load_search_prompt(prompt_config: str | None) -> str | None:
return None
def create_cache_from_config(cache: CacheConfig) -> PipelineCache:
"""Create a cache object from the config."""
cache_config = cache.model_dump()
return CacheFactory().create(
strategy=cache_config["type"],
init_args=cache_config,
)
def truncate(text: str, max_length: int) -> str:
"""Truncate a string to a maximum length."""
if len(text) <= max_length:

View File

@ -40,6 +40,7 @@ dependencies = [
"azure-storage-blob>=12.24.0",
"devtools>=0.12.2",
"environs>=11.0.0",
"graphrag-cache==2.7.0",
"graphrag-common==2.7.0",
"graphrag-storage==2.7.0",
"graspologic-native>=1.2.5",

View File

@ -55,6 +55,7 @@ members = ["packages/*"]
[tool.uv.sources]
graphrag-common = { workspace = true }
graphrag-storage = { workspace = true }
graphrag-cache = { workspace = true }
# Keep poethepoet for task management to minimize changes
[tool.poe.tasks]
@ -71,6 +72,7 @@ _semversioner_changelog = "semversioner changelog > CHANGELOG.md"
_semversioner_update_graphrag_toml_version = "update-toml update --file packages/graphrag/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
_semversioner_update_graphrag_common_toml_version = "update-toml update --file packages/graphrag-common/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
_semversioner_update_graphrag_storage_toml_version = "update-toml update --file packages/graphrag-storage/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
_semversioner_update_graphrag_cache_toml_version = "update-toml update --file packages/graphrag-cache/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
_semversioner_update_workspace_dependency_versions = "python -m scripts.update_workspace_dependency_versions"
semversioner_add = "semversioner add-change"
coverage_report = 'coverage report --omit "**/tests/**" --show-missing'
@ -106,6 +108,7 @@ sequence = [
'_semversioner_update_graphrag_toml_version',
'_semversioner_update_graphrag_common_toml_version',
'_semversioner_update_graphrag_storage_toml_version',
'_semversioner_update_graphrag_cache_toml_version',
'_semversioner_update_workspace_dependency_versions',
'_sync',
]

View File

@ -8,12 +8,12 @@ These tests will test the CacheFactory() class and the creation of each cache ty
import sys
import pytest
from graphrag.cache.factory import CacheFactory
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
from graphrag.cache.memory_pipeline_cache import InMemoryCache
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.enums import CacheType
from graphrag_cache import Cache, CacheConfig, CacheType, create_cache, register_cache
from graphrag_cache.cache_factory import cache_factory
from graphrag_cache.json_cache import JsonCache
from graphrag_cache.memory_cache import MemoryCache
from graphrag_cache.noop_cache import NoopCache
from graphrag_storage import StorageConfig, StorageType, create_storage
# cspell:disable-next-line well-known-key
WELL_KNOWN_BLOB_STORAGE_KEY = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;"
@ -22,31 +22,55 @@ WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://127.0.0.1:8081/;A
def test_create_noop_cache():
cache = CacheFactory().create(strategy=CacheType.none.value)
assert isinstance(cache, NoopPipelineCache)
cache = create_cache(
CacheConfig(
type=CacheType.Noop,
)
)
assert isinstance(cache, NoopCache)
def test_create_memory_cache():
cache = CacheFactory().create(strategy=CacheType.memory.value)
assert isinstance(cache, InMemoryCache)
cache = create_cache(
CacheConfig(
type=CacheType.Memory,
)
)
assert isinstance(cache, MemoryCache)
def test_create_file_cache():
cache = CacheFactory().create(
strategy=CacheType.file.value,
init_args={"base_dir": "testcache"},
storage = create_storage(
StorageConfig(
type=StorageType.Memory,
)
)
assert isinstance(cache, JsonPipelineCache)
cache = create_cache(
CacheConfig(
type=CacheType.Json,
),
storage=storage,
)
assert isinstance(cache, JsonCache)
def test_create_blob_cache():
init_args = {
"connection_string": WELL_KNOWN_BLOB_STORAGE_KEY,
"container_name": "testcontainer",
"base_dir": "testcache",
}
cache = CacheFactory().create(strategy=CacheType.blob.value, init_args=init_args)
assert isinstance(cache, JsonPipelineCache)
storage = create_storage(
StorageConfig(
type=StorageType.AzureBlob,
connection_string=WELL_KNOWN_BLOB_STORAGE_KEY,
container_name="testcontainer",
base_dir="testcache",
)
)
cache = create_cache(
CacheConfig(
type=CacheType.Json,
),
storage=storage,
)
assert isinstance(cache, JsonCache)
@pytest.mark.skipif(
@ -54,15 +78,21 @@ def test_create_blob_cache():
reason="cosmosdb emulator is only available on windows runners at this time",
)
def test_create_cosmosdb_cache():
init_args = {
"connection_string": WELL_KNOWN_COSMOS_CONNECTION_STRING,
"database_name": "testdatabase",
"container_name": "testcontainer",
}
cache = CacheFactory().create(
strategy=CacheType.cosmosdb.value, init_args=init_args
storage = create_storage(
StorageConfig(
type=StorageType.AzureCosmos,
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
database_name="testdatabase",
container_name="testcontainer",
)
)
assert isinstance(cache, JsonPipelineCache)
cache = create_cache(
CacheConfig(
type=CacheType.Json,
),
storage=storage,
)
assert isinstance(cache, JsonCache)
def test_register_and_create_custom_cache():
@ -70,17 +100,14 @@ def test_register_and_create_custom_cache():
from unittest.mock import MagicMock
# Create a mock that satisfies the PipelineCache interface
custom_cache_class = MagicMock(spec=PipelineCache)
custom_cache_class = MagicMock(spec=Cache)
# Make the mock return a mock instance when instantiated
instance = MagicMock()
instance.initialized = True
custom_cache_class.return_value = instance
CacheFactory().register(
strategy="custom",
initializer=lambda **kwargs: custom_cache_class(**kwargs),
)
cache = CacheFactory().create(strategy="custom")
register_cache("custom", lambda **kwargs: custom_cache_class(**kwargs))
cache = create_cache(CacheConfig(type="custom"))
assert custom_cache_class.called
assert cache is instance
@ -88,45 +115,21 @@ def test_register_and_create_custom_cache():
assert cache.initialized is True # type: ignore # Attribute only exists on our mock
# Check if it's in the list of registered cache types
assert "custom" in CacheFactory()
assert "custom" in cache_factory
def test_create_unknown_cache():
with pytest.raises(ValueError, match="Strategy 'unknown' is not registered\\."):
CacheFactory().create(strategy="unknown")
def test_is_supported_type():
# Test built-in types
assert CacheType.none.value in CacheFactory()
assert CacheType.memory.value in CacheFactory()
assert CacheType.file.value in CacheFactory()
assert CacheType.blob.value in CacheFactory()
assert CacheType.cosmosdb.value in CacheFactory()
# Test unknown type
assert "unknown" not in CacheFactory()
def test_enum_and_string_compatibility():
"""Test that both enum and string types work for cache creation."""
# Test with enum
cache_enum = CacheFactory().create(strategy=CacheType.memory)
assert isinstance(cache_enum, InMemoryCache)
# Test with string
cache_str = CacheFactory().create(strategy="memory")
assert isinstance(cache_str, InMemoryCache)
# Both should create the same type
assert type(cache_enum) is type(cache_str)
with pytest.raises(
ValueError,
match="CacheConfig\\.type 'unknown' is not registered in the CacheFactory\\.",
):
create_cache(CacheConfig(type="unknown"))
def test_register_class_directly_works():
"""Test that registering a class directly works (CacheFactory() allows this)."""
from graphrag.cache.pipeline_cache import PipelineCache
class CustomCache(PipelineCache):
class CustomCache(Cache):
def __init__(self, **kwargs):
pass
@ -149,11 +152,11 @@ def test_register_class_directly_works():
return self
# CacheFactory() allows registering classes directly (no TypeError)
CacheFactory().register("custom_class", CustomCache)
register_cache("custom_class", CustomCache)
# Verify it was registered
assert "custom_class" in CacheFactory()
assert "custom_class" in cache_factory
# Test creating an instance
cache = CacheFactory().create(strategy="custom_class")
cache = create_cache(CacheConfig(type="custom_class"))
assert isinstance(cache, CustomCache)

View File

@ -5,7 +5,6 @@ from dataclasses import asdict
import graphrag.config.defaults as defs
from graphrag.config.models.basic_search_config import BasicSearchConfig
from graphrag.config.models.cache_config import CacheConfig
from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.config.models.cluster_graph_config import ClusterGraphConfig
from graphrag.config.models.community_reports_config import CommunityReportsConfig
@ -29,6 +28,7 @@ from graphrag.config.models.summarize_descriptions_config import (
SummarizeDescriptionsConfig,
)
from graphrag.config.models.vector_store_config import VectorStoreConfig
from graphrag_cache import CacheConfig
from graphrag_storage import StorageConfig
from pydantic import BaseModel
@ -138,11 +138,8 @@ def assert_storage_config(actual: StorageConfig, expected: StorageConfig) -> Non
def assert_cache_configs(actual: CacheConfig, expected: CacheConfig) -> None:
assert actual.type == expected.type
assert actual.base_dir == expected.base_dir
assert actual.connection_string == expected.connection_string
assert actual.container_name == expected.container_name
assert actual.storage_account_blob_url == expected.storage_account_blob_url
assert actual.cosmosdb_account_url == expected.cosmosdb_account_url
assert actual.encoding == expected.encoding
assert actual.name == expected.name
def assert_input_configs(actual: InputConfig, expected: InputConfig) -> None:

View File

@ -4,17 +4,26 @@ import asyncio
import os
import unittest
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
from graphrag_storage.file_storage import (
FileStorage,
)
from graphrag_cache import CacheConfig, CacheType
from graphrag_cache import create_cache as cc
from graphrag_storage import StorageConfig, StorageType, create_storage
TEMP_DIR = "./.tmp"
def create_cache():
storage = FileStorage(base_dir=os.path.join(os.getcwd(), ".tmp"))
return JsonPipelineCache(storage)
storage = create_storage(
StorageConfig(
type=StorageType.File,
base_dir=os.path.join(os.getcwd(), ".tmp"),
)
)
return cc(
CacheConfig(
type=CacheType.Json,
),
storage=storage,
)
class TestFilePipelineCache(unittest.IsolatedAsyncioTestCase):

18
uv.lock generated
View File

@ -10,6 +10,7 @@ resolution-markers = [
[manifest]
members = [
"graphrag",
"graphrag-cache",
"graphrag-common",
"graphrag-monorepo",
"graphrag-storage",
@ -1042,6 +1043,7 @@ dependencies = [
{ name = "azure-storage-blob" },
{ name = "devtools" },
{ name = "environs" },
{ name = "graphrag-cache" },
{ name = "graphrag-common" },
{ name = "graphrag-storage" },
{ name = "graspologic-native" },
@ -1074,6 +1076,7 @@ requires-dist = [
{ name = "azure-storage-blob", specifier = ">=12.24.0" },
{ name = "devtools", specifier = ">=0.12.2" },
{ name = "environs", specifier = ">=11.0.0" },
{ name = "graphrag-cache", editable = "packages/graphrag-cache" },
{ name = "graphrag-common", editable = "packages/graphrag-common" },
{ name = "graphrag-storage", editable = "packages/graphrag-storage" },
{ name = "graspologic-native", specifier = ">=1.2.5" },
@ -1095,6 +1098,21 @@ requires-dist = [
{ name = "typing-extensions", specifier = ">=4.12.2" },
]
[[package]]
name = "graphrag-cache"
version = "2.7.0"
source = { editable = "packages/graphrag-cache" }
dependencies = [
{ name = "graphrag-common" },
{ name = "graphrag-storage" },
]
[package.metadata]
requires-dist = [
{ name = "graphrag-common", editable = "packages/graphrag-common" },
{ name = "graphrag-storage", editable = "packages/graphrag-storage" },
]
[[package]]
name = "graphrag-common"
version = "2.7.0"