mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Vector Store Integration Tests (#1856)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
* Add vector store id reference to embeddings config. * generated initial vector store pytests * cleaned up cosmosdb vector store test * fixed class name typo and debugged cosmosdb vector store test * reset emulator connection string * remove unneccessary comments * removed extra comments from azure ai search test * ruff * semversioner * fix cicd issues * bypass diskANN policy for test env * handle floating point inprecisions --------- Co-authored-by: Derek Worthen <worthend.derek@gmail.com>
This commit is contained in:
parent
ffd8db7104
commit
61769dd47e
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "add vector store integration tests"
|
||||
}
|
||||
@ -7,6 +7,7 @@ import json
|
||||
from typing import Any
|
||||
|
||||
from azure.cosmos import ContainerProxy, CosmosClient, DatabaseProxy
|
||||
from azure.cosmos.exceptions import CosmosHttpResponseError
|
||||
from azure.cosmos.partition_key import PartitionKey
|
||||
from azure.identity import DefaultAzureCredential
|
||||
|
||||
@ -19,7 +20,7 @@ from graphrag.vector_stores.base import (
|
||||
)
|
||||
|
||||
|
||||
class CosmosDBVectoreStore(BaseVectorStore):
|
||||
class CosmosDBVectorStore(BaseVectorStore):
|
||||
"""Azure CosmosDB vector storage implementation."""
|
||||
|
||||
_cosmos_client: CosmosClient
|
||||
@ -99,16 +100,32 @@ class CosmosDBVectoreStore(BaseVectorStore):
|
||||
"automatic": True,
|
||||
"includedPaths": [{"path": "/*"}],
|
||||
"excludedPaths": [{"path": "/_etag/?"}, {"path": "/vector/*"}],
|
||||
"vectorIndexes": [{"path": "/vector", "type": "diskANN"}],
|
||||
}
|
||||
|
||||
# Create the container and container client
|
||||
self._database_client.create_container_if_not_exists(
|
||||
id=self._container_name,
|
||||
partition_key=partition_key,
|
||||
indexing_policy=indexing_policy,
|
||||
vector_embedding_policy=vector_embedding_policy,
|
||||
)
|
||||
# Currently, the CosmosDB emulator does not support the diskANN policy.
|
||||
try:
|
||||
# First try with the standard diskANN policy
|
||||
indexing_policy["vectorIndexes"] = [{"path": "/vector", "type": "diskANN"}]
|
||||
|
||||
# Create the container and container client
|
||||
self._database_client.create_container_if_not_exists(
|
||||
id=self._container_name,
|
||||
partition_key=partition_key,
|
||||
indexing_policy=indexing_policy,
|
||||
vector_embedding_policy=vector_embedding_policy,
|
||||
)
|
||||
except CosmosHttpResponseError:
|
||||
# If diskANN fails (likely in emulator), retry without vector indexes
|
||||
indexing_policy.pop("vectorIndexes", None)
|
||||
|
||||
# Create the container with compatible indexing policy
|
||||
self._database_client.create_container_if_not_exists(
|
||||
id=self._container_name,
|
||||
partition_key=partition_key,
|
||||
indexing_policy=indexing_policy,
|
||||
vector_embedding_policy=vector_embedding_policy,
|
||||
)
|
||||
|
||||
self._container_client = self._database_client.get_container_client(
|
||||
self._container_name
|
||||
)
|
||||
@ -157,13 +174,46 @@ class CosmosDBVectoreStore(BaseVectorStore):
|
||||
msg = "Container client is not initialized."
|
||||
raise ValueError(msg)
|
||||
|
||||
query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608
|
||||
query_params = [{"name": "@embedding", "value": query_embedding}]
|
||||
items = self._container_client.query_items(
|
||||
query=query,
|
||||
parameters=query_params,
|
||||
enable_cross_partition_query=True,
|
||||
)
|
||||
try:
|
||||
query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608
|
||||
query_params = [{"name": "@embedding", "value": query_embedding}]
|
||||
items = list(
|
||||
self._container_client.query_items(
|
||||
query=query,
|
||||
parameters=query_params,
|
||||
enable_cross_partition_query=True,
|
||||
)
|
||||
)
|
||||
except (CosmosHttpResponseError, ValueError):
|
||||
# Currently, the CosmosDB emulator does not support the VectorDistance function.
|
||||
# For emulator or test environments - fetch all items and calculate distance locally
|
||||
query = "SELECT c.id, c.text, c.vector, c.attributes FROM c"
|
||||
items = list(
|
||||
self._container_client.query_items(
|
||||
query=query,
|
||||
enable_cross_partition_query=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Calculate cosine similarity locally (1 - cosine distance)
|
||||
from numpy import dot
|
||||
from numpy.linalg import norm
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
if norm(a) * norm(b) == 0:
|
||||
return 0.0
|
||||
return dot(a, b) / (norm(a) * norm(b))
|
||||
|
||||
# Calculate scores for all items
|
||||
for item in items:
|
||||
item_vector = item.get("vector", [])
|
||||
similarity = cosine_similarity(query_embedding, item_vector)
|
||||
item["SimilarityScore"] = similarity
|
||||
|
||||
# Sort by similarity score (higher is better) and take top k
|
||||
items = sorted(
|
||||
items, key=lambda x: x.get("SimilarityScore", 0.0), reverse=True
|
||||
)[:k]
|
||||
|
||||
return [
|
||||
VectorStoreSearchResult(
|
||||
@ -214,3 +264,8 @@ class CosmosDBVectoreStore(BaseVectorStore):
|
||||
text=item.get("text", ""),
|
||||
attributes=(json.loads(item.get("attributes", "{}"))),
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the vector store."""
|
||||
self._delete_container()
|
||||
self._delete_database()
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import ClassVar
|
||||
|
||||
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
from graphrag.vector_stores.cosmosdb import CosmosDBVectoreStore
|
||||
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
|
||||
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ class VectorStoreFactory:
|
||||
case VectorStoreType.AzureAISearch:
|
||||
return AzureAISearchVectorStore(**kwargs)
|
||||
case VectorStoreType.CosmosDB:
|
||||
return CosmosDBVectoreStore(**kwargs)
|
||||
return CosmosDBVectorStore(**kwargs)
|
||||
case _:
|
||||
if vector_store_type in cls.vector_store_types:
|
||||
return cls.vector_store_types[vector_store_type](**kwargs)
|
||||
|
||||
4
tests/integration/vector_stores/__init__.py
Normal file
4
tests/integration/vector_stores/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Integration tests for vector store implementations."""
|
||||
146
tests/integration/vector_stores/test_azure_ai_search.py
Normal file
146
tests/integration/vector_stores/test_azure_ai_search.py
Normal file
@ -0,0 +1,146 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Integration tests for Azure AI Search vector store implementation."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
|
||||
from graphrag.vector_stores.base import VectorStoreDocument
|
||||
|
||||
TEST_AZURE_AI_SEARCH_URL = os.environ.get(
|
||||
"TEST_AZURE_AI_SEARCH_URL", "https://test-url.search.windows.net"
|
||||
)
|
||||
TEST_AZURE_AI_SEARCH_KEY = os.environ.get("TEST_AZURE_AI_SEARCH_KEY", "test_api_key")
|
||||
|
||||
|
||||
class TestAzureAISearchVectorStore:
|
||||
"""Test class for AzureAISearchVectorStore."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_search_client(self):
|
||||
"""Create a mock Azure AI Search client."""
|
||||
with patch(
|
||||
"graphrag.vector_stores.azure_ai_search.SearchClient"
|
||||
) as mock_client:
|
||||
yield mock_client.return_value
|
||||
|
||||
@pytest.fixture
|
||||
def mock_index_client(self):
|
||||
"""Create a mock Azure AI Search index client."""
|
||||
with patch(
|
||||
"graphrag.vector_stores.azure_ai_search.SearchIndexClient"
|
||||
) as mock_client:
|
||||
yield mock_client.return_value
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store(self, mock_search_client, mock_index_client):
|
||||
"""Create an Azure AI Search vector store instance."""
|
||||
vector_store = AzureAISearchVectorStore(collection_name="test_vectors")
|
||||
|
||||
# Create the necessary mocks first
|
||||
vector_store.db_connection = mock_search_client
|
||||
vector_store.index_client = mock_index_client
|
||||
|
||||
vector_store.connect(
|
||||
url=TEST_AZURE_AI_SEARCH_URL,
|
||||
api_key=TEST_AZURE_AI_SEARCH_KEY,
|
||||
vector_size=5,
|
||||
)
|
||||
return vector_store
|
||||
|
||||
@pytest.fixture
|
||||
def sample_documents(self):
|
||||
"""Create sample documents for testing."""
|
||||
return [
|
||||
VectorStoreDocument(
|
||||
id="doc1",
|
||||
text="This is document 1",
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
attributes={"title": "Doc 1", "category": "test"},
|
||||
),
|
||||
VectorStoreDocument(
|
||||
id="doc2",
|
||||
text="This is document 2",
|
||||
vector=[0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
attributes={"title": "Doc 2", "category": "test"},
|
||||
),
|
||||
]
|
||||
|
||||
async def test_vector_store_operations(
|
||||
self, vector_store, sample_documents, mock_search_client, mock_index_client
|
||||
):
|
||||
"""Test basic vector store operations with Azure AI Search."""
|
||||
# Setup mock responses
|
||||
mock_index_client.list_index_names.return_value = []
|
||||
mock_index_client.create_or_update_index = MagicMock()
|
||||
mock_search_client.upload_documents = MagicMock()
|
||||
|
||||
search_results = [
|
||||
{
|
||||
"id": "doc1",
|
||||
"text": "This is document 1",
|
||||
"vector": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"attributes": '{"title": "Doc 1", "category": "test"}',
|
||||
"@search.score": 0.9,
|
||||
},
|
||||
{
|
||||
"id": "doc2",
|
||||
"text": "This is document 2",
|
||||
"vector": [0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"attributes": '{"title": "Doc 2", "category": "test"}',
|
||||
"@search.score": 0.8,
|
||||
},
|
||||
]
|
||||
mock_search_client.search.return_value = search_results
|
||||
|
||||
mock_search_client.get_document.return_value = {
|
||||
"id": "doc1",
|
||||
"text": "This is document 1",
|
||||
"vector": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"attributes": '{"title": "Doc 1", "category": "test"}',
|
||||
}
|
||||
|
||||
vector_store.load_documents(sample_documents)
|
||||
assert mock_index_client.create_or_update_index.called
|
||||
assert mock_search_client.upload_documents.called
|
||||
|
||||
filter_query = vector_store.filter_by_id(["doc1", "doc2"])
|
||||
assert filter_query == "search.in(id, 'doc1,doc2', ',')"
|
||||
|
||||
vector_results = vector_store.similarity_search_by_vector(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
|
||||
)
|
||||
assert len(vector_results) == 2
|
||||
assert vector_results[0].document.id == "doc1"
|
||||
assert vector_results[0].score == 0.9
|
||||
|
||||
# Define a simple text embedder function for testing
|
||||
def mock_embedder(text: str) -> list[float]:
|
||||
return [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
|
||||
text_results = vector_store.similarity_search_by_text(
|
||||
"test query", mock_embedder, k=2
|
||||
)
|
||||
assert len(text_results) == 2
|
||||
|
||||
doc = vector_store.search_by_id("doc1")
|
||||
assert doc.id == "doc1"
|
||||
assert doc.text == "This is document 1"
|
||||
assert doc.attributes["title"] == "Doc 1"
|
||||
|
||||
async def test_empty_embedding(self, vector_store, mock_search_client):
|
||||
"""Test similarity search by text with empty embedding."""
|
||||
|
||||
# Create a mock embedder that returns None and verify that no results are produced
|
||||
def none_embedder(text: str) -> None:
|
||||
return None
|
||||
|
||||
results = vector_store.similarity_search_by_text(
|
||||
"test query", none_embedder, k=1
|
||||
)
|
||||
assert not mock_search_client.search.called
|
||||
assert len(results) == 0
|
||||
104
tests/integration/vector_stores/test_cosmosdb.py
Normal file
104
tests/integration/vector_stores/test_cosmosdb.py
Normal file
@ -0,0 +1,104 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Integration tests for CosmosDB vector store implementation."""
|
||||
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from graphrag.vector_stores.base import VectorStoreDocument
|
||||
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
|
||||
|
||||
# cspell:disable-next-line well-known-key
|
||||
WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
|
||||
|
||||
# the cosmosdb emulator is only available on windows runners at this time
|
||||
if not sys.platform.startswith("win"):
|
||||
pytest.skip(
|
||||
"encountered windows-only tests -- will skip for now", allow_module_level=True
|
||||
)
|
||||
|
||||
|
||||
def test_vector_store_operations():
|
||||
"""Test basic vector store operations with CosmosDB."""
|
||||
vector_store = CosmosDBVectorStore(
|
||||
collection_name="testvector",
|
||||
)
|
||||
|
||||
try:
|
||||
vector_store.connect(
|
||||
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
|
||||
database_name="test_db",
|
||||
)
|
||||
|
||||
docs = [
|
||||
VectorStoreDocument(
|
||||
id="doc1",
|
||||
text="This is document 1",
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
attributes={"title": "Doc 1", "category": "test"},
|
||||
),
|
||||
VectorStoreDocument(
|
||||
id="doc2",
|
||||
text="This is document 2",
|
||||
vector=[0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
attributes={"title": "Doc 2", "category": "test"},
|
||||
),
|
||||
]
|
||||
vector_store.load_documents(docs)
|
||||
|
||||
vector_store.filter_by_id(["doc1"])
|
||||
|
||||
doc = vector_store.search_by_id("doc1")
|
||||
assert doc.id == "doc1"
|
||||
assert doc.text == "This is document 1"
|
||||
assert doc.vector is not None
|
||||
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
assert doc.attributes["title"] == "Doc 1"
|
||||
|
||||
# Define a simple text embedder function for testing
|
||||
def mock_embedder(text: str) -> list[float]:
|
||||
return [0.1, 0.2, 0.3, 0.4, 0.5] # Return fixed embedding
|
||||
|
||||
vector_results = vector_store.similarity_search_by_vector(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
|
||||
)
|
||||
assert len(vector_results) > 0
|
||||
|
||||
text_results = vector_store.similarity_search_by_text(
|
||||
"test query", mock_embedder, k=2
|
||||
)
|
||||
assert len(text_results) > 0
|
||||
finally:
|
||||
vector_store.clear()
|
||||
|
||||
|
||||
def test_clear():
|
||||
"""Test clearing the vector store."""
|
||||
vector_store = CosmosDBVectorStore(
|
||||
collection_name="testclear",
|
||||
)
|
||||
try:
|
||||
vector_store.connect(
|
||||
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
|
||||
database_name="testclear",
|
||||
)
|
||||
|
||||
doc = VectorStoreDocument(
|
||||
id="test",
|
||||
text="Test document",
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
attributes={"title": "Test Doc"},
|
||||
)
|
||||
|
||||
vector_store.load_documents([doc])
|
||||
result = vector_store.search_by_id("test")
|
||||
assert result.id == "test"
|
||||
|
||||
# Clear and verify document is removed
|
||||
vector_store.clear()
|
||||
assert vector_store._database_exists() is False # noqa: SLF001
|
||||
finally:
|
||||
pass
|
||||
172
tests/integration/vector_stores/test_lancedb.py
Normal file
172
tests/integration/vector_stores/test_lancedb.py
Normal file
@ -0,0 +1,172 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Integration tests for LanceDB vector store implementation."""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
|
||||
from graphrag.vector_stores.base import VectorStoreDocument
|
||||
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
|
||||
|
||||
def test_vector_store_operations():
|
||||
"""Test basic vector store operations with LanceDB."""
|
||||
# Create a temporary directory for the test database
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
vector_store = LanceDBVectorStore(collection_name="test_collection")
|
||||
vector_store.connect(db_uri=temp_dir)
|
||||
|
||||
docs = [
|
||||
VectorStoreDocument(
|
||||
id="1",
|
||||
text="This is document 1",
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
attributes={"title": "Doc 1", "category": "test"},
|
||||
),
|
||||
VectorStoreDocument(
|
||||
id="2",
|
||||
text="This is document 2",
|
||||
vector=[0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
attributes={"title": "Doc 2", "category": "test"},
|
||||
),
|
||||
VectorStoreDocument(
|
||||
id="3",
|
||||
text="This is document 3",
|
||||
vector=[0.3, 0.4, 0.5, 0.6, 0.7],
|
||||
attributes={"title": "Doc 3", "category": "test"},
|
||||
),
|
||||
]
|
||||
vector_store.load_documents(docs[:2])
|
||||
|
||||
assert vector_store.collection_name in vector_store.db_connection.table_names()
|
||||
|
||||
doc = vector_store.search_by_id("1")
|
||||
assert doc.id == "1"
|
||||
assert doc.text == "This is document 1"
|
||||
|
||||
assert doc.vector is not None
|
||||
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
assert doc.attributes["title"] == "Doc 1"
|
||||
|
||||
filter_query = vector_store.filter_by_id(["1"])
|
||||
assert filter_query == "id in ('1')"
|
||||
|
||||
results = vector_store.similarity_search_by_vector(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
|
||||
)
|
||||
assert 1 <= len(results) <= 2
|
||||
assert isinstance(results[0].score, float)
|
||||
|
||||
# Test append mode
|
||||
vector_store.load_documents([docs[2]], overwrite=False)
|
||||
result = vector_store.search_by_id("3")
|
||||
assert result.id == "3"
|
||||
assert result.text == "This is document 3"
|
||||
|
||||
# Define a simple text embedder function for testing
|
||||
def mock_embedder(text: str) -> list[float]:
|
||||
return [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
|
||||
text_results = vector_store.similarity_search_by_text(
|
||||
"test query", mock_embedder, k=2
|
||||
)
|
||||
assert 1 <= len(text_results) <= 2
|
||||
assert isinstance(text_results[0].score, float)
|
||||
|
||||
# Test non-existent document
|
||||
non_existent = vector_store.search_by_id("nonexistent")
|
||||
assert non_existent.id == "nonexistent"
|
||||
assert non_existent.text is None
|
||||
assert non_existent.vector is None
|
||||
finally:
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
def test_empty_collection():
|
||||
"""Test creating an empty collection."""
|
||||
# Create a temporary directory for the test database
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
vector_store = LanceDBVectorStore(collection_name="empty_collection")
|
||||
vector_store.connect(db_uri=temp_dir)
|
||||
|
||||
# Load the vector store with a document, then delete it
|
||||
sample_doc = VectorStoreDocument(
|
||||
id="tmp",
|
||||
text="Temporary document to create schema",
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
attributes={"title": "Tmp"},
|
||||
)
|
||||
vector_store.load_documents([sample_doc])
|
||||
vector_store.db_connection.open_table(vector_store.collection_name).delete(
|
||||
"id = 'tmp'"
|
||||
)
|
||||
|
||||
# Should still have the collection
|
||||
assert vector_store.collection_name in vector_store.db_connection.table_names()
|
||||
|
||||
# Add a document after creating an empty collection
|
||||
doc = VectorStoreDocument(
|
||||
id="1",
|
||||
text="This is document 1",
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
attributes={"title": "Doc 1"},
|
||||
)
|
||||
vector_store.load_documents([doc], overwrite=False)
|
||||
|
||||
result = vector_store.search_by_id("1")
|
||||
assert result.id == "1"
|
||||
assert result.text == "This is document 1"
|
||||
finally:
|
||||
# Clean up - remove the temporary directory
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
def test_filter_search():
|
||||
"""Test filtered search with LanceDB."""
|
||||
# Create a temporary directory for the test database
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
vector_store = LanceDBVectorStore(collection_name="filter_collection")
|
||||
vector_store.connect(db_uri=temp_dir)
|
||||
|
||||
# Create test documents with different categories
|
||||
docs = [
|
||||
VectorStoreDocument(
|
||||
id="1",
|
||||
text="Document about cats",
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
attributes={"category": "animals"},
|
||||
),
|
||||
VectorStoreDocument(
|
||||
id="2",
|
||||
text="Document about dogs",
|
||||
vector=[0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
attributes={"category": "animals"},
|
||||
),
|
||||
VectorStoreDocument(
|
||||
id="3",
|
||||
text="Document about cars",
|
||||
vector=[0.3, 0.4, 0.5, 0.6, 0.7],
|
||||
attributes={"category": "vehicles"},
|
||||
),
|
||||
]
|
||||
vector_store.load_documents(docs)
|
||||
|
||||
# Filter to include only documents about animals
|
||||
vector_store.filter_by_id(["1", "2"])
|
||||
results = vector_store.similarity_search_by_vector(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5], k=3
|
||||
)
|
||||
|
||||
# Should return at most 2 documents (the filtered ones)
|
||||
assert len(results) <= 2
|
||||
ids = [result.document.id for result in results]
|
||||
assert "3" not in ids
|
||||
assert set(ids).issubset({"1", "2"})
|
||||
finally:
|
||||
shutil.rmtree(temp_dir)
|
||||
Loading…
Reference in New Issue
Block a user