mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 09:07:20 +08:00
Some checks failed
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
* Fix pipeline recursion * Remove base_dir from storage.find * Remove max_count from storage.find * Remove prefix on storage integ test * Add base_dir in creation_date test * Wrap base_dir in Path * Use constants for input/update directories
154 lines
5.5 KiB
Python
154 lines
5.5 KiB
Python
# Copyright (c) 2024 Microsoft Corporation.
|
|
# Licensed under the MIT License
|
|
"""StorageFactory Tests.
|
|
|
|
These tests will test the StorageFactory class and the creation of each storage type that is natively supported.
|
|
"""
|
|
|
|
import sys
|
|
|
|
import pytest
|
|
from graphrag.config.enums import StorageType
|
|
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
|
|
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
|
|
from graphrag.storage.factory import StorageFactory
|
|
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
|
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
|
|
from graphrag.storage.pipeline_storage import PipelineStorage
|
|
|
|
# cspell:disable-next-line well-known-key
|
|
WELL_KNOWN_BLOB_STORAGE_KEY = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;"
|
|
# cspell:disable-next-line well-known-key
|
|
WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
|
|
|
|
|
|
@pytest.mark.skip(reason="Blob storage emulator is not available in this environment")
|
|
def test_create_blob_storage():
|
|
kwargs = {
|
|
"type": "blob",
|
|
"connection_string": WELL_KNOWN_BLOB_STORAGE_KEY,
|
|
"base_dir": "testbasedir",
|
|
"container_name": "testcontainer",
|
|
}
|
|
storage = StorageFactory().create(StorageType.blob.value, kwargs)
|
|
assert isinstance(storage, BlobPipelineStorage)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not sys.platform.startswith("win"),
|
|
reason="cosmosdb emulator is only available on windows runners at this time",
|
|
)
|
|
def test_create_cosmosdb_storage():
|
|
kwargs = {
|
|
"type": "cosmosdb",
|
|
"connection_string": WELL_KNOWN_COSMOS_CONNECTION_STRING,
|
|
"base_dir": "testdatabase",
|
|
"container_name": "testcontainer",
|
|
}
|
|
storage = StorageFactory().create(StorageType.cosmosdb.value, kwargs)
|
|
assert isinstance(storage, CosmosDBPipelineStorage)
|
|
|
|
|
|
def test_create_file():
|
|
kwargs = {"type": "file", "base_dir": "/tmp/teststorage"}
|
|
storage = StorageFactory().create(StorageType.file.value, kwargs)
|
|
assert isinstance(storage, FilePipelineStorage)
|
|
|
|
|
|
def test_create_memory_storage():
|
|
kwargs = {} # MemoryPipelineStorage doesn't accept any constructor parameters
|
|
storage = StorageFactory().create(StorageType.memory.value, kwargs)
|
|
assert isinstance(storage, MemoryPipelineStorage)
|
|
|
|
|
|
def test_register_and_create_custom_storage():
|
|
"""Test registering and creating a custom storage type."""
|
|
from unittest.mock import MagicMock
|
|
|
|
# Create a mock that satisfies the PipelineStorage interface
|
|
custom_storage_class = MagicMock(spec=PipelineStorage)
|
|
# Make the mock return a mock instance when instantiated
|
|
instance = MagicMock()
|
|
# We can set attributes on the mock instance, even if they don't exist on PipelineStorage
|
|
instance.initialized = True
|
|
custom_storage_class.return_value = instance
|
|
|
|
StorageFactory().register("custom", lambda **kwargs: custom_storage_class(**kwargs))
|
|
storage = StorageFactory().create("custom", {})
|
|
|
|
assert custom_storage_class.called
|
|
assert storage is instance
|
|
# Access the attribute we set on our mock
|
|
assert storage.initialized is True # type: ignore # Attribute only exists on our mock
|
|
|
|
# Check if it's in the list of registered storage types
|
|
assert "custom" in StorageFactory()
|
|
|
|
|
|
def test_get_storage_types():
|
|
# Check that built-in types are registered
|
|
assert StorageType.file.value in StorageFactory()
|
|
assert StorageType.memory.value in StorageFactory()
|
|
assert StorageType.blob.value in StorageFactory()
|
|
assert StorageType.cosmosdb.value in StorageFactory()
|
|
|
|
|
|
def test_create_unknown_storage():
|
|
with pytest.raises(ValueError, match="Strategy 'unknown' is not registered\\."):
|
|
StorageFactory().create("unknown")
|
|
|
|
|
|
def test_register_class_directly_works():
|
|
"""Test that registering a class directly works (StorageFactory allows this)."""
|
|
import re
|
|
from collections.abc import Iterator
|
|
from typing import Any
|
|
|
|
from graphrag.storage.pipeline_storage import PipelineStorage
|
|
|
|
class CustomStorage(PipelineStorage):
|
|
def __init__(self, **kwargs):
|
|
pass
|
|
|
|
def find(
|
|
self,
|
|
file_pattern: re.Pattern[str],
|
|
) -> Iterator[str]:
|
|
return iter([])
|
|
|
|
async def get(
|
|
self, key: str, as_bytes: bool | None = None, encoding: str | None = None
|
|
) -> Any:
|
|
return None
|
|
|
|
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
|
|
pass
|
|
|
|
async def delete(self, key: str) -> None:
|
|
pass
|
|
|
|
async def has(self, key: str) -> bool:
|
|
return False
|
|
|
|
async def clear(self) -> None:
|
|
pass
|
|
|
|
def child(self, name: str | None) -> "PipelineStorage":
|
|
return self
|
|
|
|
def keys(self) -> list[str]:
|
|
return []
|
|
|
|
async def get_creation_date(self, key: str) -> str:
|
|
return "2024-01-01 00:00:00 +0000"
|
|
|
|
# StorageFactory allows registering classes directly (no TypeError)
|
|
StorageFactory().register("custom_class", CustomStorage)
|
|
|
|
# Verify it was registered
|
|
assert "custom_class" in StorageFactory()
|
|
|
|
# Test creating an instance
|
|
storage = StorageFactory().create("custom_class")
|
|
assert isinstance(storage, CustomStorage)
|