graphrag/graphrag/cache/factory.py
Josh Bradley 823342188d
Cleanup factory methods (#1482)
* cleanup factory methods to have similar design pattern across codebase

* add semversioner file

* cleanup logging factory

* update developer guide

* add comment

* typo fix

* cleanup reporter terminology

* renmae reporter to logger

* fix comments

* update comment

* instantiate factory classes correctly and update index api callback parameter

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
2024-12-10 16:11:11 -06:00

58 lines
2.0 KiB
Python

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing create_cache method definition."""
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from graphrag.config.enums import CacheType
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
if TYPE_CHECKING:
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
from graphrag.cache.memory_pipeline_cache import InMemoryCache
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
class CacheFactory:
"""A factory class for cache implementations.
Includes a method for users to register a custom cache implementation.
"""
cache_types: ClassVar[dict[str, type]] = {}
@classmethod
def register(cls, cache_type: str, cache: type):
"""Register a custom cache implementation."""
cls.cache_types[cache_type] = cache
@classmethod
def create_cache(
cls, cache_type: CacheType | str | None, root_dir: str, kwargs: dict
) -> PipelineCache:
"""Create or get a cache from the provided type."""
if not cache_type:
return NoopPipelineCache()
match cache_type:
case CacheType.none:
return NoopPipelineCache()
case CacheType.memory:
return InMemoryCache()
case CacheType.file:
return JsonPipelineCache(
FilePipelineStorage(root_dir=root_dir).child(kwargs["base_dir"])
)
case CacheType.blob:
return JsonPipelineCache(BlobPipelineStorage(**kwargs))
case _:
if cache_type in cls.cache_types:
return cls.cache_types[cache_type](**kwargs)
msg = f"Unknown cache type: {cache_type}"
raise ValueError(msg)