mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Add GraphRAG Cache package.
This commit is contained in:
parent
4404668aa8
commit
71f9c09f3f
1
packages/graphrag-cache/.python-version
Normal file
1
packages/graphrag-cache/.python-version
Normal file
@ -0,0 +1 @@
|
||||
3.12
|
||||
97
packages/graphrag-cache/README.md
Normal file
97
packages/graphrag-cache/README.md
Normal file
@ -0,0 +1,97 @@
|
||||
# GraphRAG Cache
|
||||
|
||||
## Basic
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from graphrag_storage import StorageConfig, create_storage, StorageType
|
||||
from graphrag_cache import CacheConfig, create_cache, CacheType
|
||||
|
||||
async def run():
|
||||
# Json cache requires a storage implementation.
|
||||
storage = create_storage(
|
||||
StorageConfig(
|
||||
type=StorageType.File
|
||||
base_dir="output"
|
||||
)
|
||||
)
|
||||
|
||||
cache = create_cache(
|
||||
CacheConfig(
|
||||
type=CacheType.Json
|
||||
),
|
||||
storage=storage
|
||||
)
|
||||
|
||||
await cache.set("my_key", {"some": "object to cache"})
|
||||
print(await cache.get("my_key"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run())
|
||||
```
|
||||
|
||||
## Custom Cache
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from graphrag_storage import Storage
|
||||
from graphrag_cache import Cache, CacheConfig, create_cache, register_cache
|
||||
|
||||
class MyCache(Cache):
|
||||
def __init__(self, storage: Storage, some_setting: str, optional_setting: str = "default setting", **kwargs: Any):
|
||||
# Validate settings and initialize
|
||||
...
|
||||
|
||||
#Implement rest of interface
|
||||
...
|
||||
|
||||
register_cache("MyCache", MyCache)
|
||||
|
||||
async def run():
|
||||
cache = create_cache(
|
||||
CacheConfig(
|
||||
type="MyCache"
|
||||
some_setting="My Setting"
|
||||
)
|
||||
# if your cache relies on a storage implementation you can pass that here
|
||||
# storage=some_storage
|
||||
)
|
||||
# Or use the factory directly to instantiate with a dict instead of using
|
||||
# CacheConfig + create_factory
|
||||
# from graphrag_cache.cache_factory import cache_factory
|
||||
# cache = cache_factory.create(strategy="MyCache", init_args={"storage": storage_implementation, "some_setting": "My Setting"})
|
||||
|
||||
await cache.set("my_key", {"some": "object to cache"})
|
||||
print(await cache.get("my_key"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run())
|
||||
```
|
||||
|
||||
### Details
|
||||
|
||||
By default, the `create_cache` comes with the following cache providers registered that correspond to the entries in the `CacheType` enum.
|
||||
|
||||
- `JsonCache`
|
||||
- `MemoryCache`
|
||||
- `NoopCache`
|
||||
|
||||
The preregistration happens dynamically, e.g., `JsonCache` is only imported and registered if you request a `JsonCache` with `create_cache(CacheType.Json, ...)`. There is no need to manually import and register builtin cache providers when using `create_cache`.
|
||||
|
||||
If you want a clean factory with no preregistered cache providers then directly import `cache_factory` and bypass using `create_cache`. The downside is that `cache_factory.create` uses a dict for init args instead of the strongly typed `CacheConfig` used with `create_cache`.
|
||||
|
||||
```python
|
||||
from graphrag_cache.cache_factory import cache_factory
|
||||
from graphrag_cache.json_cache import JsonCache
|
||||
|
||||
# cache_factory has no preregistered providers so you must register any
|
||||
# providers you plan on using.
|
||||
# May also register a custom implementation, see above for example.
|
||||
cache_factory.register("my_cache_impl", JsonCache)
|
||||
|
||||
cache = cache_factory.create(strategy="my_cache_impl", init_args={"some_setting": "..."})
|
||||
|
||||
...
|
||||
|
||||
```
|
||||
17
packages/graphrag-cache/graphrag_cache/__init__.py
Normal file
17
packages/graphrag-cache/graphrag_cache/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""The GraphRAG Cache package."""
|
||||
|
||||
from graphrag_cache.cache import Cache
|
||||
from graphrag_cache.cache_config import CacheConfig
|
||||
from graphrag_cache.cache_factory import create_cache, register_cache
|
||||
from graphrag_cache.cache_type import CacheType
|
||||
|
||||
__all__ = [
|
||||
"Cache",
|
||||
"CacheConfig",
|
||||
"CacheType",
|
||||
"create_cache",
|
||||
"register_cache",
|
||||
]
|
||||
@ -1,17 +1,21 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'PipelineCache' model."""
|
||||
"""Abstract base class for cache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class PipelineCache(metaclass=ABCMeta):
|
||||
class Cache(ABC):
|
||||
"""Provide a cache interface for the pipeline."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Create a cache instance."""
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> Any:
|
||||
"""Get the value for the given key.
|
||||
@ -59,7 +63,7 @@ class PipelineCache(metaclass=ABCMeta):
|
||||
"""Clear the cache."""
|
||||
|
||||
@abstractmethod
|
||||
def child(self, name: str) -> PipelineCache:
|
||||
def child(self, name: str) -> Cache:
|
||||
"""Create a child cache with the given name.
|
||||
|
||||
Args:
|
||||
30
packages/graphrag-cache/graphrag_cache/cache_config.py
Normal file
30
packages/graphrag-cache/graphrag_cache/cache_config.py
Normal file
@ -0,0 +1,30 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Cache configuration model."""
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from graphrag_cache.cache_type import CacheType
|
||||
|
||||
|
||||
class CacheConfig(BaseModel):
|
||||
"""The configuration section for cache."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
"""Allow extra fields to support custom cache implementations."""
|
||||
|
||||
type: str = Field(
|
||||
description="The cache type to use. Builtin types include 'Json', 'Memory', and 'Noop'.",
|
||||
default=CacheType.Json,
|
||||
)
|
||||
|
||||
encoding: str | None = Field(
|
||||
description="The encoding to use for file-based caching.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
name: str | None = Field(
|
||||
description="The name to use for the cache instance.",
|
||||
default=None,
|
||||
)
|
||||
82
packages/graphrag-cache/graphrag_cache/cache_factory.py
Normal file
82
packages/graphrag-cache/graphrag_cache/cache_factory.py
Normal file
@ -0,0 +1,82 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
|
||||
"""Cache factory implementation."""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from graphrag_common.factory import Factory, ServiceScope
|
||||
from graphrag_storage import Storage
|
||||
|
||||
from graphrag_cache.cache import Cache
|
||||
from graphrag_cache.cache_config import CacheConfig
|
||||
from graphrag_cache.cache_type import CacheType
|
||||
|
||||
|
||||
class CacheFactory(Factory[Cache]):
|
||||
"""A factory class for cache implementations."""
|
||||
|
||||
|
||||
cache_factory = CacheFactory()
|
||||
|
||||
|
||||
def register_cache(
|
||||
cache_type: str,
|
||||
cache_initializer: Callable[..., Cache],
|
||||
scope: ServiceScope = "transient",
|
||||
) -> None:
|
||||
"""Register a custom storage implementation.
|
||||
|
||||
Args
|
||||
----
|
||||
- storage_type: str
|
||||
The storage id to register.
|
||||
- storage_initializer: Callable[..., Storage]
|
||||
The storage initializer to register.
|
||||
"""
|
||||
cache_factory.register(cache_type, cache_initializer, scope)
|
||||
|
||||
|
||||
def create_cache(config: CacheConfig, storage: Storage | None = None) -> Cache:
|
||||
"""Create a cache implementation based on the given configuration.
|
||||
|
||||
Args
|
||||
----
|
||||
- config: CacheConfig
|
||||
The cache configuration to use.
|
||||
- storage: Storage | None
|
||||
The storage implementation to use for file-based caches such as 'Json'.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Cache
|
||||
The created cache implementation.
|
||||
"""
|
||||
config_model = config.model_dump()
|
||||
cache_strategy = config.type
|
||||
|
||||
if cache_strategy not in cache_factory:
|
||||
match cache_strategy:
|
||||
case "json":
|
||||
from graphrag_cache.json_cache import JsonCache
|
||||
|
||||
register_cache(CacheType.Json, JsonCache)
|
||||
|
||||
case "memory":
|
||||
from graphrag_cache.memory_cache import MemoryCache
|
||||
|
||||
register_cache(CacheType.Memory, MemoryCache)
|
||||
|
||||
case "noop":
|
||||
from graphrag_cache.noop_cache import NoopCache
|
||||
|
||||
register_cache(CacheType.Noop, NoopCache)
|
||||
|
||||
case _:
|
||||
msg = f"CacheConfig.type '{cache_strategy}' is not registered in the CacheFactory. Registered types: {', '.join(cache_factory.keys())}."
|
||||
raise ValueError(msg)
|
||||
|
||||
return cache_factory.create(
|
||||
strategy=cache_strategy, init_args={"storage": storage, **config_model}
|
||||
)
|
||||
15
packages/graphrag-cache/graphrag_cache/cache_type.py
Normal file
15
packages/graphrag-cache/graphrag_cache/cache_type.py
Normal file
@ -0,0 +1,15 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
|
||||
"""Builtin cache implementation types."""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class CacheType(StrEnum):
|
||||
"""Enum for cache types."""
|
||||
|
||||
Json = "json"
|
||||
Memory = "memory"
|
||||
Noop = "noop"
|
||||
@ -8,21 +8,21 @@ from typing import Any
|
||||
|
||||
from graphrag_storage import Storage
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag_cache.cache import Cache
|
||||
|
||||
|
||||
class JsonPipelineCache(PipelineCache):
|
||||
class JsonCache(Cache):
|
||||
"""File pipeline cache class definition."""
|
||||
|
||||
_storage: Storage
|
||||
_encoding: str
|
||||
|
||||
def __init__(self, storage: Storage, encoding="utf-8"):
|
||||
def __init__(self, storage: Storage, encoding="utf-8", **kwargs: Any) -> None:
|
||||
"""Init method definition."""
|
||||
self._storage = storage
|
||||
self._encoding = encoding
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
async def get(self, key: str) -> Any | None:
|
||||
"""Get method definition."""
|
||||
if await self.has(key):
|
||||
try:
|
||||
@ -61,6 +61,6 @@ class JsonPipelineCache(PipelineCache):
|
||||
"""Clear method definition."""
|
||||
await self._storage.clear()
|
||||
|
||||
def child(self, name: str) -> "JsonPipelineCache":
|
||||
def child(self, name: str) -> "Cache":
|
||||
"""Child method definition."""
|
||||
return JsonPipelineCache(self._storage.child(name), encoding=self._encoding)
|
||||
return JsonCache(self._storage.child(name), encoding=self._encoding)
|
||||
@ -1,20 +1,20 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'InMemoryCache' model."""
|
||||
"""MemoryCache implementation."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag_cache.cache import Cache
|
||||
|
||||
|
||||
class InMemoryCache(PipelineCache):
|
||||
class MemoryCache(Cache):
|
||||
"""In memory cache class definition."""
|
||||
|
||||
_cache: dict[str, Any]
|
||||
_name: str
|
||||
|
||||
def __init__(self, name: str | None = None):
|
||||
def __init__(self, name: str | None = None, **kwargs: Any) -> None:
|
||||
"""Init method definition."""
|
||||
self._cache = {}
|
||||
self._name = name or ""
|
||||
@ -69,9 +69,9 @@ class InMemoryCache(PipelineCache):
|
||||
"""Clear the storage."""
|
||||
self._cache.clear()
|
||||
|
||||
def child(self, name: str) -> PipelineCache:
|
||||
def child(self, name: str) -> "Cache":
|
||||
"""Create a sub cache with the given name."""
|
||||
return InMemoryCache(name)
|
||||
return MemoryCache(name)
|
||||
|
||||
def _create_cache_key(self, key: str) -> str:
|
||||
"""Create a cache key for the given key."""
|
||||
@ -1,15 +1,18 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Module containing the NoopPipelineCache implementation."""
|
||||
"""NoopCache implementation."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag_cache.cache import Cache
|
||||
|
||||
|
||||
class NoopPipelineCache(PipelineCache):
|
||||
"""A no-op implementation of the pipeline cache, usually useful for testing."""
|
||||
class NoopCache(Cache):
|
||||
"""A no-op implementation of Cache, usually useful for testing."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Init method definition."""
|
||||
|
||||
async def get(self, key: str) -> Any:
|
||||
"""Get the value for the given key.
|
||||
@ -56,7 +59,7 @@ class NoopPipelineCache(PipelineCache):
|
||||
async def clear(self) -> None:
|
||||
"""Clear the cache."""
|
||||
|
||||
def child(self, name: str) -> PipelineCache:
|
||||
def child(self, name: str) -> "Cache":
|
||||
"""Create a child cache with the given name.
|
||||
|
||||
Args:
|
||||
0
packages/graphrag-cache/graphrag_cache/py.typed
Normal file
0
packages/graphrag-cache/graphrag_cache/py.typed
Normal file
43
packages/graphrag-cache/pyproject.toml
Normal file
43
packages/graphrag-cache/pyproject.toml
Normal file
@ -0,0 +1,43 @@
|
||||
[project]
|
||||
name = "graphrag-cache"
|
||||
version = "2.7.0"
|
||||
description = "GraphRAG cache package."
|
||||
authors = [
|
||||
{name = "Alonso Guevara Fernández", email = "alonsog@microsoft.com"},
|
||||
{name = "Andrés Morales Esquivel", email = "andresmor@microsoft.com"},
|
||||
{name = "Chris Trevino", email = "chtrevin@microsoft.com"},
|
||||
{name = "David Tittsworth", email = "datittsw@microsoft.com"},
|
||||
{name = "Dayenne de Souza", email = "ddesouza@microsoft.com"},
|
||||
{name = "Derek Worthen", email = "deworthe@microsoft.com"},
|
||||
{name = "Gaudy Blanco Meneses", email = "gaudyb@microsoft.com"},
|
||||
{name = "Ha Trinh", email = "trinhha@microsoft.com"},
|
||||
{name = "Jonathan Larson", email = "jolarso@microsoft.com"},
|
||||
{name = "Josh Bradley", email = "joshbradley@microsoft.com"},
|
||||
{name = "Kate Lytvynets", email = "kalytv@microsoft.com"},
|
||||
{name = "Kenny Zhang", email = "zhangken@microsoft.com"},
|
||||
{name = "Mónica Carvajal"},
|
||||
{name = "Nathan Evans", email = "naevans@microsoft.com"},
|
||||
{name = "Rodrigo Racanicci", email = "rracanicci@microsoft.com"},
|
||||
{name = "Sarah Smith", email = "smithsarah@microsoft.com"},
|
||||
]
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
license-files = ["LICENSE"]
|
||||
requires-python = ">=3.10,<3.13"
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
dependencies = [
|
||||
"graphrag-common==2.7.0",
|
||||
"graphrag-storage==2.7.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Source = "https://github.com/microsoft/graphrag"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling>=1.27.0,<2.0.0"]
|
||||
build-backend = "hatchling.build"
|
||||
4
packages/graphrag/graphrag/cache/__init__.py
vendored
4
packages/graphrag/graphrag/cache/__init__.py
vendored
@ -1,4 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A package containing cache implementations."""
|
||||
68
packages/graphrag/graphrag/cache/factory.py
vendored
68
packages/graphrag/graphrag/cache/factory.py
vendored
@ -1,68 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Factory functions for creating a cache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from graphrag_common.factory import Factory
|
||||
|
||||
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
|
||||
from graphrag.cache.memory_pipeline_cache import InMemoryCache
|
||||
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.config.enums import CacheType
|
||||
|
||||
|
||||
class CacheFactory(Factory[PipelineCache]):
|
||||
"""A factory class for cache implementations.
|
||||
|
||||
Includes a method for users to register a custom cache implementation.
|
||||
|
||||
Configuration arguments are passed to each cache implementation as kwargs
|
||||
for individual enforcement of required/optional arguments.
|
||||
"""
|
||||
|
||||
|
||||
# --- register built-in cache implementations ---
|
||||
def create_file_cache(**kwargs) -> PipelineCache:
|
||||
"""Create a file-based cache implementation."""
|
||||
from graphrag_storage.file_storage import FileStorage
|
||||
|
||||
storage = FileStorage(**kwargs)
|
||||
return JsonPipelineCache(storage)
|
||||
|
||||
|
||||
def create_blob_cache(**kwargs) -> PipelineCache:
|
||||
"""Create a blob storage-based cache implementation."""
|
||||
from graphrag_storage.azure_blob_storage import AzureBlobStorage
|
||||
|
||||
storage = AzureBlobStorage(**kwargs)
|
||||
return JsonPipelineCache(storage)
|
||||
|
||||
|
||||
def create_cosmosdb_cache(**kwargs) -> PipelineCache:
|
||||
"""Create a CosmosDB-based cache implementation."""
|
||||
from graphrag_storage.azure_cosmos_storage import AzureCosmosStorage
|
||||
|
||||
storage = AzureCosmosStorage(**kwargs)
|
||||
return JsonPipelineCache(storage)
|
||||
|
||||
|
||||
def create_noop_cache(**_kwargs) -> PipelineCache:
|
||||
"""Create a no-op cache implementation."""
|
||||
return NoopPipelineCache()
|
||||
|
||||
|
||||
def create_memory_cache(**kwargs) -> PipelineCache:
|
||||
"""Create a memory cache implementation."""
|
||||
return InMemoryCache(**kwargs)
|
||||
|
||||
|
||||
# --- register built-in cache implementations ---
|
||||
cache_factory = CacheFactory()
|
||||
cache_factory.register(CacheType.none.value, create_noop_cache)
|
||||
cache_factory.register(CacheType.memory.value, create_memory_cache)
|
||||
cache_factory.register(CacheType.file.value, create_file_cache)
|
||||
cache_factory.register(CacheType.blob.value, create_blob_cache)
|
||||
cache_factory.register(CacheType.cosmosdb.value, create_cosmosdb_cache)
|
||||
@ -7,13 +7,13 @@ from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
from graphrag_cache import CacheType
|
||||
from graphrag_storage import StorageType
|
||||
|
||||
from graphrag.config.embeddings import default_embeddings
|
||||
from graphrag.config.enums import (
|
||||
AsyncType,
|
||||
AuthType,
|
||||
CacheType,
|
||||
ChunkStrategyType,
|
||||
InputFileType,
|
||||
ModelType,
|
||||
@ -27,6 +27,7 @@ from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import
|
||||
|
||||
DEFAULT_INPUT_BASE_DIR = "input"
|
||||
DEFAULT_OUTPUT_BASE_DIR = "output"
|
||||
DEFAULT_CACHE_BASE_DIR = "cache"
|
||||
DEFAULT_UPDATE_OUTPUT_BASE_DIR = "update_output"
|
||||
DEFAULT_CHAT_MODEL_ID = "default_chat_model"
|
||||
DEFAULT_CHAT_MODEL_TYPE = ModelType.Chat
|
||||
@ -59,12 +60,9 @@ class BasicSearchDefaults:
|
||||
class CacheDefaults:
|
||||
"""Default values for cache."""
|
||||
|
||||
type: ClassVar[CacheType] = CacheType.file
|
||||
base_dir: str = "cache"
|
||||
connection_string: None = None
|
||||
container_name: None = None
|
||||
storage_account_blob_url: None = None
|
||||
cosmosdb_account_url: None = None
|
||||
type: CacheType = CacheType.Json
|
||||
encoding: str | None = None
|
||||
name: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -240,6 +238,13 @@ class StorageDefaults:
|
||||
azure_cosmosdb_account_url: None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStorageDefaults(StorageDefaults):
|
||||
"""Default values for cache storage."""
|
||||
|
||||
base_dir: str | None = DEFAULT_CACHE_BASE_DIR
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputStorageDefaults(StorageDefaults):
|
||||
"""Default values for input storage."""
|
||||
@ -396,6 +401,7 @@ class GraphRagConfigDefaults:
|
||||
default_factory=UpdateIndexOutputDefaults
|
||||
)
|
||||
cache: CacheDefaults = field(default_factory=CacheDefaults)
|
||||
cache_storage: CacheStorageDefaults = field(default_factory=CacheStorageDefaults)
|
||||
input: InputDefaults = field(default_factory=InputDefaults)
|
||||
embed_text: EmbedTextDefaults = field(default_factory=EmbedTextDefaults)
|
||||
chunks: ChunksDefaults = field(default_factory=ChunksDefaults)
|
||||
|
||||
@ -66,9 +66,12 @@ output:
|
||||
type: {graphrag_config_defaults.output.type} # or blob, cosmosdb
|
||||
base_dir: "{graphrag_config_defaults.output.base_dir}"
|
||||
|
||||
cache_storage:
|
||||
type: {graphrag_config_defaults.cache_storage.type} # [file, blob, cosmosdb]
|
||||
base_dir: "{graphrag_config_defaults.cache_storage.base_dir}"
|
||||
|
||||
cache:
|
||||
type: {graphrag_config_defaults.cache.type.value} # [file, blob, cosmosdb]
|
||||
base_dir: "{graphrag_config_defaults.cache.base_dir}"
|
||||
type: {graphrag_config_defaults.cache.type} # [json, memory, noop]
|
||||
|
||||
reporting:
|
||||
type: {graphrag_config_defaults.reporting.type.value} # [file, blob]
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Parameterization settings for the default configuration."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import CacheType
|
||||
|
||||
|
||||
class CacheConfig(BaseModel):
|
||||
"""The default configuration section for Cache."""
|
||||
|
||||
type: CacheType | str = Field(
|
||||
description="The cache type to use.",
|
||||
default=graphrag_config_defaults.cache.type,
|
||||
)
|
||||
base_dir: str = Field(
|
||||
description="The base directory for the cache.",
|
||||
default=graphrag_config_defaults.cache.base_dir,
|
||||
)
|
||||
connection_string: str | None = Field(
|
||||
description="The cache connection string to use.",
|
||||
default=graphrag_config_defaults.cache.connection_string,
|
||||
)
|
||||
container_name: str | None = Field(
|
||||
description="The cache container name to use.",
|
||||
default=graphrag_config_defaults.cache.container_name,
|
||||
)
|
||||
storage_account_blob_url: str | None = Field(
|
||||
description="The storage account blob url to use.",
|
||||
default=graphrag_config_defaults.cache.storage_account_blob_url,
|
||||
)
|
||||
cosmosdb_account_url: str | None = Field(
|
||||
description="The cosmosdb account url to use.",
|
||||
default=graphrag_config_defaults.cache.cosmosdb_account_url,
|
||||
)
|
||||
@ -3,9 +3,11 @@
|
||||
|
||||
"""Parameterization settings for the default configuration."""
|
||||
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
|
||||
from devtools import pformat
|
||||
from graphrag_cache import CacheConfig
|
||||
from graphrag_storage import StorageConfig, StorageType
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@ -13,7 +15,6 @@ import graphrag.config.defaults as defs
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import VectorStoreType
|
||||
from graphrag.config.models.basic_search_config import BasicSearchConfig
|
||||
from graphrag.config.models.cache_config import CacheConfig
|
||||
from graphrag.config.models.chunking_config import ChunkingConfig
|
||||
from graphrag.config.models.cluster_graph_config import ClusterGraphConfig
|
||||
from graphrag.config.models.community_reports_config import CommunityReportsConfig
|
||||
@ -164,8 +165,17 @@ class GraphRagConfig(BaseModel):
|
||||
Path(self.update_index_output.base_dir).resolve()
|
||||
)
|
||||
|
||||
cache_storage: StorageConfig | None = Field(
|
||||
description="The cache storage configuration.",
|
||||
default=StorageConfig(
|
||||
**asdict(graphrag_config_defaults.cache_storage),
|
||||
),
|
||||
)
|
||||
"""The cache storage configuration."""
|
||||
|
||||
cache: CacheConfig = Field(
|
||||
description="The cache configuration.", default=CacheConfig()
|
||||
description="The cache configuration.",
|
||||
default=CacheConfig(**asdict(graphrag_config_defaults.cache)),
|
||||
)
|
||||
"""The cache configuration."""
|
||||
|
||||
|
||||
@ -7,8 +7,8 @@ from itertools import combinations
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from graphrag_cache import Cache
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.index.operations.build_noun_graph.np_extractors.base import (
|
||||
BaseNounPhraseExtractor,
|
||||
@ -24,7 +24,7 @@ async def build_noun_graph(
|
||||
normalize_edge_weights: bool,
|
||||
num_threads: int,
|
||||
async_mode: AsyncType,
|
||||
cache: PipelineCache,
|
||||
cache: Cache,
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
"""Build a noun graph from text units."""
|
||||
text_units = text_unit_df.loc[:, ["id", "text"]]
|
||||
@ -44,7 +44,7 @@ async def _extract_nodes(
|
||||
text_analyzer: BaseNounPhraseExtractor,
|
||||
num_threads: int,
|
||||
async_mode: AsyncType,
|
||||
cache: PipelineCache,
|
||||
cache: Cache,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Extract initial nodes and edges from text units.
|
||||
|
||||
@ -7,7 +7,8 @@ from collections.abc import Awaitable, Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag_cache import Cache
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
|
||||
|
||||
@ -42,7 +43,7 @@ CovariateExtractStrategy = Callable[
|
||||
list[str],
|
||||
dict[str, str],
|
||||
WorkflowCallbacks,
|
||||
PipelineCache,
|
||||
Cache,
|
||||
dict[str, Any],
|
||||
],
|
||||
Awaitable[CovariateExtractionResult],
|
||||
|
||||
@ -12,6 +12,7 @@ from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
from graphrag_cache import create_cache
|
||||
from graphrag_storage import Storage, create_storage
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
@ -20,7 +21,6 @@ from graphrag.index.run.utils import create_run_context
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.pipeline import Pipeline
|
||||
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.utils.api import create_cache_from_config
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -37,7 +37,10 @@ async def run_pipeline(
|
||||
"""Run all workflows using a simplified pipeline."""
|
||||
input_storage = create_storage(config.input.storage)
|
||||
output_storage = create_storage(config.output)
|
||||
cache = create_cache_from_config(config.cache)
|
||||
cache_storage: Storage | None = None
|
||||
if config.cache_storage:
|
||||
cache_storage = create_storage(config.cache_storage)
|
||||
cache = create_cache(config.cache, storage=cache_storage)
|
||||
|
||||
# load existing state in case any workflows are stateful
|
||||
state_json = await output_storage.get("context.json")
|
||||
|
||||
@ -3,11 +3,11 @@
|
||||
|
||||
"""Utility functions for the GraphRAG run module."""
|
||||
|
||||
from graphrag_cache import Cache
|
||||
from graphrag_cache.memory_cache import MemoryCache
|
||||
from graphrag_storage import Storage, create_storage
|
||||
from graphrag_storage.memory_storage import MemoryStorage
|
||||
|
||||
from graphrag.cache.memory_pipeline_cache import InMemoryCache
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks_manager import WorkflowCallbacksManager
|
||||
@ -21,7 +21,7 @@ def create_run_context(
|
||||
input_storage: Storage | None = None,
|
||||
output_storage: Storage | None = None,
|
||||
previous_storage: Storage | None = None,
|
||||
cache: PipelineCache | None = None,
|
||||
cache: Cache | None = None,
|
||||
callbacks: WorkflowCallbacks | None = None,
|
||||
stats: PipelineRunStats | None = None,
|
||||
state: PipelineState | None = None,
|
||||
@ -31,7 +31,7 @@ def create_run_context(
|
||||
input_storage=input_storage or MemoryStorage(),
|
||||
output_storage=output_storage or MemoryStorage(),
|
||||
previous_storage=previous_storage or MemoryStorage(),
|
||||
cache=cache or InMemoryCache(),
|
||||
cache=cache or MemoryCache(),
|
||||
callbacks=callbacks or NoopWorkflowCallbacks(),
|
||||
stats=stats or PipelineRunStats(),
|
||||
state=state or {},
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag_cache import Cache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.typing.state import PipelineState
|
||||
from graphrag.index.typing.stats import PipelineRunStats
|
||||
@ -24,7 +24,7 @@ class PipelineRunContext:
|
||||
"Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider."
|
||||
previous_storage: Storage
|
||||
"Storage for previous pipeline run when running in update mode."
|
||||
cache: PipelineCache
|
||||
cache: Cache
|
||||
"Cache instance for reading previous LLM responses."
|
||||
callbacks: WorkflowCallbacks
|
||||
"Callbacks to be called during the pipeline run."
|
||||
|
||||
@ -6,8 +6,8 @@
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
from graphrag_cache import Cache
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.operations.build_noun_graph.build_noun_graph import build_noun_graph
|
||||
@ -59,7 +59,7 @@ async def run_workflow(
|
||||
|
||||
async def extract_graph_nlp(
|
||||
text_units: pd.DataFrame,
|
||||
cache: PipelineCache,
|
||||
cache: Cache,
|
||||
text_analyzer: BaseNounPhraseExtractor,
|
||||
normalize_edge_weights: bool,
|
||||
num_threads: int,
|
||||
|
||||
@ -6,9 +6,9 @@
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
from graphrag_cache import Cache
|
||||
from graphrag_storage import Storage
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.run.utils import get_update_storages
|
||||
@ -59,7 +59,7 @@ async def _update_entities_and_relationships(
|
||||
delta_storage: Storage,
|
||||
output_storage: Storage,
|
||||
config: GraphRagConfig,
|
||||
cache: PipelineCache,
|
||||
cache: Cache,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame, dict]:
|
||||
"""Update Final Entities and Relationships output."""
|
||||
|
||||
@ -38,7 +38,8 @@ from graphrag.language_model.providers.litellm.types import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag_cache import Cache
|
||||
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.language_model.response.base import ModelResponse as MR # noqa: N817
|
||||
|
||||
@ -113,7 +114,7 @@ def _create_base_completions(
|
||||
|
||||
def _create_completions(
|
||||
model_config: "LanguageModelConfig",
|
||||
cache: "PipelineCache | None",
|
||||
cache: "Cache | None",
|
||||
cache_key_prefix: str,
|
||||
) -> tuple[FixedModelCompletion, AFixedModelCompletion]:
|
||||
"""Wrap the base litellm completion function with the model configuration and additional features.
|
||||
@ -203,7 +204,7 @@ class LitellmChatModel:
|
||||
self,
|
||||
name: str,
|
||||
config: "LanguageModelConfig",
|
||||
cache: "PipelineCache | None" = None,
|
||||
cache: "Cache | None" = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.name = name
|
||||
|
||||
@ -33,7 +33,8 @@ from graphrag.language_model.providers.litellm.types import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag_cache import Cache
|
||||
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
|
||||
litellm.suppress_debug_info = True
|
||||
@ -99,7 +100,7 @@ def _create_base_embeddings(
|
||||
|
||||
def _create_embeddings(
|
||||
model_config: "LanguageModelConfig",
|
||||
cache: "PipelineCache | None",
|
||||
cache: "Cache | None",
|
||||
cache_key_prefix: str,
|
||||
) -> tuple[FixedModelEmbedding, AFixedModelEmbedding]:
|
||||
"""Wrap the base litellm embedding function with the model configuration and additional features.
|
||||
@ -167,7 +168,7 @@ class LitellmEmbeddingModel:
|
||||
self,
|
||||
name: str,
|
||||
config: "LanguageModelConfig",
|
||||
cache: "PipelineCache | None" = None,
|
||||
cache: "Cache | None" = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.name = name
|
||||
|
||||
@ -15,7 +15,8 @@ from graphrag.language_model.providers.litellm.types import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag_cache import Cache
|
||||
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
|
||||
|
||||
@ -24,7 +25,7 @@ def with_cache(
|
||||
sync_fn: LitellmRequestFunc,
|
||||
async_fn: AsyncLitellmRequestFunc,
|
||||
model_config: "LanguageModelConfig",
|
||||
cache: "PipelineCache",
|
||||
cache: "Cache",
|
||||
request_type: Literal["chat", "embedding"],
|
||||
cache_key_prefix: str,
|
||||
) -> tuple[LitellmRequestFunc, AsyncLitellmRequestFunc]:
|
||||
|
||||
@ -8,9 +8,9 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from graphrag_cache.noop_cache import NoopCache
|
||||
from graphrag_storage import create_storage
|
||||
|
||||
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.input.factory import InputReaderFactory
|
||||
@ -60,7 +60,7 @@ async def load_docs_in_chunks(
|
||||
model_type=embeddings_llm_settings.type,
|
||||
config=embeddings_llm_settings,
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
cache=NoopPipelineCache(),
|
||||
cache=NoopCache(),
|
||||
)
|
||||
tokenizer = get_tokenizer(embeddings_llm_settings)
|
||||
input_storage = create_storage(config.input.storage)
|
||||
|
||||
@ -6,10 +6,7 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from graphrag.cache.factory import CacheFactory
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.config.embeddings import create_index_name
|
||||
from graphrag.config.models.cache_config import CacheConfig
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.vector_stores.base import (
|
||||
BaseVectorStore,
|
||||
@ -98,15 +95,6 @@ def load_search_prompt(prompt_config: str | None) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def create_cache_from_config(cache: CacheConfig) -> PipelineCache:
|
||||
"""Create a cache object from the config."""
|
||||
cache_config = cache.model_dump()
|
||||
return CacheFactory().create(
|
||||
strategy=cache_config["type"],
|
||||
init_args=cache_config,
|
||||
)
|
||||
|
||||
|
||||
def truncate(text: str, max_length: int) -> str:
|
||||
"""Truncate a string to a maximum length."""
|
||||
if len(text) <= max_length:
|
||||
|
||||
@ -40,6 +40,7 @@ dependencies = [
|
||||
"azure-storage-blob>=12.24.0",
|
||||
"devtools>=0.12.2",
|
||||
"environs>=11.0.0",
|
||||
"graphrag-cache==2.7.0",
|
||||
"graphrag-common==2.7.0",
|
||||
"graphrag-storage==2.7.0",
|
||||
"graspologic-native>=1.2.5",
|
||||
|
||||
@ -55,6 +55,7 @@ members = ["packages/*"]
|
||||
[tool.uv.sources]
|
||||
graphrag-common = { workspace = true }
|
||||
graphrag-storage = { workspace = true }
|
||||
graphrag-cache = { workspace = true }
|
||||
|
||||
# Keep poethepoet for task management to minimize changes
|
||||
[tool.poe.tasks]
|
||||
@ -71,6 +72,7 @@ _semversioner_changelog = "semversioner changelog > CHANGELOG.md"
|
||||
_semversioner_update_graphrag_toml_version = "update-toml update --file packages/graphrag/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
|
||||
_semversioner_update_graphrag_common_toml_version = "update-toml update --file packages/graphrag-common/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
|
||||
_semversioner_update_graphrag_storage_toml_version = "update-toml update --file packages/graphrag-storage/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
|
||||
_semversioner_update_graphrag_cache_toml_version = "update-toml update --file packages/graphrag-cache/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
|
||||
_semversioner_update_workspace_dependency_versions = "python -m scripts.update_workspace_dependency_versions"
|
||||
semversioner_add = "semversioner add-change"
|
||||
coverage_report = 'coverage report --omit "**/tests/**" --show-missing'
|
||||
@ -106,6 +108,7 @@ sequence = [
|
||||
'_semversioner_update_graphrag_toml_version',
|
||||
'_semversioner_update_graphrag_common_toml_version',
|
||||
'_semversioner_update_graphrag_storage_toml_version',
|
||||
'_semversioner_update_graphrag_cache_toml_version',
|
||||
'_semversioner_update_workspace_dependency_versions',
|
||||
'_sync',
|
||||
]
|
||||
|
||||
141
tests/integration/cache/test_factory.py
vendored
141
tests/integration/cache/test_factory.py
vendored
@ -8,12 +8,12 @@ These tests will test the CacheFactory() class and the creation of each cache ty
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from graphrag.cache.factory import CacheFactory
|
||||
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
|
||||
from graphrag.cache.memory_pipeline_cache import InMemoryCache
|
||||
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.config.enums import CacheType
|
||||
from graphrag_cache import Cache, CacheConfig, CacheType, create_cache, register_cache
|
||||
from graphrag_cache.cache_factory import cache_factory
|
||||
from graphrag_cache.json_cache import JsonCache
|
||||
from graphrag_cache.memory_cache import MemoryCache
|
||||
from graphrag_cache.noop_cache import NoopCache
|
||||
from graphrag_storage import StorageConfig, StorageType, create_storage
|
||||
|
||||
# cspell:disable-next-line well-known-key
|
||||
WELL_KNOWN_BLOB_STORAGE_KEY = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;"
|
||||
@ -22,31 +22,55 @@ WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://127.0.0.1:8081/;A
|
||||
|
||||
|
||||
def test_create_noop_cache():
|
||||
cache = CacheFactory().create(strategy=CacheType.none.value)
|
||||
assert isinstance(cache, NoopPipelineCache)
|
||||
cache = create_cache(
|
||||
CacheConfig(
|
||||
type=CacheType.Noop,
|
||||
)
|
||||
)
|
||||
assert isinstance(cache, NoopCache)
|
||||
|
||||
|
||||
def test_create_memory_cache():
|
||||
cache = CacheFactory().create(strategy=CacheType.memory.value)
|
||||
assert isinstance(cache, InMemoryCache)
|
||||
cache = create_cache(
|
||||
CacheConfig(
|
||||
type=CacheType.Memory,
|
||||
)
|
||||
)
|
||||
assert isinstance(cache, MemoryCache)
|
||||
|
||||
|
||||
def test_create_file_cache():
|
||||
cache = CacheFactory().create(
|
||||
strategy=CacheType.file.value,
|
||||
init_args={"base_dir": "testcache"},
|
||||
storage = create_storage(
|
||||
StorageConfig(
|
||||
type=StorageType.Memory,
|
||||
)
|
||||
)
|
||||
assert isinstance(cache, JsonPipelineCache)
|
||||
cache = create_cache(
|
||||
CacheConfig(
|
||||
type=CacheType.Json,
|
||||
),
|
||||
storage=storage,
|
||||
)
|
||||
assert isinstance(cache, JsonCache)
|
||||
|
||||
|
||||
def test_create_blob_cache():
|
||||
init_args = {
|
||||
"connection_string": WELL_KNOWN_BLOB_STORAGE_KEY,
|
||||
"container_name": "testcontainer",
|
||||
"base_dir": "testcache",
|
||||
}
|
||||
cache = CacheFactory().create(strategy=CacheType.blob.value, init_args=init_args)
|
||||
assert isinstance(cache, JsonPipelineCache)
|
||||
storage = create_storage(
|
||||
StorageConfig(
|
||||
type=StorageType.AzureBlob,
|
||||
connection_string=WELL_KNOWN_BLOB_STORAGE_KEY,
|
||||
container_name="testcontainer",
|
||||
base_dir="testcache",
|
||||
)
|
||||
)
|
||||
cache = create_cache(
|
||||
CacheConfig(
|
||||
type=CacheType.Json,
|
||||
),
|
||||
storage=storage,
|
||||
)
|
||||
|
||||
assert isinstance(cache, JsonCache)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@ -54,15 +78,21 @@ def test_create_blob_cache():
|
||||
reason="cosmosdb emulator is only available on windows runners at this time",
|
||||
)
|
||||
def test_create_cosmosdb_cache():
|
||||
init_args = {
|
||||
"connection_string": WELL_KNOWN_COSMOS_CONNECTION_STRING,
|
||||
"database_name": "testdatabase",
|
||||
"container_name": "testcontainer",
|
||||
}
|
||||
cache = CacheFactory().create(
|
||||
strategy=CacheType.cosmosdb.value, init_args=init_args
|
||||
storage = create_storage(
|
||||
StorageConfig(
|
||||
type=StorageType.AzureCosmos,
|
||||
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
|
||||
database_name="testdatabase",
|
||||
container_name="testcontainer",
|
||||
)
|
||||
)
|
||||
assert isinstance(cache, JsonPipelineCache)
|
||||
cache = create_cache(
|
||||
CacheConfig(
|
||||
type=CacheType.Json,
|
||||
),
|
||||
storage=storage,
|
||||
)
|
||||
assert isinstance(cache, JsonCache)
|
||||
|
||||
|
||||
def test_register_and_create_custom_cache():
|
||||
@ -70,17 +100,14 @@ def test_register_and_create_custom_cache():
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create a mock that satisfies the PipelineCache interface
|
||||
custom_cache_class = MagicMock(spec=PipelineCache)
|
||||
custom_cache_class = MagicMock(spec=Cache)
|
||||
# Make the mock return a mock instance when instantiated
|
||||
instance = MagicMock()
|
||||
instance.initialized = True
|
||||
custom_cache_class.return_value = instance
|
||||
|
||||
CacheFactory().register(
|
||||
strategy="custom",
|
||||
initializer=lambda **kwargs: custom_cache_class(**kwargs),
|
||||
)
|
||||
cache = CacheFactory().create(strategy="custom")
|
||||
register_cache("custom", lambda **kwargs: custom_cache_class(**kwargs))
|
||||
cache = create_cache(CacheConfig(type="custom"))
|
||||
|
||||
assert custom_cache_class.called
|
||||
assert cache is instance
|
||||
@ -88,45 +115,21 @@ def test_register_and_create_custom_cache():
|
||||
assert cache.initialized is True # type: ignore # Attribute only exists on our mock
|
||||
|
||||
# Check if it's in the list of registered cache types
|
||||
assert "custom" in CacheFactory()
|
||||
assert "custom" in cache_factory
|
||||
|
||||
|
||||
def test_create_unknown_cache():
|
||||
with pytest.raises(ValueError, match="Strategy 'unknown' is not registered\\."):
|
||||
CacheFactory().create(strategy="unknown")
|
||||
|
||||
|
||||
def test_is_supported_type():
|
||||
# Test built-in types
|
||||
assert CacheType.none.value in CacheFactory()
|
||||
assert CacheType.memory.value in CacheFactory()
|
||||
assert CacheType.file.value in CacheFactory()
|
||||
assert CacheType.blob.value in CacheFactory()
|
||||
assert CacheType.cosmosdb.value in CacheFactory()
|
||||
|
||||
# Test unknown type
|
||||
assert "unknown" not in CacheFactory()
|
||||
|
||||
|
||||
def test_enum_and_string_compatibility():
|
||||
"""Test that both enum and string types work for cache creation."""
|
||||
# Test with enum
|
||||
cache_enum = CacheFactory().create(strategy=CacheType.memory)
|
||||
assert isinstance(cache_enum, InMemoryCache)
|
||||
|
||||
# Test with string
|
||||
cache_str = CacheFactory().create(strategy="memory")
|
||||
assert isinstance(cache_str, InMemoryCache)
|
||||
|
||||
# Both should create the same type
|
||||
assert type(cache_enum) is type(cache_str)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="CacheConfig\\.type 'unknown' is not registered in the CacheFactory\\.",
|
||||
):
|
||||
create_cache(CacheConfig(type="unknown"))
|
||||
|
||||
|
||||
def test_register_class_directly_works():
|
||||
"""Test that registering a class directly works (CacheFactory() allows this)."""
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
|
||||
class CustomCache(PipelineCache):
|
||||
class CustomCache(Cache):
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
@ -149,11 +152,11 @@ def test_register_class_directly_works():
|
||||
return self
|
||||
|
||||
# CacheFactory() allows registering classes directly (no TypeError)
|
||||
CacheFactory().register("custom_class", CustomCache)
|
||||
register_cache("custom_class", CustomCache)
|
||||
|
||||
# Verify it was registered
|
||||
assert "custom_class" in CacheFactory()
|
||||
assert "custom_class" in cache_factory
|
||||
|
||||
# Test creating an instance
|
||||
cache = CacheFactory().create(strategy="custom_class")
|
||||
cache = create_cache(CacheConfig(type="custom_class"))
|
||||
assert isinstance(cache, CustomCache)
|
||||
|
||||
@ -5,7 +5,6 @@ from dataclasses import asdict
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.models.basic_search_config import BasicSearchConfig
|
||||
from graphrag.config.models.cache_config import CacheConfig
|
||||
from graphrag.config.models.chunking_config import ChunkingConfig
|
||||
from graphrag.config.models.cluster_graph_config import ClusterGraphConfig
|
||||
from graphrag.config.models.community_reports_config import CommunityReportsConfig
|
||||
@ -29,6 +28,7 @@ from graphrag.config.models.summarize_descriptions_config import (
|
||||
SummarizeDescriptionsConfig,
|
||||
)
|
||||
from graphrag.config.models.vector_store_config import VectorStoreConfig
|
||||
from graphrag_cache import CacheConfig
|
||||
from graphrag_storage import StorageConfig
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -138,11 +138,8 @@ def assert_storage_config(actual: StorageConfig, expected: StorageConfig) -> Non
|
||||
|
||||
def assert_cache_configs(actual: CacheConfig, expected: CacheConfig) -> None:
|
||||
assert actual.type == expected.type
|
||||
assert actual.base_dir == expected.base_dir
|
||||
assert actual.connection_string == expected.connection_string
|
||||
assert actual.container_name == expected.container_name
|
||||
assert actual.storage_account_blob_url == expected.storage_account_blob_url
|
||||
assert actual.cosmosdb_account_url == expected.cosmosdb_account_url
|
||||
assert actual.encoding == expected.encoding
|
||||
assert actual.name == expected.name
|
||||
|
||||
|
||||
def assert_input_configs(actual: InputConfig, expected: InputConfig) -> None:
|
||||
|
||||
@ -4,17 +4,26 @@ import asyncio
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
|
||||
from graphrag_storage.file_storage import (
|
||||
FileStorage,
|
||||
)
|
||||
from graphrag_cache import CacheConfig, CacheType
|
||||
from graphrag_cache import create_cache as cc
|
||||
from graphrag_storage import StorageConfig, StorageType, create_storage
|
||||
|
||||
TEMP_DIR = "./.tmp"
|
||||
|
||||
|
||||
def create_cache():
|
||||
storage = FileStorage(base_dir=os.path.join(os.getcwd(), ".tmp"))
|
||||
return JsonPipelineCache(storage)
|
||||
storage = create_storage(
|
||||
StorageConfig(
|
||||
type=StorageType.File,
|
||||
base_dir=os.path.join(os.getcwd(), ".tmp"),
|
||||
)
|
||||
)
|
||||
return cc(
|
||||
CacheConfig(
|
||||
type=CacheType.Json,
|
||||
),
|
||||
storage=storage,
|
||||
)
|
||||
|
||||
|
||||
class TestFilePipelineCache(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
18
uv.lock
generated
18
uv.lock
generated
@ -10,6 +10,7 @@ resolution-markers = [
|
||||
[manifest]
|
||||
members = [
|
||||
"graphrag",
|
||||
"graphrag-cache",
|
||||
"graphrag-common",
|
||||
"graphrag-monorepo",
|
||||
"graphrag-storage",
|
||||
@ -1042,6 +1043,7 @@ dependencies = [
|
||||
{ name = "azure-storage-blob" },
|
||||
{ name = "devtools" },
|
||||
{ name = "environs" },
|
||||
{ name = "graphrag-cache" },
|
||||
{ name = "graphrag-common" },
|
||||
{ name = "graphrag-storage" },
|
||||
{ name = "graspologic-native" },
|
||||
@ -1074,6 +1076,7 @@ requires-dist = [
|
||||
{ name = "azure-storage-blob", specifier = ">=12.24.0" },
|
||||
{ name = "devtools", specifier = ">=0.12.2" },
|
||||
{ name = "environs", specifier = ">=11.0.0" },
|
||||
{ name = "graphrag-cache", editable = "packages/graphrag-cache" },
|
||||
{ name = "graphrag-common", editable = "packages/graphrag-common" },
|
||||
{ name = "graphrag-storage", editable = "packages/graphrag-storage" },
|
||||
{ name = "graspologic-native", specifier = ">=1.2.5" },
|
||||
@ -1095,6 +1098,21 @@ requires-dist = [
|
||||
{ name = "typing-extensions", specifier = ">=4.12.2" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "graphrag-cache"
|
||||
version = "2.7.0"
|
||||
source = { editable = "packages/graphrag-cache" }
|
||||
dependencies = [
|
||||
{ name = "graphrag-common" },
|
||||
{ name = "graphrag-storage" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "graphrag-common", editable = "packages/graphrag-common" },
|
||||
{ name = "graphrag-storage", editable = "packages/graphrag-storage" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "graphrag-common"
|
||||
version = "2.7.0"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user