mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-13 16:47:20 +08:00
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
* progress on vector customization * fix for lancedb vectors * cosmosdb implementation * uv run poe format * clean test for vector store * semversioner update * test_factory.py integration test fixes * fixes for cosmosdb test * integration test fix for lancedb * uv fix for format * test fixes * fixes for tests * fix cosmosdb bug * print statement * test * test * fix cosmosdb bug * test validation * validation cosmosdb * validate cosmosdb * fix cosmosdb * fix small feedback from PR --------- Co-authored-by: Gaudy Blanco <gaudy-microsoft@MacBook-Pro-m4-Gaudy-For-Work.local>
169 lines
5.9 KiB
Python
169 lines
5.9 KiB
Python
# Copyright (c) 2024 Microsoft Corporation.
|
|
# Licensed under the MIT License
|
|
"""VectorStoreFactory Tests.
|
|
|
|
These tests will test the VectorStoreFactory class and the creation of each vector store type that is natively supported.
|
|
"""
|
|
|
|
import pytest
|
|
|
|
from graphrag.config.enums import VectorStoreType
|
|
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
|
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
|
|
from graphrag.vector_stores.base import BaseVectorStore
|
|
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
|
|
from graphrag.vector_stores.factory import VectorStoreFactory
|
|
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
|
|
|
|
|
def test_create_lancedb_vector_store():
|
|
kwargs = {
|
|
"db_uri": "/tmp/lancedb",
|
|
}
|
|
vector_store = VectorStoreFactory.create_vector_store(
|
|
vector_store_type=VectorStoreType.LanceDB.value,
|
|
vector_store_schema_config=VectorStoreSchemaConfig(
|
|
index_name="test_collection"
|
|
),
|
|
kwargs=kwargs,
|
|
)
|
|
assert isinstance(vector_store, LanceDBVectorStore)
|
|
assert vector_store.index_name == "test_collection"
|
|
|
|
|
|
@pytest.mark.skip(reason="Azure AI Search requires credentials and setup")
|
|
def test_create_azure_ai_search_vector_store():
|
|
kwargs = {
|
|
"url": "https://test.search.windows.net",
|
|
"api_key": "test_key",
|
|
}
|
|
vector_store = VectorStoreFactory.create_vector_store(
|
|
vector_store_type=VectorStoreType.AzureAISearch.value,
|
|
vector_store_schema_config=VectorStoreSchemaConfig(
|
|
index_name="test_collection"
|
|
),
|
|
kwargs=kwargs,
|
|
)
|
|
assert isinstance(vector_store, AzureAISearchVectorStore)
|
|
|
|
|
|
@pytest.mark.skip(reason="CosmosDB requires credentials and setup")
|
|
def test_create_cosmosdb_vector_store():
|
|
kwargs = {
|
|
"connection_string": "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=test_key==",
|
|
"database_name": "test_db",
|
|
}
|
|
|
|
vector_store = VectorStoreFactory.create_vector_store(
|
|
vector_store_type=VectorStoreType.CosmosDB.value,
|
|
vector_store_schema_config=VectorStoreSchemaConfig(
|
|
index_name="test_collection"
|
|
),
|
|
kwargs=kwargs,
|
|
)
|
|
|
|
assert isinstance(vector_store, CosmosDBVectorStore)
|
|
|
|
|
|
def test_register_and_create_custom_vector_store():
|
|
"""Test registering and creating a custom vector store type."""
|
|
from unittest.mock import MagicMock
|
|
|
|
# Create a mock that satisfies the BaseVectorStore interface
|
|
custom_vector_store_class = MagicMock(spec=BaseVectorStore)
|
|
# Make the mock return a mock instance when instantiated
|
|
instance = MagicMock()
|
|
instance.initialized = True
|
|
custom_vector_store_class.return_value = instance
|
|
|
|
VectorStoreFactory.register(
|
|
"custom", lambda **kwargs: custom_vector_store_class(**kwargs)
|
|
)
|
|
|
|
vector_store = VectorStoreFactory.create_vector_store(
|
|
vector_store_type="custom",
|
|
vector_store_schema_config=VectorStoreSchemaConfig(),
|
|
kwargs={},
|
|
)
|
|
|
|
assert custom_vector_store_class.called
|
|
assert vector_store is instance
|
|
# Access the attribute we set on our mock
|
|
assert vector_store.initialized is True # type: ignore # Attribute only exists on our mock
|
|
|
|
# Check if it's in the list of registered vector store types
|
|
assert "custom" in VectorStoreFactory.get_vector_store_types()
|
|
assert VectorStoreFactory.is_supported_type("custom")
|
|
|
|
|
|
def test_get_vector_store_types():
|
|
vector_store_types = VectorStoreFactory.get_vector_store_types()
|
|
# Check that built-in types are registered
|
|
assert VectorStoreType.LanceDB.value in vector_store_types
|
|
assert VectorStoreType.AzureAISearch.value in vector_store_types
|
|
assert VectorStoreType.CosmosDB.value in vector_store_types
|
|
|
|
|
|
def test_create_unknown_vector_store():
|
|
with pytest.raises(ValueError, match="Unknown vector store type: unknown"):
|
|
VectorStoreFactory.create_vector_store(
|
|
vector_store_type="unknown",
|
|
vector_store_schema_config=VectorStoreSchemaConfig(),
|
|
kwargs={},
|
|
)
|
|
|
|
|
|
def test_is_supported_type():
|
|
# Test built-in types
|
|
assert VectorStoreFactory.is_supported_type(VectorStoreType.LanceDB.value)
|
|
assert VectorStoreFactory.is_supported_type(VectorStoreType.AzureAISearch.value)
|
|
assert VectorStoreFactory.is_supported_type(VectorStoreType.CosmosDB.value)
|
|
|
|
# Test unknown type
|
|
assert not VectorStoreFactory.is_supported_type("unknown")
|
|
|
|
|
|
def test_register_class_directly_works():
|
|
"""Test that registering a class directly works (VectorStoreFactory allows this)."""
|
|
from graphrag.vector_stores.base import BaseVectorStore
|
|
|
|
class CustomVectorStore(BaseVectorStore):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
def connect(self, **kwargs):
|
|
pass
|
|
|
|
def load_documents(self, documents, overwrite=True):
|
|
pass
|
|
|
|
def similarity_search_by_vector(self, query_embedding, k=10, **kwargs):
|
|
return []
|
|
|
|
def similarity_search_by_text(self, text, text_embedder, k=10, **kwargs):
|
|
return []
|
|
|
|
def filter_by_id(self, include_ids):
|
|
return {}
|
|
|
|
def search_by_id(self, id):
|
|
from graphrag.vector_stores.base import VectorStoreDocument
|
|
|
|
return VectorStoreDocument(id=id, text="test", vector=None)
|
|
|
|
# VectorStoreFactory allows registering classes directly (no TypeError)
|
|
VectorStoreFactory.register("custom_class", CustomVectorStore)
|
|
|
|
# Verify it was registered
|
|
assert "custom_class" in VectorStoreFactory.get_vector_store_types()
|
|
assert VectorStoreFactory.is_supported_type("custom_class")
|
|
|
|
# Test creating an instance
|
|
vector_store = VectorStoreFactory.create_vector_store(
|
|
vector_store_type="custom_class",
|
|
vector_store_schema_config=VectorStoreSchemaConfig(),
|
|
kwargs={"collection_name": "test"},
|
|
)
|
|
|
|
assert isinstance(vector_store, CustomVectorStore)
|