Custom vector store schema implementation (#2062)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled

* progress on vector customization

* fix for lancedb vectors

* cosmosdb implementation

* uv run poe format

* clean test for vector store

* semversioner update

* test_factory.py integration test fixes

* fixes for cosmosdb test

* integration test fix for lancedb

* uv fix for format

* test fixes

* fixes for tests

* fix cosmosdb bug

* print statement

* test

* test

* fix cosmosdb bug

* test validation

* validation cosmosdb

* validate cosmosdb

* fix cosmosdb

* fix small feedback from PR

---------

Co-authored-by: Gaudy Blanco <gaudy-microsoft@MacBook-Pro-m4-Gaudy-For-Work.local>
This commit is contained in:
gaudyb 2025-09-19 11:11:34 -06:00 committed by GitHub
parent 075cadd59a
commit 82cd3b7df2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 774 additions and 268 deletions

View File

@ -0,0 +1,4 @@
{
"type": "minor",
"description": "add customization to vector store"
}

55
.vscode/launch.json vendored
View File

@ -6,21 +6,24 @@
"name": "Indexer",
"type": "debugpy",
"request": "launch",
"module": "uv",
"module": "graphrag",
"args": [
"poe", "index",
"--root", "<path_to_ragtest_root_demo>"
"index",
"--root",
"<path_to_index_folder>"
],
"console": "integratedTerminal"
},
{
"name": "Query",
"type": "debugpy",
"request": "launch",
"module": "uv",
"module": "graphrag",
"args": [
"poe", "query",
"--root", "<path_to_ragtest_root_demo>",
"--method", "global",
"query",
"--root",
"<path_to_index_folder>",
"--method", "basic",
"--query", "What are the top themes in this story",
]
},
@ -34,6 +37,42 @@
"--config",
"<path_to_ragtest_root_demo>/settings.yaml",
]
}
},
{
"name": "Debug Integration Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"args": [
"./tests/integration/vector_stores",
"-k", "test_azure_ai_search"
],
"console": "integratedTerminal",
"justMyCode": false
},
{
"name": "Debug Verbs Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"args": [
"./tests/verbs",
"-k", "test_generate_text_embeddings"
],
"console": "integratedTerminal",
"justMyCode": false
},
{
"name": "Debug Smoke Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"args": [
"./tests/smoke",
"-k", "test_fixtures"
],
"console": "integratedTerminal",
"justMyCode": false
},
]
}

View File

@ -394,6 +394,7 @@ class VectorStoreDefaults:
api_key: None = None
audience: None = None
database_name: None = None
schema: None = None
@dataclass

View File

@ -29,14 +29,14 @@ default_embeddings: list[str] = [
]
def create_collection_name(
def create_index_name(
container_name: str, embedding_name: str, validate: bool = True
) -> str:
"""
Create a collection name for the embedding store.
Create a index name for the embedding store.
Within any given vector store, we can have multiple sets of embeddings organized into projects.
The `container` param is used for this partitioning, and is added as a prefix to the collection name for differentiation.
The `container` param is used for this partitioning, and is added as a prefix to the index name for differentiation.
The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings

View File

@ -6,7 +6,9 @@
from pydantic import BaseModel, Field, model_validator
from graphrag.config.defaults import vector_store_defaults
from graphrag.config.embeddings import all_embeddings
from graphrag.config.enums import VectorStoreType
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
class VectorStoreConfig(BaseModel):
@ -85,9 +87,25 @@ class VectorStoreConfig(BaseModel):
default=vector_store_defaults.overwrite,
)
embeddings_schema: dict[str, VectorStoreSchemaConfig] = {}
def _validate_embeddings_schema(self) -> None:
"""Validate the embeddings schema."""
for name in self.embeddings_schema:
if name not in all_embeddings:
msg = f"vector_store.embeddings_schema contains an invalid embedding schema name: {name}. Please update your settings.yaml and select the correct embedding schema names."
raise ValueError(msg)
if self.type == VectorStoreType.CosmosDB:
for id_field in self.embeddings_schema:
if id_field != "id":
msg = "When using CosmosDB, the id_field in embeddings_schema must be 'id'. Please update your settings.yaml and set the id_field to 'id'."
raise ValueError(msg)
@model_validator(mode="after")
def _validate_model(self):
"""Validate the model."""
self._validate_db_uri()
self._validate_url()
self._validate_embeddings_schema()
return self

View File

@ -0,0 +1,66 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Parameterization settings for the default configuration."""
import re
from pydantic import BaseModel, Field, model_validator
DEFAULT_VECTOR_SIZE: int = 1536
VALID_IDENTIFIER_REGEX = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
def is_valid_field_name(field: str) -> bool:
"""Check if a field name is valid for CosmosDB."""
return bool(VALID_IDENTIFIER_REGEX.match(field))
class VectorStoreSchemaConfig(BaseModel):
"""The default configuration section for Vector Store Schema."""
id_field: str = Field(
description="The ID field to use.",
default="id",
)
vector_field: str = Field(
description="The vector field to use.",
default="vector",
)
text_field: str = Field(
description="The text field to use.",
default="text",
)
attributes_field: str = Field(
description="The attributes field to use.",
default="attributes",
)
vector_size: int = Field(
description="The vector size to use.",
default=DEFAULT_VECTOR_SIZE,
)
index_name: str | None = Field(description="The index name to use.", default=None)
def _validate_schema(self) -> None:
"""Validate the schema."""
for field in [
self.id_field,
self.vector_field,
self.text_field,
self.attributes_field,
]:
if not is_valid_field_name(field):
msg = f"Unsafe or invalid field name: {field}"
raise ValueError(msg)
@model_validator(mode="after")
def _validate_model(self):
"""Validate the model."""
self._validate_schema()
return self

View File

@ -12,7 +12,8 @@ import pandas as pd
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.embeddings import create_collection_name
from graphrag.config.embeddings import create_index_name
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
from graphrag.vector_stores.factory import VectorStoreFactory
@ -49,9 +50,9 @@ async def embed_text(
vector_store_config = strategy.get("vector_store")
if vector_store_config:
collection_name = _get_collection_name(vector_store_config, embedding_name)
index_name = _get_index_name(vector_store_config, embedding_name)
vector_store: BaseVectorStore = _create_vector_store(
vector_store_config, collection_name
vector_store_config, index_name, embedding_name
)
vector_store_workflow_config = vector_store_config.get(
embedding_name, vector_store_config
@ -183,27 +184,46 @@ async def _text_embed_with_vector_store(
def _create_vector_store(
vector_store_config: dict, collection_name: str
vector_store_config: dict, index_name: str, embedding_name: str | None = None
) -> BaseVectorStore:
vector_store_type: str = str(vector_store_config.get("type"))
if collection_name:
vector_store_config.update({"collection_name": collection_name})
embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get(
"embeddings_schema", {}
)
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
if (
embeddings_schema is not None
and embedding_name is not None
and embedding_name in embeddings_schema
):
raw_config = embeddings_schema[embedding_name]
if isinstance(raw_config, dict):
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
else:
single_embedding_config = raw_config
if single_embedding_config.index_name is None:
single_embedding_config.index_name = index_name
vector_store = VectorStoreFactory().create_vector_store(
vector_store_type, kwargs=vector_store_config
vector_store_schema_config=single_embedding_config,
vector_store_type=vector_store_type,
kwargs=vector_store_config,
)
vector_store.connect(**vector_store_config)
return vector_store
def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str:
def _get_index_name(vector_store_config: dict, embedding_name: str) -> str:
container_name = vector_store_config.get("container_name", "default")
collection_name = create_collection_name(container_name, embedding_name)
index_name = create_index_name(container_name, embedding_name)
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}"
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {index_name}"
logger.info(msg)
return collection_name
return index_name
def load_strategy(strategy: TextEmbedStrategyType) -> TextEmbeddingStrategy:

View File

@ -8,9 +8,10 @@ from typing import Any
from graphrag.cache.factory import CacheFactory
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.embeddings import create_collection_name
from graphrag.config.embeddings import create_index_name
from graphrag.config.models.cache_config import CacheConfig
from graphrag.config.models.storage_config import StorageConfig
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.data_model.types import TextEmbedder
from graphrag.storage.factory import StorageFactory
from graphrag.storage.pipeline_storage import PipelineStorage
@ -103,12 +104,33 @@ def get_embedding_store(
index_names = []
for index, store in config_args.items():
vector_store_type = store["type"]
collection_name = create_collection_name(
index_name = create_index_name(
store.get("container_name", "default"), embedding_name
)
embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get(
"embeddings_schema", {}
)
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
if (
embeddings_schema is not None
and embedding_name is not None
and embedding_name in embeddings_schema
):
raw_config = embeddings_schema[embedding_name]
if isinstance(raw_config, dict):
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
else:
single_embedding_config = raw_config
if single_embedding_config.index_name is None:
single_embedding_config.index_name = index_name
embedding_store = VectorStoreFactory().create_vector_store(
vector_store_type=vector_store_type,
kwargs={**store, "collection_name": collection_name},
vector_store_schema_config=single_embedding_config,
kwargs={**store},
)
embedding_store.connect(**store)
# If there is only a single index, return the embedding store directly

View File

@ -24,9 +24,9 @@ from azure.search.documents.indexes.models import (
)
from azure.search.documents.models import VectorizedQuery
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.data_model.types import TextEmbedder
from graphrag.vector_stores.base import (
DEFAULT_VECTOR_SIZE,
BaseVectorStore,
VectorStoreDocument,
VectorStoreSearchResult,
@ -38,15 +38,18 @@ class AzureAISearchVectorStore(BaseVectorStore):
index_client: SearchIndexClient
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
def __init__(
self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
) -> None:
super().__init__(
vector_store_schema_config=vector_store_schema_config, **kwargs
)
def connect(self, **kwargs: Any) -> Any:
"""Connect to AI search vector storage."""
url = kwargs["url"]
api_key = kwargs.get("api_key")
audience = kwargs.get("audience")
self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE)
self.vector_search_profile_name = kwargs.get(
"vector_search_profile_name", "vectorSearchProfile"
@ -56,7 +59,7 @@ class AzureAISearchVectorStore(BaseVectorStore):
audience_arg = {"audience": audience} if audience and not api_key else {}
self.db_connection = SearchClient(
endpoint=url,
index_name=self.collection_name,
index_name=self.index_name if self.index_name else "",
credential=(
AzureKeyCredential(api_key) if api_key else DefaultAzureCredential()
),
@ -78,8 +81,11 @@ class AzureAISearchVectorStore(BaseVectorStore):
) -> None:
"""Load documents into an Azure AI Search index."""
if overwrite:
if self.collection_name in self.index_client.list_index_names():
self.index_client.delete_index(self.collection_name)
if (
self.index_name is not None
and self.index_name in self.index_client.list_index_names()
):
self.index_client.delete_index(self.index_name)
# Configure vector search profile
vector_search = VectorSearch(
@ -100,24 +106,26 @@ class AzureAISearchVectorStore(BaseVectorStore):
)
# Configure the index
index = SearchIndex(
name=self.collection_name,
name=self.index_name if self.index_name else "",
fields=[
SimpleField(
name="id",
name=self.id_field,
type=SearchFieldDataType.String,
key=True,
),
SearchField(
name="vector",
name=self.vector_field,
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
searchable=True,
hidden=False, # DRIFT needs to return the vector for client-side similarity
vector_search_dimensions=self.vector_size,
vector_search_profile_name=self.vector_search_profile_name,
),
SearchableField(name="text", type=SearchFieldDataType.String),
SearchableField(
name=self.text_field, type=SearchFieldDataType.String
),
SimpleField(
name="attributes",
name=self.attributes_field,
type=SearchFieldDataType.String,
),
],
@ -129,10 +137,10 @@ class AzureAISearchVectorStore(BaseVectorStore):
batch = [
{
"id": doc.id,
"vector": doc.vector,
"text": doc.text,
"attributes": json.dumps(doc.attributes),
self.id_field: doc.id,
self.vector_field: doc.vector,
self.text_field: doc.text,
self.attributes_field: json.dumps(doc.attributes),
}
for doc in documents
if doc.vector is not None
@ -151,7 +159,7 @@ class AzureAISearchVectorStore(BaseVectorStore):
# More info about odata filtering here: https://learn.microsoft.com/en-us/azure/search/search-query-odata-search-in-function
# search.in is faster that joined and/or conditions
id_filter = ",".join([f"{id!s}" for id in include_ids])
self.query_filter = f"search.in(id, '{id_filter}', ',')"
self.query_filter = f"search.in({self.id_field}, '{id_filter}', ',')"
# Returning to keep consistency with other methods, but not needed
# TODO: Refactor on a future PR
@ -162,7 +170,7 @@ class AzureAISearchVectorStore(BaseVectorStore):
) -> list[VectorStoreSearchResult]:
"""Perform a vector-based similarity search."""
vectorized_query = VectorizedQuery(
vector=query_embedding, k_nearest_neighbors=k, fields="vector"
vector=query_embedding, k_nearest_neighbors=k, fields=self.vector_field
)
response = self.db_connection.search(
@ -172,10 +180,10 @@ class AzureAISearchVectorStore(BaseVectorStore):
return [
VectorStoreSearchResult(
document=VectorStoreDocument(
id=doc.get("id", ""),
text=doc.get("text", ""),
vector=doc.get("vector", []),
attributes=(json.loads(doc.get("attributes", "{}"))),
id=doc.get(self.id_field, ""),
text=doc.get(self.text_field, ""),
vector=doc.get(self.vector_field, []),
attributes=(json.loads(doc.get(self.attributes_field, "{}"))),
),
# Cosine similarity between 0.333 and 1.000
# https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking#scores-in-a-hybrid-search-results
@ -199,8 +207,8 @@ class AzureAISearchVectorStore(BaseVectorStore):
"""Search for a document by id."""
response = self.db_connection.get_document(id)
return VectorStoreDocument(
id=response.get("id", ""),
text=response.get("text", ""),
vector=response.get("vector", []),
attributes=(json.loads(response.get("attributes", "{}"))),
id=response.get(self.id_field, ""),
text=response.get(self.text_field, ""),
vector=response.get(self.vector_field, []),
attributes=(json.loads(response.get(self.attributes_field, "{}"))),
)

View File

@ -7,10 +7,9 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.data_model.types import TextEmbedder
DEFAULT_VECTOR_SIZE: int = 1536
@dataclass
class VectorStoreDocument:
@ -42,18 +41,24 @@ class BaseVectorStore(ABC):
def __init__(
self,
collection_name: str,
vector_store_schema_config: VectorStoreSchemaConfig,
db_connection: Any | None = None,
document_collection: Any | None = None,
query_filter: Any | None = None,
**kwargs: Any,
):
self.collection_name = collection_name
self.db_connection = db_connection
self.document_collection = document_collection
self.query_filter = query_filter
self.kwargs = kwargs
self.index_name = vector_store_schema_config.index_name
self.id_field = vector_store_schema_config.id_field
self.text_field = vector_store_schema_config.text_field
self.vector_field = vector_store_schema_config.vector_field
self.attributes_field = vector_store_schema_config.attributes_field
self.vector_size = vector_store_schema_config.vector_size
@abstractmethod
def connect(self, **kwargs: Any) -> None:
"""Connect to vector storage."""

View File

@ -11,9 +11,9 @@ from azure.cosmos.exceptions import CosmosHttpResponseError
from azure.cosmos.partition_key import PartitionKey
from azure.identity import DefaultAzureCredential
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.data_model.types import TextEmbedder
from graphrag.vector_stores.base import (
DEFAULT_VECTOR_SIZE,
BaseVectorStore,
VectorStoreDocument,
VectorStoreSearchResult,
@ -27,8 +27,12 @@ class CosmosDBVectorStore(BaseVectorStore):
_database_client: DatabaseProxy
_container_client: ContainerProxy
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
def __init__(
self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
) -> None:
super().__init__(
vector_store_schema_config=vector_store_schema_config, **kwargs
)
def connect(self, **kwargs: Any) -> Any:
"""Connect to CosmosDB vector storage."""
@ -49,13 +53,12 @@ class CosmosDBVectorStore(BaseVectorStore):
msg = "Database name must be provided."
raise ValueError(msg)
self._database_name = database_name
collection_name = self.collection_name
if collection_name is None:
msg = "Collection name is empty or not provided."
if self.index_name is None:
msg = "Index name is empty or not provided."
raise ValueError(msg)
self._container_name = collection_name
self._container_name = self.index_name
self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE)
self.vector_size = self.vector_size
self._create_database()
self._create_container()
@ -80,13 +83,13 @@ class CosmosDBVectorStore(BaseVectorStore):
def _create_container(self) -> None:
"""Create the container if it doesn't exist."""
partition_key = PartitionKey(path="/id", kind="Hash")
partition_key = PartitionKey(path=f"/{self.id_field}", kind="Hash")
# Define the container vector policy
vector_embedding_policy = {
"vectorEmbeddings": [
{
"path": "/vector",
"path": f"/{self.vector_field}",
"dataType": "float32",
"distanceFunction": "cosine",
"dimensions": self.vector_size,
@ -99,13 +102,18 @@ class CosmosDBVectorStore(BaseVectorStore):
"indexingMode": "consistent",
"automatic": True,
"includedPaths": [{"path": "/*"}],
"excludedPaths": [{"path": "/_etag/?"}, {"path": "/vector/*"}],
"excludedPaths": [
{"path": "/_etag/?"},
{"path": f"/{self.vector_field}/*"},
],
}
# Currently, the CosmosDB emulator does not support the diskANN policy.
try:
# First try with the standard diskANN policy
indexing_policy["vectorIndexes"] = [{"path": "/vector", "type": "diskANN"}]
indexing_policy["vectorIndexes"] = [
{"path": f"/{self.vector_field}", "type": "diskANN"}
]
# Create the container and container client
self._database_client.create_container_if_not_exists(
@ -158,12 +166,16 @@ class CosmosDBVectorStore(BaseVectorStore):
# Upload documents to CosmosDB
for doc in documents:
if doc.vector is not None:
print("Document to store:") # noqa: T201
print(doc) # noqa: T201
doc_json = {
"id": doc.id,
"vector": doc.vector,
"text": doc.text,
"attributes": json.dumps(doc.attributes),
self.id_field: doc.id,
self.vector_field: doc.vector,
self.text_field: doc.text,
self.attributes_field: json.dumps(doc.attributes),
}
print("Storing document in CosmosDB:") # noqa: T201
print(doc_json) # noqa: T201
self._container_client.upsert_item(doc_json)
def similarity_search_by_vector(
@ -175,7 +187,7 @@ class CosmosDBVectorStore(BaseVectorStore):
raise ValueError(msg)
try:
query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608
query = f"SELECT TOP {k} c.{self.id_field}, c.{self.text_field}, c.{self.vector_field}, c.{self.attributes_field}, VectorDistance(c.{self.vector_field}, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.{self.vector_field}, @embedding)" # noqa: S608
query_params = [{"name": "@embedding", "value": query_embedding}]
items = list(
self._container_client.query_items(
@ -187,7 +199,7 @@ class CosmosDBVectorStore(BaseVectorStore):
except (CosmosHttpResponseError, ValueError):
# Currently, the CosmosDB emulator does not support the VectorDistance function.
# For emulator or test environments - fetch all items and calculate distance locally
query = "SELECT c.id, c.text, c.vector, c.attributes FROM c"
query = f"SELECT c.{self.id_field}, c.{self.text_field}, c.{self.vector_field}, c.{self.attributes_field} FROM c" # noqa: S608
items = list(
self._container_client.query_items(
query=query,
@ -206,7 +218,7 @@ class CosmosDBVectorStore(BaseVectorStore):
# Calculate scores for all items
for item in items:
item_vector = item.get("vector", [])
item_vector = item.get(self.vector_field, [])
similarity = cosine_similarity(query_embedding, item_vector)
item["SimilarityScore"] = similarity
@ -218,10 +230,10 @@ class CosmosDBVectorStore(BaseVectorStore):
return [
VectorStoreSearchResult(
document=VectorStoreDocument(
id=item.get("id", ""),
text=item.get("text", ""),
vector=item.get("vector", []),
attributes=(json.loads(item.get("attributes", "{}"))),
id=item.get(self.id_field, ""),
text=item.get(self.text_field, ""),
vector=item.get(self.vector_field, []),
attributes=(json.loads(item.get(self.attributes_field, "{}"))),
),
score=item.get("SimilarityScore", 0.0),
)
@ -248,7 +260,9 @@ class CosmosDBVectorStore(BaseVectorStore):
id_filter = ", ".join([f"'{id}'" for id in include_ids])
else:
id_filter = ", ".join([str(id) for id in include_ids])
self.query_filter = f"SELECT * FROM c WHERE c.id IN ({id_filter})" # noqa: S608
self.query_filter = (
f"SELECT * FROM c WHERE c.{self.id_field} IN ({id_filter})" # noqa: S608
)
return self.query_filter
def search_by_id(self, id: str) -> VectorStoreDocument:
@ -259,10 +273,10 @@ class CosmosDBVectorStore(BaseVectorStore):
item = self._container_client.read_item(item=id, partition_key=id)
return VectorStoreDocument(
id=item.get("id", ""),
vector=item.get("vector", []),
text=item.get("text", ""),
attributes=(json.loads(item.get("attributes", "{}"))),
id=item.get(self.id_field, ""),
vector=item.get(self.vector_field, []),
text=item.get(self.text_field, ""),
attributes=(json.loads(item.get(self.attributes_field, "{}"))),
)
def clear(self) -> None:

View File

@ -15,6 +15,9 @@ from graphrag.vector_stores.lancedb import LanceDBVectorStore
if TYPE_CHECKING:
from collections.abc import Callable
from graphrag.config.models.vector_store_schema_config import (
VectorStoreSchemaConfig,
)
from graphrag.vector_stores.base import BaseVectorStore
@ -47,7 +50,10 @@ class VectorStoreFactory:
@classmethod
def create_vector_store(
cls, vector_store_type: str, kwargs: dict
cls,
vector_store_type: str,
vector_store_schema_config: VectorStoreSchemaConfig,
kwargs: dict,
) -> BaseVectorStore:
"""Create a vector store object from the provided type.
@ -67,7 +73,9 @@ class VectorStoreFactory:
msg = f"Unknown vector store type: {vector_store_type}"
raise ValueError(msg)
return cls._registry[vector_store_type](**kwargs)
return cls._registry[vector_store_type](
vector_store_schema_config=vector_store_schema_config, **kwargs
)
@classmethod
def get_vector_store_types(cls) -> list[str]:

View File

@ -5,9 +5,9 @@
import json # noqa: I001
from typing import Any
import pyarrow as pa
import numpy as np
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.data_model.types import TextEmbedder
from graphrag.vector_stores.base import (
@ -21,60 +21,81 @@ import lancedb
class LanceDBVectorStore(BaseVectorStore):
"""LanceDB vector storage implementation."""
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
def __init__(
self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
) -> None:
super().__init__(
vector_store_schema_config=vector_store_schema_config, **kwargs
)
def connect(self, **kwargs: Any) -> Any:
"""Connect to the vector storage."""
self.db_connection = lancedb.connect(kwargs["db_uri"])
if (
self.collection_name
and self.collection_name in self.db_connection.table_names()
):
self.document_collection = self.db_connection.open_table(
self.collection_name
)
if self.index_name and self.index_name in self.db_connection.table_names():
self.document_collection = self.db_connection.open_table(self.index_name)
def load_documents(
self, documents: list[VectorStoreDocument], overwrite: bool = True
) -> None:
"""Load documents into vector storage."""
data = [
{
"id": document.id,
"text": document.text,
"vector": document.vector,
"attributes": json.dumps(document.attributes),
}
for document in documents
if document.vector is not None
]
# Step 1: Prepare data columns manually
ids = []
texts = []
vectors = []
attributes = []
if len(data) == 0:
for document in documents:
self.vector_size = (
len(document.vector) if document.vector else self.vector_size
)
if document.vector is not None and len(document.vector) == self.vector_size:
ids.append(document.id)
texts.append(document.text)
vectors.append(np.array(document.vector, dtype=np.float32))
attributes.append(json.dumps(document.attributes))
# Step 2: Handle empty case
if len(ids) == 0:
data = None
else:
# Step 3: Flatten the vectors and build FixedSizeListArray manually
flat_vector = np.concatenate(vectors).astype(np.float32)
flat_array = pa.array(flat_vector, type=pa.float32())
vector_column = pa.FixedSizeListArray.from_arrays(
flat_array, self.vector_size
)
# Step 4: Create PyArrow table (let schema be inferred)
data = pa.table({
self.id_field: pa.array(ids, type=pa.string()),
self.text_field: pa.array(texts, type=pa.string()),
self.vector_field: vector_column,
self.attributes_field: pa.array(attributes, type=pa.string()),
})
schema = pa.schema([
pa.field("id", pa.string()),
pa.field("text", pa.string()),
pa.field("vector", pa.list_(pa.float64())),
pa.field("attributes", pa.string()),
])
# NOTE: If modifying the next section of code, ensure that the schema remains the same.
# The pyarrow format of the 'vector' field may change if the order of operations is changed
# and will break vector search.
if overwrite:
if data:
self.document_collection = self.db_connection.create_table(
self.collection_name, data=data, mode="overwrite"
self.index_name if self.index_name else "",
data=data,
mode="overwrite",
schema=data.schema,
)
else:
self.document_collection = self.db_connection.create_table(
self.collection_name, schema=schema, mode="overwrite"
self.index_name if self.index_name else "", mode="overwrite"
)
self.document_collection.create_index(
vector_column_name=self.vector_field, index_type="IVF_FLAT"
)
else:
# add data to existing table
self.document_collection = self.db_connection.open_table(
self.collection_name
self.index_name if self.index_name else ""
)
if data:
self.document_collection.add(data)
@ -86,30 +107,32 @@ class LanceDBVectorStore(BaseVectorStore):
else:
if isinstance(include_ids[0], str):
id_filter = ", ".join([f"'{id}'" for id in include_ids])
self.query_filter = f"id in ({id_filter})"
self.query_filter = f"{self.id_field} in ({id_filter})"
else:
self.query_filter = (
f"id in ({', '.join([str(id) for id in include_ids])})"
f"{self.id_field} in ({', '.join([str(id) for id in include_ids])})"
)
return self.query_filter
def similarity_search_by_vector(
self, query_embedding: list[float], k: int = 10, **kwargs: Any
self, query_embedding: list[float] | np.ndarray, k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
"""Perform a vector-based similarity search."""
if self.query_filter:
docs = (
self.document_collection.search(
query=query_embedding, vector_column_name="vector"
query=query_embedding, vector_column_name=self.vector_field
)
.where(self.query_filter, prefilter=True)
.limit(k)
.to_list()
)
else:
query_embedding = np.array(query_embedding, dtype=np.float32)
docs = (
self.document_collection.search(
query=query_embedding, vector_column_name="vector"
query=query_embedding, vector_column_name=self.vector_field
)
.limit(k)
.to_list()
@ -117,10 +140,10 @@ class LanceDBVectorStore(BaseVectorStore):
return [
VectorStoreSearchResult(
document=VectorStoreDocument(
id=doc["id"],
text=doc["text"],
vector=doc["vector"],
attributes=json.loads(doc["attributes"]),
id=doc[self.id_field],
text=doc[self.text_field],
vector=doc[self.vector_field],
attributes=json.loads(doc[self.attributes_field]),
),
score=1 - abs(float(doc["_distance"])),
)
@ -140,14 +163,14 @@ class LanceDBVectorStore(BaseVectorStore):
"""Search for a document by id."""
doc = (
self.document_collection.search()
.where(f"id == '{id}'", prefilter=True)
.where(f"{self.id_field} == '{id}'", prefilter=True)
.to_list()
)
if doc:
return VectorStoreDocument(
id=doc[0]["id"],
text=doc[0]["text"],
vector=doc[0]["vector"],
attributes=json.loads(doc[0]["attributes"]),
id=doc[0][self.id_field],
text=doc[0][self.text_field],
vector=doc[0][self.vector_field],
attributes=json.loads(doc[0][self.attributes_field]),
)
return VectorStoreDocument(id=id, text=None, vector=None)

View File

@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch
import pytest
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
from graphrag.vector_stores.base import VectorStoreDocument
@ -39,7 +40,35 @@ class TestAzureAISearchVectorStore:
@pytest.fixture
def vector_store(self, mock_search_client, mock_index_client):
"""Create an Azure AI Search vector store instance."""
vector_store = AzureAISearchVectorStore(collection_name="test_vectors")
vector_store = AzureAISearchVectorStore(
vector_store_schema_config=VectorStoreSchemaConfig(
index_name="test_vectors", vector_size=5
),
)
# Create the necessary mocks first
vector_store.db_connection = mock_search_client
vector_store.index_client = mock_index_client
vector_store.connect(
url=TEST_AZURE_AI_SEARCH_URL,
api_key=TEST_AZURE_AI_SEARCH_KEY,
)
return vector_store
@pytest.fixture
def vector_store_custom(self, mock_search_client, mock_index_client):
"""Create an Azure AI Search vector store instance."""
vector_store = AzureAISearchVectorStore(
vector_store_schema_config=VectorStoreSchemaConfig(
index_name="test_vectors",
id_field="id_custom",
text_field="text_custom",
attributes_field="attributes_custom",
vector_field="vector_custom",
vector_size=5,
),
)
# Create the necessary mocks first
vector_store.db_connection = mock_search_client
@ -48,7 +77,6 @@ class TestAzureAISearchVectorStore:
vector_store.connect(
url=TEST_AZURE_AI_SEARCH_URL,
api_key=TEST_AZURE_AI_SEARCH_KEY,
vector_size=5,
)
return vector_store
@ -144,3 +172,72 @@ class TestAzureAISearchVectorStore:
)
assert not mock_search_client.search.called
assert len(results) == 0
async def test_vector_store_customization(
self,
vector_store_custom,
sample_documents,
mock_search_client,
mock_index_client,
):
"""Test vector store customization with Azure AI Search."""
# Setup mock responses
mock_index_client.list_index_names.return_value = []
mock_index_client.create_or_update_index = MagicMock()
mock_search_client.upload_documents = MagicMock()
search_results = [
{
vector_store_custom.id_field: "doc1",
vector_store_custom.text_field: "This is document 1",
vector_store_custom.vector_field: [0.1, 0.2, 0.3, 0.4, 0.5],
vector_store_custom.attributes_field: '{"title": "Doc 1", "category": "test"}',
"@search.score": 0.9,
},
{
vector_store_custom.id_field: "doc2",
vector_store_custom.text_field: "This is document 2",
vector_store_custom.vector_field: [0.2, 0.3, 0.4, 0.5, 0.6],
vector_store_custom.attributes_field: '{"title": "Doc 2", "category": "test"}',
"@search.score": 0.8,
},
]
mock_search_client.search.return_value = search_results
mock_search_client.get_document.return_value = {
vector_store_custom.id_field: "doc1",
vector_store_custom.text_field: "This is document 1",
vector_store_custom.vector_field: [0.1, 0.2, 0.3, 0.4, 0.5],
vector_store_custom.attributes_field: '{"title": "Doc 1", "category": "test"}',
}
vector_store_custom.load_documents(sample_documents)
assert mock_index_client.create_or_update_index.called
assert mock_search_client.upload_documents.called
filter_query = vector_store_custom.filter_by_id(["doc1", "doc2"])
assert (
filter_query
== f"search.in({vector_store_custom.id_field}, 'doc1,doc2', ',')"
)
vector_results = vector_store_custom.similarity_search_by_vector(
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
)
assert len(vector_results) == 2
assert vector_results[0].document.id == "doc1"
assert vector_results[0].score == 0.9
# Define a simple text embedder function for testing
def mock_embedder(text: str) -> list[float]:
return [0.1, 0.2, 0.3, 0.4, 0.5]
text_results = vector_store_custom.similarity_search_by_text(
"test query", mock_embedder, k=2
)
assert len(text_results) == 2
doc = vector_store_custom.search_by_id("doc1")
assert doc.id == "doc1"
assert doc.text == "This is document 1"
assert doc.attributes["title"] == "Doc 1"

View File

@ -8,6 +8,7 @@ import sys
import numpy as np
import pytest
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.vector_stores.base import VectorStoreDocument
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
@ -24,7 +25,7 @@ if not sys.platform.startswith("win"):
def test_vector_store_operations():
"""Test basic vector store operations with CosmosDB."""
vector_store = CosmosDBVectorStore(
collection_name="testvector",
vector_store_schema_config=VectorStoreSchemaConfig(index_name="testvector"),
)
try:
@ -78,7 +79,7 @@ def test_vector_store_operations():
def test_clear():
"""Test clearing the vector store."""
vector_store = CosmosDBVectorStore(
collection_name="testclear",
vector_store_schema_config=VectorStoreSchemaConfig(index_name="testclear"),
)
try:
vector_store.connect(
@ -102,3 +103,64 @@ def test_clear():
assert vector_store._database_exists() is False # noqa: SLF001
finally:
pass
def test_vector_store_customization():
"""Test vector store customization with CosmosDB."""
vector_store = CosmosDBVectorStore(
vector_store_schema_config=VectorStoreSchemaConfig(
index_name="text-embeddings",
id_field="id",
text_field="text_custom",
vector_field="vector_custom",
attributes_field="attributes_custom",
vector_size=5,
),
)
try:
vector_store.connect(
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
database_name="test_db",
)
docs = [
VectorStoreDocument(
id="doc1",
text="This is document 1",
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
attributes={"title": "Doc 1", "category": "test"},
),
VectorStoreDocument(
id="doc2",
text="This is document 2",
vector=[0.2, 0.3, 0.4, 0.5, 0.6],
attributes={"title": "Doc 2", "category": "test"},
),
]
vector_store.load_documents(docs)
vector_store.filter_by_id(["doc1"])
doc = vector_store.search_by_id("doc1")
assert doc.id == "doc1"
assert doc.text == "This is document 1"
assert doc.vector is not None
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
assert doc.attributes["title"] == "Doc 1"
# Define a simple text embedder function for testing
def mock_embedder(text: str) -> list[float]:
return [0.1, 0.2, 0.3, 0.4, 0.5] # Return fixed embedding
vector_results = vector_store.similarity_search_by_vector(
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
)
assert len(vector_results) > 0
text_results = vector_store.similarity_search_by_text(
"test query", mock_embedder, k=2
)
assert len(text_results) > 0
finally:
vector_store.clear()

View File

@ -8,6 +8,7 @@ These tests will test the VectorStoreFactory class and the creation of each vect
import pytest
from graphrag.config.enums import VectorStoreType
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
from graphrag.vector_stores.base import BaseVectorStore
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
@ -17,25 +18,31 @@ from graphrag.vector_stores.lancedb import LanceDBVectorStore
def test_create_lancedb_vector_store():
kwargs = {
"collection_name": "test_collection",
"db_uri": "/tmp/lancedb",
}
vector_store = VectorStoreFactory.create_vector_store(
VectorStoreType.LanceDB.value, kwargs
vector_store_type=VectorStoreType.LanceDB.value,
vector_store_schema_config=VectorStoreSchemaConfig(
index_name="test_collection"
),
kwargs=kwargs,
)
assert isinstance(vector_store, LanceDBVectorStore)
assert vector_store.collection_name == "test_collection"
assert vector_store.index_name == "test_collection"
@pytest.mark.skip(reason="Azure AI Search requires credentials and setup")
def test_create_azure_ai_search_vector_store():
kwargs = {
"collection_name": "test_collection",
"url": "https://test.search.windows.net",
"api_key": "test_key",
}
vector_store = VectorStoreFactory.create_vector_store(
VectorStoreType.AzureAISearch.value, kwargs
vector_store_type=VectorStoreType.AzureAISearch.value,
vector_store_schema_config=VectorStoreSchemaConfig(
index_name="test_collection"
),
kwargs=kwargs,
)
assert isinstance(vector_store, AzureAISearchVectorStore)
@ -43,13 +50,18 @@ def test_create_azure_ai_search_vector_store():
@pytest.mark.skip(reason="CosmosDB requires credentials and setup")
def test_create_cosmosdb_vector_store():
kwargs = {
"collection_name": "test_collection",
"connection_string": "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=test_key==",
"database_name": "test_db",
}
vector_store = VectorStoreFactory.create_vector_store(
VectorStoreType.CosmosDB.value, kwargs
vector_store_type=VectorStoreType.CosmosDB.value,
vector_store_schema_config=VectorStoreSchemaConfig(
index_name="test_collection"
),
kwargs=kwargs,
)
assert isinstance(vector_store, CosmosDBVectorStore)
@ -67,7 +79,12 @@ def test_register_and_create_custom_vector_store():
VectorStoreFactory.register(
"custom", lambda **kwargs: custom_vector_store_class(**kwargs)
)
vector_store = VectorStoreFactory.create_vector_store("custom", {})
vector_store = VectorStoreFactory.create_vector_store(
vector_store_type="custom",
vector_store_schema_config=VectorStoreSchemaConfig(),
kwargs={},
)
assert custom_vector_store_class.called
assert vector_store is instance
@ -89,7 +106,11 @@ def test_get_vector_store_types():
def test_create_unknown_vector_store():
with pytest.raises(ValueError, match="Unknown vector store type: unknown"):
VectorStoreFactory.create_vector_store("unknown", {})
VectorStoreFactory.create_vector_store(
vector_store_type="unknown",
vector_store_schema_config=VectorStoreSchemaConfig(),
kwargs={},
)
def test_is_supported_type():
@ -139,6 +160,9 @@ def test_register_class_directly_works():
# Test creating an instance
vector_store = VectorStoreFactory.create_vector_store(
"custom_class", {"collection_name": "test"}
vector_store_type="custom_class",
vector_store_schema_config=VectorStoreSchemaConfig(),
kwargs={"collection_name": "test"},
)
assert isinstance(vector_store, CustomVectorStore)

View File

@ -7,20 +7,20 @@ import shutil
import tempfile
import numpy as np
import pytest
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.vector_stores.base import VectorStoreDocument
from graphrag.vector_stores.lancedb import LanceDBVectorStore
def test_vector_store_operations():
"""Test basic vector store operations with LanceDB."""
# Create a temporary directory for the test database
temp_dir = tempfile.mkdtemp()
try:
vector_store = LanceDBVectorStore(collection_name="test_collection")
vector_store.connect(db_uri=temp_dir)
class TestLanceDBVectorStore:
"""Test class for TestLanceDBVectorStore."""
docs = [
@pytest.fixture
def sample_documents(self):
"""Create sample documents for testing."""
return [
VectorStoreDocument(
id="1",
text="This is document 1",
@ -40,102 +40,11 @@ def test_vector_store_operations():
attributes={"title": "Doc 3", "category": "test"},
),
]
vector_store.load_documents(docs[:2])
assert vector_store.collection_name in vector_store.db_connection.table_names()
doc = vector_store.search_by_id("1")
assert doc.id == "1"
assert doc.text == "This is document 1"
assert doc.vector is not None
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
assert doc.attributes["title"] == "Doc 1"
filter_query = vector_store.filter_by_id(["1"])
assert filter_query == "id in ('1')"
results = vector_store.similarity_search_by_vector(
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
)
assert 1 <= len(results) <= 2
assert isinstance(results[0].score, float)
# Test append mode
vector_store.load_documents([docs[2]], overwrite=False)
result = vector_store.search_by_id("3")
assert result.id == "3"
assert result.text == "This is document 3"
# Define a simple text embedder function for testing
def mock_embedder(text: str) -> list[float]:
return [0.1, 0.2, 0.3, 0.4, 0.5]
text_results = vector_store.similarity_search_by_text(
"test query", mock_embedder, k=2
)
assert 1 <= len(text_results) <= 2
assert isinstance(text_results[0].score, float)
# Test non-existent document
non_existent = vector_store.search_by_id("nonexistent")
assert non_existent.id == "nonexistent"
assert non_existent.text is None
assert non_existent.vector is None
finally:
shutil.rmtree(temp_dir)
def test_empty_collection():
"""Test creating an empty collection."""
# Create a temporary directory for the test database
temp_dir = tempfile.mkdtemp()
try:
vector_store = LanceDBVectorStore(collection_name="empty_collection")
vector_store.connect(db_uri=temp_dir)
# Load the vector store with a document, then delete it
sample_doc = VectorStoreDocument(
id="tmp",
text="Temporary document to create schema",
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
attributes={"title": "Tmp"},
)
vector_store.load_documents([sample_doc])
vector_store.db_connection.open_table(vector_store.collection_name).delete(
"id = 'tmp'"
)
# Should still have the collection
assert vector_store.collection_name in vector_store.db_connection.table_names()
# Add a document after creating an empty collection
doc = VectorStoreDocument(
id="1",
text="This is document 1",
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
attributes={"title": "Doc 1"},
)
vector_store.load_documents([doc], overwrite=False)
result = vector_store.search_by_id("1")
assert result.id == "1"
assert result.text == "This is document 1"
finally:
# Clean up - remove the temporary directory
shutil.rmtree(temp_dir)
def test_filter_search():
"""Test filtered search with LanceDB."""
# Create a temporary directory for the test database
temp_dir = tempfile.mkdtemp()
try:
vector_store = LanceDBVectorStore(collection_name="filter_collection")
vector_store.connect(db_uri=temp_dir)
# Create test documents with different categories
docs = [
@pytest.fixture
def sample_documents_categories(self):
"""Create sample documents with different categories for testing."""
return [
VectorStoreDocument(
id="1",
text="Document about cats",
@ -155,18 +64,201 @@ def test_filter_search():
attributes={"category": "vehicles"},
),
]
vector_store.load_documents(docs)
# Filter to include only documents about animals
vector_store.filter_by_id(["1", "2"])
results = vector_store.similarity_search_by_vector(
[0.1, 0.2, 0.3, 0.4, 0.5], k=3
)
def test_vector_store_operations(self, sample_documents):
"""Test basic vector store operations with LanceDB."""
# Create a temporary directory for the test database
temp_dir = tempfile.mkdtemp()
try:
vector_store = LanceDBVectorStore(
vector_store_schema_config=VectorStoreSchemaConfig(
index_name="test_collection", vector_size=5
)
)
vector_store.connect(db_uri=temp_dir)
vector_store.load_documents(sample_documents[:2])
# Should return at most 2 documents (the filtered ones)
assert len(results) <= 2
ids = [result.document.id for result in results]
assert "3" not in ids
assert set(ids).issubset({"1", "2"})
finally:
shutil.rmtree(temp_dir)
if vector_store.index_name:
assert (
vector_store.index_name in vector_store.db_connection.table_names()
)
doc = vector_store.search_by_id("1")
assert doc.id == "1"
assert doc.text == "This is document 1"
assert doc.vector is not None
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
assert doc.attributes["title"] == "Doc 1"
filter_query = vector_store.filter_by_id(["1"])
assert filter_query == "id in ('1')"
results = vector_store.similarity_search_by_vector(
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
)
assert 1 <= len(results) <= 2
assert isinstance(results[0].score, float)
# Test append mode
vector_store.load_documents([sample_documents[2]], overwrite=False)
result = vector_store.search_by_id("3")
assert result.id == "3"
assert result.text == "This is document 3"
# Define a simple text embedder function for testing
def mock_embedder(text: str) -> list[float]:
return [0.1, 0.2, 0.3, 0.4, 0.5]
text_results = vector_store.similarity_search_by_text(
"test query", mock_embedder, k=2
)
assert 1 <= len(text_results) <= 2
assert isinstance(text_results[0].score, float)
# Test non-existent document
non_existent = vector_store.search_by_id("nonexistent")
assert non_existent.id == "nonexistent"
assert non_existent.text is None
assert non_existent.vector is None
finally:
shutil.rmtree(temp_dir)
def test_empty_collection(self):
"""Test creating an empty collection."""
# Create a temporary directory for the test database
temp_dir = tempfile.mkdtemp()
try:
vector_store = LanceDBVectorStore(
vector_store_schema_config=VectorStoreSchemaConfig(
index_name="empty_collection", vector_size=5
)
)
vector_store.connect(db_uri=temp_dir)
# Load the vector store with a document, then delete it
sample_doc = VectorStoreDocument(
id="tmp",
text="Temporary document to create schema",
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
attributes={"title": "Tmp"},
)
vector_store.load_documents([sample_doc])
vector_store.db_connection.open_table(
vector_store.index_name if vector_store.index_name else ""
).delete("id = 'tmp'")
# Should still have the collection
if vector_store.index_name:
assert (
vector_store.index_name in vector_store.db_connection.table_names()
)
# Add a document after creating an empty collection
doc = VectorStoreDocument(
id="1",
text="This is document 1",
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
attributes={"title": "Doc 1"},
)
vector_store.load_documents([doc], overwrite=False)
result = vector_store.search_by_id("1")
assert result.id == "1"
assert result.text == "This is document 1"
finally:
# Clean up - remove the temporary directory
shutil.rmtree(temp_dir)
def test_filter_search(self, sample_documents_categories):
"""Test filtered search with LanceDB."""
# Create a temporary directory for the test database
temp_dir = tempfile.mkdtemp()
try:
vector_store = LanceDBVectorStore(
vector_store_schema_config=VectorStoreSchemaConfig(
index_name="filter_collection", vector_size=5
)
)
vector_store.connect(db_uri=temp_dir)
vector_store.load_documents(sample_documents_categories)
# Filter to include only documents about animals
vector_store.filter_by_id(["1", "2"])
results = vector_store.similarity_search_by_vector(
[0.1, 0.2, 0.3, 0.4, 0.5], k=3
)
# Should return at most 2 documents (the filtered ones)
assert len(results) <= 2
ids = [result.document.id for result in results]
assert "3" not in ids
assert set(ids).issubset({"1", "2"})
finally:
shutil.rmtree(temp_dir)
def test_vector_store_customization(self, sample_documents):
"""Test vector store customization with LanceDB."""
# Create a temporary directory for the test database
temp_dir = tempfile.mkdtemp()
try:
vector_store = LanceDBVectorStore(
vector_store_schema_config=VectorStoreSchemaConfig(
index_name="text-embeddings",
id_field="id_custom",
text_field="text_custom",
vector_field="vector_custom",
attributes_field="attributes_custom",
vector_size=5,
),
)
vector_store.connect(db_uri=temp_dir)
vector_store.load_documents(sample_documents[:2])
if vector_store.index_name:
assert (
vector_store.index_name in vector_store.db_connection.table_names()
)
doc = vector_store.search_by_id("1")
assert doc.id == "1"
assert doc.text == "This is document 1"
assert doc.vector is not None
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
assert doc.attributes["title"] == "Doc 1"
filter_query = vector_store.filter_by_id(["1"])
assert filter_query == f"{vector_store.id_field} in ('1')"
results = vector_store.similarity_search_by_vector(
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
)
assert 1 <= len(results) <= 2
assert isinstance(results[0].score, float)
# Test append mode
vector_store.load_documents([sample_documents[2]], overwrite=False)
result = vector_store.search_by_id("3")
assert result.id == "3"
assert result.text == "This is document 3"
# Define a simple text embedder function for testing
def mock_embedder(text: str) -> list[float]:
return [0.1, 0.2, 0.3, 0.4, 0.5]
text_results = vector_store.similarity_search_by_text(
"test query", mock_embedder, k=2
)
assert 1 <= len(text_results) <= 2
assert isinstance(text_results[0].score, float)
# Test non-existent document
non_existent = vector_store.search_by_id("nonexistent")
assert non_existent.id == "nonexistent"
assert non_existent.text is None
assert non_existent.vector is None
finally:
shutil.rmtree(temp_dir)

View File

@ -3,6 +3,7 @@
from typing import Any
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.data_model.entity import Entity
from graphrag.data_model.types import TextEmbedder
from graphrag.language_model.manager import ModelManager
@ -19,7 +20,9 @@ from graphrag.vector_stores.base import (
class MockBaseVectorStore(BaseVectorStore):
def __init__(self, documents: list[VectorStoreDocument]) -> None:
super().__init__("mock")
super().__init__(
vector_store_schema_config=VectorStoreSchemaConfig(index_name="mock")
)
self.documents = documents
def connect(self, **kwargs: Any) -> None:

View File

@ -3,19 +3,19 @@
import pytest
from graphrag.config.embeddings import create_collection_name
from graphrag.config.embeddings import create_index_name
def test_create_collection_name():
collection = create_collection_name("default", "entity.title")
def test_create_index_name():
collection = create_index_name("default", "entity.title")
assert collection == "default-entity-title"
def test_create_collection_name_invalid_embedding_throws():
def test_create_index_name_invalid_embedding_throws():
with pytest.raises(KeyError):
create_collection_name("default", "invalid.name")
create_index_name("default", "invalid.name")
def test_create_collection_name_invalid_embedding_does_not_throw():
collection = create_collection_name("default", "invalid.name", validate=False)
def test_create_index_name_invalid_embedding_does_not_throw():
collection = create_index_name("default", "invalid.name", validate=False)
assert collection == "default-invalid-name"