Custom vector store schema implementation (#2062)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled

* progress on vector customization

* fix for lancedb vectors

* cosmosdb implementation

* uv run poe format

* clean test for vector store

* semversioner update

* test_factory.py integration test fixes

* fixes for cosmosdb test

* integration test fix for lancedb

* uv fix for format

* test fixes

* fixes for tests

* fix cosmosdb bug

* print statement

* test

* test

* fix cosmosdb bug

* test validation

* validation cosmosdb

* validate cosmosdb

* fix cosmosdb

* fix small feedback from PR

---------

Co-authored-by: Gaudy Blanco <gaudy-microsoft@MacBook-Pro-m4-Gaudy-For-Work.local>
This commit is contained in:
gaudyb 2025-09-19 11:11:34 -06:00 committed by GitHub
parent 075cadd59a
commit 82cd3b7df2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 774 additions and 268 deletions

View File

@ -0,0 +1,4 @@
{
"type": "minor",
"description": "add customization to vector store"
}

55
.vscode/launch.json vendored
View File

@ -6,21 +6,24 @@
"name": "Indexer", "name": "Indexer",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"module": "uv", "module": "graphrag",
"args": [ "args": [
"poe", "index", "index",
"--root", "<path_to_ragtest_root_demo>" "--root",
"<path_to_index_folder>"
], ],
"console": "integratedTerminal"
}, },
{ {
"name": "Query", "name": "Query",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"module": "uv", "module": "graphrag",
"args": [ "args": [
"poe", "query", "query",
"--root", "<path_to_ragtest_root_demo>", "--root",
"--method", "global", "<path_to_index_folder>",
"--method", "basic",
"--query", "What are the top themes in this story", "--query", "What are the top themes in this story",
] ]
}, },
@ -34,6 +37,42 @@
"--config", "--config",
"<path_to_ragtest_root_demo>/settings.yaml", "<path_to_ragtest_root_demo>/settings.yaml",
] ]
} },
{
"name": "Debug Integration Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"args": [
"./tests/integration/vector_stores",
"-k", "test_azure_ai_search"
],
"console": "integratedTerminal",
"justMyCode": false
},
{
"name": "Debug Verbs Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"args": [
"./tests/verbs",
"-k", "test_generate_text_embeddings"
],
"console": "integratedTerminal",
"justMyCode": false
},
{
"name": "Debug Smoke Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"args": [
"./tests/smoke",
"-k", "test_fixtures"
],
"console": "integratedTerminal",
"justMyCode": false
},
] ]
} }

View File

@ -394,6 +394,7 @@ class VectorStoreDefaults:
api_key: None = None api_key: None = None
audience: None = None audience: None = None
database_name: None = None database_name: None = None
schema: None = None
@dataclass @dataclass

View File

@ -29,14 +29,14 @@ default_embeddings: list[str] = [
] ]
def create_collection_name( def create_index_name(
container_name: str, embedding_name: str, validate: bool = True container_name: str, embedding_name: str, validate: bool = True
) -> str: ) -> str:
""" """
Create a collection name for the embedding store. Create a index name for the embedding store.
Within any given vector store, we can have multiple sets of embeddings organized into projects. Within any given vector store, we can have multiple sets of embeddings organized into projects.
The `container` param is used for this partitioning, and is added as a prefix to the collection name for differentiation. The `container` param is used for this partitioning, and is added as a prefix to the index name for differentiation.
The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings

View File

@ -6,7 +6,9 @@
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from graphrag.config.defaults import vector_store_defaults from graphrag.config.defaults import vector_store_defaults
from graphrag.config.embeddings import all_embeddings
from graphrag.config.enums import VectorStoreType from graphrag.config.enums import VectorStoreType
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
class VectorStoreConfig(BaseModel): class VectorStoreConfig(BaseModel):
@ -85,9 +87,25 @@ class VectorStoreConfig(BaseModel):
default=vector_store_defaults.overwrite, default=vector_store_defaults.overwrite,
) )
embeddings_schema: dict[str, VectorStoreSchemaConfig] = {}
def _validate_embeddings_schema(self) -> None:
"""Validate the embeddings schema."""
for name in self.embeddings_schema:
if name not in all_embeddings:
msg = f"vector_store.embeddings_schema contains an invalid embedding schema name: {name}. Please update your settings.yaml and select the correct embedding schema names."
raise ValueError(msg)
if self.type == VectorStoreType.CosmosDB:
for id_field in self.embeddings_schema:
if id_field != "id":
msg = "When using CosmosDB, the id_field in embeddings_schema must be 'id'. Please update your settings.yaml and set the id_field to 'id'."
raise ValueError(msg)
@model_validator(mode="after") @model_validator(mode="after")
def _validate_model(self): def _validate_model(self):
"""Validate the model.""" """Validate the model."""
self._validate_db_uri() self._validate_db_uri()
self._validate_url() self._validate_url()
self._validate_embeddings_schema()
return self return self

View File

@ -0,0 +1,66 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Parameterization settings for the default configuration."""
import re
from pydantic import BaseModel, Field, model_validator
DEFAULT_VECTOR_SIZE: int = 1536
VALID_IDENTIFIER_REGEX = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
def is_valid_field_name(field: str) -> bool:
"""Check if a field name is valid for CosmosDB."""
return bool(VALID_IDENTIFIER_REGEX.match(field))
class VectorStoreSchemaConfig(BaseModel):
"""The default configuration section for Vector Store Schema."""
id_field: str = Field(
description="The ID field to use.",
default="id",
)
vector_field: str = Field(
description="The vector field to use.",
default="vector",
)
text_field: str = Field(
description="The text field to use.",
default="text",
)
attributes_field: str = Field(
description="The attributes field to use.",
default="attributes",
)
vector_size: int = Field(
description="The vector size to use.",
default=DEFAULT_VECTOR_SIZE,
)
index_name: str | None = Field(description="The index name to use.", default=None)
def _validate_schema(self) -> None:
"""Validate the schema."""
for field in [
self.id_field,
self.vector_field,
self.text_field,
self.attributes_field,
]:
if not is_valid_field_name(field):
msg = f"Unsafe or invalid field name: {field}"
raise ValueError(msg)
@model_validator(mode="after")
def _validate_model(self):
"""Validate the model."""
self._validate_schema()
return self

View File

@ -12,7 +12,8 @@ import pandas as pd
from graphrag.cache.pipeline_cache import PipelineCache from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.embeddings import create_collection_name from graphrag.config.embeddings import create_index_name
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
from graphrag.vector_stores.factory import VectorStoreFactory from graphrag.vector_stores.factory import VectorStoreFactory
@ -49,9 +50,9 @@ async def embed_text(
vector_store_config = strategy.get("vector_store") vector_store_config = strategy.get("vector_store")
if vector_store_config: if vector_store_config:
collection_name = _get_collection_name(vector_store_config, embedding_name) index_name = _get_index_name(vector_store_config, embedding_name)
vector_store: BaseVectorStore = _create_vector_store( vector_store: BaseVectorStore = _create_vector_store(
vector_store_config, collection_name vector_store_config, index_name, embedding_name
) )
vector_store_workflow_config = vector_store_config.get( vector_store_workflow_config = vector_store_config.get(
embedding_name, vector_store_config embedding_name, vector_store_config
@ -183,27 +184,46 @@ async def _text_embed_with_vector_store(
def _create_vector_store( def _create_vector_store(
vector_store_config: dict, collection_name: str vector_store_config: dict, index_name: str, embedding_name: str | None = None
) -> BaseVectorStore: ) -> BaseVectorStore:
vector_store_type: str = str(vector_store_config.get("type")) vector_store_type: str = str(vector_store_config.get("type"))
if collection_name:
vector_store_config.update({"collection_name": collection_name}) embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get(
"embeddings_schema", {}
)
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
if (
embeddings_schema is not None
and embedding_name is not None
and embedding_name in embeddings_schema
):
raw_config = embeddings_schema[embedding_name]
if isinstance(raw_config, dict):
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
else:
single_embedding_config = raw_config
if single_embedding_config.index_name is None:
single_embedding_config.index_name = index_name
vector_store = VectorStoreFactory().create_vector_store( vector_store = VectorStoreFactory().create_vector_store(
vector_store_type, kwargs=vector_store_config vector_store_schema_config=single_embedding_config,
vector_store_type=vector_store_type,
kwargs=vector_store_config,
) )
vector_store.connect(**vector_store_config) vector_store.connect(**vector_store_config)
return vector_store return vector_store
def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str: def _get_index_name(vector_store_config: dict, embedding_name: str) -> str:
container_name = vector_store_config.get("container_name", "default") container_name = vector_store_config.get("container_name", "default")
collection_name = create_collection_name(container_name, embedding_name) index_name = create_index_name(container_name, embedding_name)
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}" msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {index_name}"
logger.info(msg) logger.info(msg)
return collection_name return index_name
def load_strategy(strategy: TextEmbedStrategyType) -> TextEmbeddingStrategy: def load_strategy(strategy: TextEmbedStrategyType) -> TextEmbeddingStrategy:

View File

@ -8,9 +8,10 @@ from typing import Any
from graphrag.cache.factory import CacheFactory from graphrag.cache.factory import CacheFactory
from graphrag.cache.pipeline_cache import PipelineCache from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.embeddings import create_collection_name from graphrag.config.embeddings import create_index_name
from graphrag.config.models.cache_config import CacheConfig from graphrag.config.models.cache_config import CacheConfig
from graphrag.config.models.storage_config import StorageConfig from graphrag.config.models.storage_config import StorageConfig
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.data_model.types import TextEmbedder from graphrag.data_model.types import TextEmbedder
from graphrag.storage.factory import StorageFactory from graphrag.storage.factory import StorageFactory
from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.storage.pipeline_storage import PipelineStorage
@ -103,12 +104,33 @@ def get_embedding_store(
index_names = [] index_names = []
for index, store in config_args.items(): for index, store in config_args.items():
vector_store_type = store["type"] vector_store_type = store["type"]
collection_name = create_collection_name( index_name = create_index_name(
store.get("container_name", "default"), embedding_name store.get("container_name", "default"), embedding_name
) )
embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get(
"embeddings_schema", {}
)
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
if (
embeddings_schema is not None
and embedding_name is not None
and embedding_name in embeddings_schema
):
raw_config = embeddings_schema[embedding_name]
if isinstance(raw_config, dict):
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
else:
single_embedding_config = raw_config
if single_embedding_config.index_name is None:
single_embedding_config.index_name = index_name
embedding_store = VectorStoreFactory().create_vector_store( embedding_store = VectorStoreFactory().create_vector_store(
vector_store_type=vector_store_type, vector_store_type=vector_store_type,
kwargs={**store, "collection_name": collection_name}, vector_store_schema_config=single_embedding_config,
kwargs={**store},
) )
embedding_store.connect(**store) embedding_store.connect(**store)
# If there is only a single index, return the embedding store directly # If there is only a single index, return the embedding store directly

View File

@ -24,9 +24,9 @@ from azure.search.documents.indexes.models import (
) )
from azure.search.documents.models import VectorizedQuery from azure.search.documents.models import VectorizedQuery
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.data_model.types import TextEmbedder from graphrag.data_model.types import TextEmbedder
from graphrag.vector_stores.base import ( from graphrag.vector_stores.base import (
DEFAULT_VECTOR_SIZE,
BaseVectorStore, BaseVectorStore,
VectorStoreDocument, VectorStoreDocument,
VectorStoreSearchResult, VectorStoreSearchResult,
@ -38,15 +38,18 @@ class AzureAISearchVectorStore(BaseVectorStore):
index_client: SearchIndexClient index_client: SearchIndexClient
def __init__(self, **kwargs: Any) -> None: def __init__(
super().__init__(**kwargs) self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
) -> None:
super().__init__(
vector_store_schema_config=vector_store_schema_config, **kwargs
)
def connect(self, **kwargs: Any) -> Any: def connect(self, **kwargs: Any) -> Any:
"""Connect to AI search vector storage.""" """Connect to AI search vector storage."""
url = kwargs["url"] url = kwargs["url"]
api_key = kwargs.get("api_key") api_key = kwargs.get("api_key")
audience = kwargs.get("audience") audience = kwargs.get("audience")
self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE)
self.vector_search_profile_name = kwargs.get( self.vector_search_profile_name = kwargs.get(
"vector_search_profile_name", "vectorSearchProfile" "vector_search_profile_name", "vectorSearchProfile"
@ -56,7 +59,7 @@ class AzureAISearchVectorStore(BaseVectorStore):
audience_arg = {"audience": audience} if audience and not api_key else {} audience_arg = {"audience": audience} if audience and not api_key else {}
self.db_connection = SearchClient( self.db_connection = SearchClient(
endpoint=url, endpoint=url,
index_name=self.collection_name, index_name=self.index_name if self.index_name else "",
credential=( credential=(
AzureKeyCredential(api_key) if api_key else DefaultAzureCredential() AzureKeyCredential(api_key) if api_key else DefaultAzureCredential()
), ),
@ -78,8 +81,11 @@ class AzureAISearchVectorStore(BaseVectorStore):
) -> None: ) -> None:
"""Load documents into an Azure AI Search index.""" """Load documents into an Azure AI Search index."""
if overwrite: if overwrite:
if self.collection_name in self.index_client.list_index_names(): if (
self.index_client.delete_index(self.collection_name) self.index_name is not None
and self.index_name in self.index_client.list_index_names()
):
self.index_client.delete_index(self.index_name)
# Configure vector search profile # Configure vector search profile
vector_search = VectorSearch( vector_search = VectorSearch(
@ -100,24 +106,26 @@ class AzureAISearchVectorStore(BaseVectorStore):
) )
# Configure the index # Configure the index
index = SearchIndex( index = SearchIndex(
name=self.collection_name, name=self.index_name if self.index_name else "",
fields=[ fields=[
SimpleField( SimpleField(
name="id", name=self.id_field,
type=SearchFieldDataType.String, type=SearchFieldDataType.String,
key=True, key=True,
), ),
SearchField( SearchField(
name="vector", name=self.vector_field,
type=SearchFieldDataType.Collection(SearchFieldDataType.Single), type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
searchable=True, searchable=True,
hidden=False, # DRIFT needs to return the vector for client-side similarity hidden=False, # DRIFT needs to return the vector for client-side similarity
vector_search_dimensions=self.vector_size, vector_search_dimensions=self.vector_size,
vector_search_profile_name=self.vector_search_profile_name, vector_search_profile_name=self.vector_search_profile_name,
), ),
SearchableField(name="text", type=SearchFieldDataType.String), SearchableField(
name=self.text_field, type=SearchFieldDataType.String
),
SimpleField( SimpleField(
name="attributes", name=self.attributes_field,
type=SearchFieldDataType.String, type=SearchFieldDataType.String,
), ),
], ],
@ -129,10 +137,10 @@ class AzureAISearchVectorStore(BaseVectorStore):
batch = [ batch = [
{ {
"id": doc.id, self.id_field: doc.id,
"vector": doc.vector, self.vector_field: doc.vector,
"text": doc.text, self.text_field: doc.text,
"attributes": json.dumps(doc.attributes), self.attributes_field: json.dumps(doc.attributes),
} }
for doc in documents for doc in documents
if doc.vector is not None if doc.vector is not None
@ -151,7 +159,7 @@ class AzureAISearchVectorStore(BaseVectorStore):
# More info about odata filtering here: https://learn.microsoft.com/en-us/azure/search/search-query-odata-search-in-function # More info about odata filtering here: https://learn.microsoft.com/en-us/azure/search/search-query-odata-search-in-function
# search.in is faster that joined and/or conditions # search.in is faster that joined and/or conditions
id_filter = ",".join([f"{id!s}" for id in include_ids]) id_filter = ",".join([f"{id!s}" for id in include_ids])
self.query_filter = f"search.in(id, '{id_filter}', ',')" self.query_filter = f"search.in({self.id_field}, '{id_filter}', ',')"
# Returning to keep consistency with other methods, but not needed # Returning to keep consistency with other methods, but not needed
# TODO: Refactor on a future PR # TODO: Refactor on a future PR
@ -162,7 +170,7 @@ class AzureAISearchVectorStore(BaseVectorStore):
) -> list[VectorStoreSearchResult]: ) -> list[VectorStoreSearchResult]:
"""Perform a vector-based similarity search.""" """Perform a vector-based similarity search."""
vectorized_query = VectorizedQuery( vectorized_query = VectorizedQuery(
vector=query_embedding, k_nearest_neighbors=k, fields="vector" vector=query_embedding, k_nearest_neighbors=k, fields=self.vector_field
) )
response = self.db_connection.search( response = self.db_connection.search(
@ -172,10 +180,10 @@ class AzureAISearchVectorStore(BaseVectorStore):
return [ return [
VectorStoreSearchResult( VectorStoreSearchResult(
document=VectorStoreDocument( document=VectorStoreDocument(
id=doc.get("id", ""), id=doc.get(self.id_field, ""),
text=doc.get("text", ""), text=doc.get(self.text_field, ""),
vector=doc.get("vector", []), vector=doc.get(self.vector_field, []),
attributes=(json.loads(doc.get("attributes", "{}"))), attributes=(json.loads(doc.get(self.attributes_field, "{}"))),
), ),
# Cosine similarity between 0.333 and 1.000 # Cosine similarity between 0.333 and 1.000
# https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking#scores-in-a-hybrid-search-results # https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking#scores-in-a-hybrid-search-results
@ -199,8 +207,8 @@ class AzureAISearchVectorStore(BaseVectorStore):
"""Search for a document by id.""" """Search for a document by id."""
response = self.db_connection.get_document(id) response = self.db_connection.get_document(id)
return VectorStoreDocument( return VectorStoreDocument(
id=response.get("id", ""), id=response.get(self.id_field, ""),
text=response.get("text", ""), text=response.get(self.text_field, ""),
vector=response.get("vector", []), vector=response.get(self.vector_field, []),
attributes=(json.loads(response.get("attributes", "{}"))), attributes=(json.loads(response.get(self.attributes_field, "{}"))),
) )

View File

@ -7,10 +7,9 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.data_model.types import TextEmbedder from graphrag.data_model.types import TextEmbedder
DEFAULT_VECTOR_SIZE: int = 1536
@dataclass @dataclass
class VectorStoreDocument: class VectorStoreDocument:
@ -42,18 +41,24 @@ class BaseVectorStore(ABC):
def __init__( def __init__(
self, self,
collection_name: str, vector_store_schema_config: VectorStoreSchemaConfig,
db_connection: Any | None = None, db_connection: Any | None = None,
document_collection: Any | None = None, document_collection: Any | None = None,
query_filter: Any | None = None, query_filter: Any | None = None,
**kwargs: Any, **kwargs: Any,
): ):
self.collection_name = collection_name
self.db_connection = db_connection self.db_connection = db_connection
self.document_collection = document_collection self.document_collection = document_collection
self.query_filter = query_filter self.query_filter = query_filter
self.kwargs = kwargs self.kwargs = kwargs
self.index_name = vector_store_schema_config.index_name
self.id_field = vector_store_schema_config.id_field
self.text_field = vector_store_schema_config.text_field
self.vector_field = vector_store_schema_config.vector_field
self.attributes_field = vector_store_schema_config.attributes_field
self.vector_size = vector_store_schema_config.vector_size
@abstractmethod @abstractmethod
def connect(self, **kwargs: Any) -> None: def connect(self, **kwargs: Any) -> None:
"""Connect to vector storage.""" """Connect to vector storage."""

View File

@ -11,9 +11,9 @@ from azure.cosmos.exceptions import CosmosHttpResponseError
from azure.cosmos.partition_key import PartitionKey from azure.cosmos.partition_key import PartitionKey
from azure.identity import DefaultAzureCredential from azure.identity import DefaultAzureCredential
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.data_model.types import TextEmbedder from graphrag.data_model.types import TextEmbedder
from graphrag.vector_stores.base import ( from graphrag.vector_stores.base import (
DEFAULT_VECTOR_SIZE,
BaseVectorStore, BaseVectorStore,
VectorStoreDocument, VectorStoreDocument,
VectorStoreSearchResult, VectorStoreSearchResult,
@ -27,8 +27,12 @@ class CosmosDBVectorStore(BaseVectorStore):
_database_client: DatabaseProxy _database_client: DatabaseProxy
_container_client: ContainerProxy _container_client: ContainerProxy
def __init__(self, **kwargs: Any) -> None: def __init__(
super().__init__(**kwargs) self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
) -> None:
super().__init__(
vector_store_schema_config=vector_store_schema_config, **kwargs
)
def connect(self, **kwargs: Any) -> Any: def connect(self, **kwargs: Any) -> Any:
"""Connect to CosmosDB vector storage.""" """Connect to CosmosDB vector storage."""
@ -49,13 +53,12 @@ class CosmosDBVectorStore(BaseVectorStore):
msg = "Database name must be provided." msg = "Database name must be provided."
raise ValueError(msg) raise ValueError(msg)
self._database_name = database_name self._database_name = database_name
collection_name = self.collection_name if self.index_name is None:
if collection_name is None: msg = "Index name is empty or not provided."
msg = "Collection name is empty or not provided."
raise ValueError(msg) raise ValueError(msg)
self._container_name = collection_name self._container_name = self.index_name
self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE) self.vector_size = self.vector_size
self._create_database() self._create_database()
self._create_container() self._create_container()
@ -80,13 +83,13 @@ class CosmosDBVectorStore(BaseVectorStore):
def _create_container(self) -> None: def _create_container(self) -> None:
"""Create the container if it doesn't exist.""" """Create the container if it doesn't exist."""
partition_key = PartitionKey(path="/id", kind="Hash") partition_key = PartitionKey(path=f"/{self.id_field}", kind="Hash")
# Define the container vector policy # Define the container vector policy
vector_embedding_policy = { vector_embedding_policy = {
"vectorEmbeddings": [ "vectorEmbeddings": [
{ {
"path": "/vector", "path": f"/{self.vector_field}",
"dataType": "float32", "dataType": "float32",
"distanceFunction": "cosine", "distanceFunction": "cosine",
"dimensions": self.vector_size, "dimensions": self.vector_size,
@ -99,13 +102,18 @@ class CosmosDBVectorStore(BaseVectorStore):
"indexingMode": "consistent", "indexingMode": "consistent",
"automatic": True, "automatic": True,
"includedPaths": [{"path": "/*"}], "includedPaths": [{"path": "/*"}],
"excludedPaths": [{"path": "/_etag/?"}, {"path": "/vector/*"}], "excludedPaths": [
{"path": "/_etag/?"},
{"path": f"/{self.vector_field}/*"},
],
} }
# Currently, the CosmosDB emulator does not support the diskANN policy. # Currently, the CosmosDB emulator does not support the diskANN policy.
try: try:
# First try with the standard diskANN policy # First try with the standard diskANN policy
indexing_policy["vectorIndexes"] = [{"path": "/vector", "type": "diskANN"}] indexing_policy["vectorIndexes"] = [
{"path": f"/{self.vector_field}", "type": "diskANN"}
]
# Create the container and container client # Create the container and container client
self._database_client.create_container_if_not_exists( self._database_client.create_container_if_not_exists(
@ -158,12 +166,16 @@ class CosmosDBVectorStore(BaseVectorStore):
# Upload documents to CosmosDB # Upload documents to CosmosDB
for doc in documents: for doc in documents:
if doc.vector is not None: if doc.vector is not None:
print("Document to store:") # noqa: T201
print(doc) # noqa: T201
doc_json = { doc_json = {
"id": doc.id, self.id_field: doc.id,
"vector": doc.vector, self.vector_field: doc.vector,
"text": doc.text, self.text_field: doc.text,
"attributes": json.dumps(doc.attributes), self.attributes_field: json.dumps(doc.attributes),
} }
print("Storing document in CosmosDB:") # noqa: T201
print(doc_json) # noqa: T201
self._container_client.upsert_item(doc_json) self._container_client.upsert_item(doc_json)
def similarity_search_by_vector( def similarity_search_by_vector(
@ -175,7 +187,7 @@ class CosmosDBVectorStore(BaseVectorStore):
raise ValueError(msg) raise ValueError(msg)
try: try:
query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608 query = f"SELECT TOP {k} c.{self.id_field}, c.{self.text_field}, c.{self.vector_field}, c.{self.attributes_field}, VectorDistance(c.{self.vector_field}, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.{self.vector_field}, @embedding)" # noqa: S608
query_params = [{"name": "@embedding", "value": query_embedding}] query_params = [{"name": "@embedding", "value": query_embedding}]
items = list( items = list(
self._container_client.query_items( self._container_client.query_items(
@ -187,7 +199,7 @@ class CosmosDBVectorStore(BaseVectorStore):
except (CosmosHttpResponseError, ValueError): except (CosmosHttpResponseError, ValueError):
# Currently, the CosmosDB emulator does not support the VectorDistance function. # Currently, the CosmosDB emulator does not support the VectorDistance function.
# For emulator or test environments - fetch all items and calculate distance locally # For emulator or test environments - fetch all items and calculate distance locally
query = "SELECT c.id, c.text, c.vector, c.attributes FROM c" query = f"SELECT c.{self.id_field}, c.{self.text_field}, c.{self.vector_field}, c.{self.attributes_field} FROM c" # noqa: S608
items = list( items = list(
self._container_client.query_items( self._container_client.query_items(
query=query, query=query,
@ -206,7 +218,7 @@ class CosmosDBVectorStore(BaseVectorStore):
# Calculate scores for all items # Calculate scores for all items
for item in items: for item in items:
item_vector = item.get("vector", []) item_vector = item.get(self.vector_field, [])
similarity = cosine_similarity(query_embedding, item_vector) similarity = cosine_similarity(query_embedding, item_vector)
item["SimilarityScore"] = similarity item["SimilarityScore"] = similarity
@ -218,10 +230,10 @@ class CosmosDBVectorStore(BaseVectorStore):
return [ return [
VectorStoreSearchResult( VectorStoreSearchResult(
document=VectorStoreDocument( document=VectorStoreDocument(
id=item.get("id", ""), id=item.get(self.id_field, ""),
text=item.get("text", ""), text=item.get(self.text_field, ""),
vector=item.get("vector", []), vector=item.get(self.vector_field, []),
attributes=(json.loads(item.get("attributes", "{}"))), attributes=(json.loads(item.get(self.attributes_field, "{}"))),
), ),
score=item.get("SimilarityScore", 0.0), score=item.get("SimilarityScore", 0.0),
) )
@ -248,7 +260,9 @@ class CosmosDBVectorStore(BaseVectorStore):
id_filter = ", ".join([f"'{id}'" for id in include_ids]) id_filter = ", ".join([f"'{id}'" for id in include_ids])
else: else:
id_filter = ", ".join([str(id) for id in include_ids]) id_filter = ", ".join([str(id) for id in include_ids])
self.query_filter = f"SELECT * FROM c WHERE c.id IN ({id_filter})" # noqa: S608 self.query_filter = (
f"SELECT * FROM c WHERE c.{self.id_field} IN ({id_filter})" # noqa: S608
)
return self.query_filter return self.query_filter
def search_by_id(self, id: str) -> VectorStoreDocument: def search_by_id(self, id: str) -> VectorStoreDocument:
@ -259,10 +273,10 @@ class CosmosDBVectorStore(BaseVectorStore):
item = self._container_client.read_item(item=id, partition_key=id) item = self._container_client.read_item(item=id, partition_key=id)
return VectorStoreDocument( return VectorStoreDocument(
id=item.get("id", ""), id=item.get(self.id_field, ""),
vector=item.get("vector", []), vector=item.get(self.vector_field, []),
text=item.get("text", ""), text=item.get(self.text_field, ""),
attributes=(json.loads(item.get("attributes", "{}"))), attributes=(json.loads(item.get(self.attributes_field, "{}"))),
) )
def clear(self) -> None: def clear(self) -> None:

View File

@ -15,6 +15,9 @@ from graphrag.vector_stores.lancedb import LanceDBVectorStore
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Callable from collections.abc import Callable
from graphrag.config.models.vector_store_schema_config import (
VectorStoreSchemaConfig,
)
from graphrag.vector_stores.base import BaseVectorStore from graphrag.vector_stores.base import BaseVectorStore
@ -47,7 +50,10 @@ class VectorStoreFactory:
@classmethod @classmethod
def create_vector_store( def create_vector_store(
cls, vector_store_type: str, kwargs: dict cls,
vector_store_type: str,
vector_store_schema_config: VectorStoreSchemaConfig,
kwargs: dict,
) -> BaseVectorStore: ) -> BaseVectorStore:
"""Create a vector store object from the provided type. """Create a vector store object from the provided type.
@ -67,7 +73,9 @@ class VectorStoreFactory:
msg = f"Unknown vector store type: {vector_store_type}" msg = f"Unknown vector store type: {vector_store_type}"
raise ValueError(msg) raise ValueError(msg)
return cls._registry[vector_store_type](**kwargs) return cls._registry[vector_store_type](
vector_store_schema_config=vector_store_schema_config, **kwargs
)
@classmethod @classmethod
def get_vector_store_types(cls) -> list[str]: def get_vector_store_types(cls) -> list[str]:

View File

@ -5,9 +5,9 @@
import json # noqa: I001 import json # noqa: I001
from typing import Any from typing import Any
import pyarrow as pa import pyarrow as pa
import numpy as np
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.data_model.types import TextEmbedder from graphrag.data_model.types import TextEmbedder
from graphrag.vector_stores.base import ( from graphrag.vector_stores.base import (
@ -21,60 +21,81 @@ import lancedb
class LanceDBVectorStore(BaseVectorStore): class LanceDBVectorStore(BaseVectorStore):
"""LanceDB vector storage implementation.""" """LanceDB vector storage implementation."""
def __init__(self, **kwargs: Any) -> None: def __init__(
super().__init__(**kwargs) self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
) -> None:
super().__init__(
vector_store_schema_config=vector_store_schema_config, **kwargs
)
def connect(self, **kwargs: Any) -> Any: def connect(self, **kwargs: Any) -> Any:
"""Connect to the vector storage.""" """Connect to the vector storage."""
self.db_connection = lancedb.connect(kwargs["db_uri"]) self.db_connection = lancedb.connect(kwargs["db_uri"])
if (
self.collection_name if self.index_name and self.index_name in self.db_connection.table_names():
and self.collection_name in self.db_connection.table_names() self.document_collection = self.db_connection.open_table(self.index_name)
):
self.document_collection = self.db_connection.open_table(
self.collection_name
)
def load_documents( def load_documents(
self, documents: list[VectorStoreDocument], overwrite: bool = True self, documents: list[VectorStoreDocument], overwrite: bool = True
) -> None: ) -> None:
"""Load documents into vector storage.""" """Load documents into vector storage."""
data = [ # Step 1: Prepare data columns manually
{ ids = []
"id": document.id, texts = []
"text": document.text, vectors = []
"vector": document.vector, attributes = []
"attributes": json.dumps(document.attributes),
}
for document in documents
if document.vector is not None
]
if len(data) == 0: for document in documents:
self.vector_size = (
len(document.vector) if document.vector else self.vector_size
)
if document.vector is not None and len(document.vector) == self.vector_size:
ids.append(document.id)
texts.append(document.text)
vectors.append(np.array(document.vector, dtype=np.float32))
attributes.append(json.dumps(document.attributes))
# Step 2: Handle empty case
if len(ids) == 0:
data = None data = None
else:
# Step 3: Flatten the vectors and build FixedSizeListArray manually
flat_vector = np.concatenate(vectors).astype(np.float32)
flat_array = pa.array(flat_vector, type=pa.float32())
vector_column = pa.FixedSizeListArray.from_arrays(
flat_array, self.vector_size
)
# Step 4: Create PyArrow table (let schema be inferred)
data = pa.table({
self.id_field: pa.array(ids, type=pa.string()),
self.text_field: pa.array(texts, type=pa.string()),
self.vector_field: vector_column,
self.attributes_field: pa.array(attributes, type=pa.string()),
})
schema = pa.schema([
pa.field("id", pa.string()),
pa.field("text", pa.string()),
pa.field("vector", pa.list_(pa.float64())),
pa.field("attributes", pa.string()),
])
# NOTE: If modifying the next section of code, ensure that the schema remains the same. # NOTE: If modifying the next section of code, ensure that the schema remains the same.
# The pyarrow format of the 'vector' field may change if the order of operations is changed # The pyarrow format of the 'vector' field may change if the order of operations is changed
# and will break vector search. # and will break vector search.
if overwrite: if overwrite:
if data: if data:
self.document_collection = self.db_connection.create_table( self.document_collection = self.db_connection.create_table(
self.collection_name, data=data, mode="overwrite" self.index_name if self.index_name else "",
data=data,
mode="overwrite",
schema=data.schema,
) )
else: else:
self.document_collection = self.db_connection.create_table( self.document_collection = self.db_connection.create_table(
self.collection_name, schema=schema, mode="overwrite" self.index_name if self.index_name else "", mode="overwrite"
) )
self.document_collection.create_index(
vector_column_name=self.vector_field, index_type="IVF_FLAT"
)
else: else:
# add data to existing table # add data to existing table
self.document_collection = self.db_connection.open_table( self.document_collection = self.db_connection.open_table(
self.collection_name self.index_name if self.index_name else ""
) )
if data: if data:
self.document_collection.add(data) self.document_collection.add(data)
@ -86,30 +107,32 @@ class LanceDBVectorStore(BaseVectorStore):
else: else:
if isinstance(include_ids[0], str): if isinstance(include_ids[0], str):
id_filter = ", ".join([f"'{id}'" for id in include_ids]) id_filter = ", ".join([f"'{id}'" for id in include_ids])
self.query_filter = f"id in ({id_filter})" self.query_filter = f"{self.id_field} in ({id_filter})"
else: else:
self.query_filter = ( self.query_filter = (
f"id in ({', '.join([str(id) for id in include_ids])})" f"{self.id_field} in ({', '.join([str(id) for id in include_ids])})"
) )
return self.query_filter return self.query_filter
def similarity_search_by_vector( def similarity_search_by_vector(
self, query_embedding: list[float], k: int = 10, **kwargs: Any self, query_embedding: list[float] | np.ndarray, k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]: ) -> list[VectorStoreSearchResult]:
"""Perform a vector-based similarity search.""" """Perform a vector-based similarity search."""
if self.query_filter: if self.query_filter:
docs = ( docs = (
self.document_collection.search( self.document_collection.search(
query=query_embedding, vector_column_name="vector" query=query_embedding, vector_column_name=self.vector_field
) )
.where(self.query_filter, prefilter=True) .where(self.query_filter, prefilter=True)
.limit(k) .limit(k)
.to_list() .to_list()
) )
else: else:
query_embedding = np.array(query_embedding, dtype=np.float32)
docs = ( docs = (
self.document_collection.search( self.document_collection.search(
query=query_embedding, vector_column_name="vector" query=query_embedding, vector_column_name=self.vector_field
) )
.limit(k) .limit(k)
.to_list() .to_list()
@ -117,10 +140,10 @@ class LanceDBVectorStore(BaseVectorStore):
return [ return [
VectorStoreSearchResult( VectorStoreSearchResult(
document=VectorStoreDocument( document=VectorStoreDocument(
id=doc["id"], id=doc[self.id_field],
text=doc["text"], text=doc[self.text_field],
vector=doc["vector"], vector=doc[self.vector_field],
attributes=json.loads(doc["attributes"]), attributes=json.loads(doc[self.attributes_field]),
), ),
score=1 - abs(float(doc["_distance"])), score=1 - abs(float(doc["_distance"])),
) )
@ -140,14 +163,14 @@ class LanceDBVectorStore(BaseVectorStore):
"""Search for a document by id.""" """Search for a document by id."""
doc = ( doc = (
self.document_collection.search() self.document_collection.search()
.where(f"id == '{id}'", prefilter=True) .where(f"{self.id_field} == '{id}'", prefilter=True)
.to_list() .to_list()
) )
if doc: if doc:
return VectorStoreDocument( return VectorStoreDocument(
id=doc[0]["id"], id=doc[0][self.id_field],
text=doc[0]["text"], text=doc[0][self.text_field],
vector=doc[0]["vector"], vector=doc[0][self.vector_field],
attributes=json.loads(doc[0]["attributes"]), attributes=json.loads(doc[0][self.attributes_field]),
) )
return VectorStoreDocument(id=id, text=None, vector=None) return VectorStoreDocument(id=id, text=None, vector=None)

View File

@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
from graphrag.vector_stores.base import VectorStoreDocument from graphrag.vector_stores.base import VectorStoreDocument
@ -39,7 +40,35 @@ class TestAzureAISearchVectorStore:
@pytest.fixture @pytest.fixture
def vector_store(self, mock_search_client, mock_index_client): def vector_store(self, mock_search_client, mock_index_client):
"""Create an Azure AI Search vector store instance.""" """Create an Azure AI Search vector store instance."""
vector_store = AzureAISearchVectorStore(collection_name="test_vectors") vector_store = AzureAISearchVectorStore(
vector_store_schema_config=VectorStoreSchemaConfig(
index_name="test_vectors", vector_size=5
),
)
# Create the necessary mocks first
vector_store.db_connection = mock_search_client
vector_store.index_client = mock_index_client
vector_store.connect(
url=TEST_AZURE_AI_SEARCH_URL,
api_key=TEST_AZURE_AI_SEARCH_KEY,
)
return vector_store
@pytest.fixture
def vector_store_custom(self, mock_search_client, mock_index_client):
"""Create an Azure AI Search vector store instance."""
vector_store = AzureAISearchVectorStore(
vector_store_schema_config=VectorStoreSchemaConfig(
index_name="test_vectors",
id_field="id_custom",
text_field="text_custom",
attributes_field="attributes_custom",
vector_field="vector_custom",
vector_size=5,
),
)
# Create the necessary mocks first # Create the necessary mocks first
vector_store.db_connection = mock_search_client vector_store.db_connection = mock_search_client
@ -48,7 +77,6 @@ class TestAzureAISearchVectorStore:
vector_store.connect( vector_store.connect(
url=TEST_AZURE_AI_SEARCH_URL, url=TEST_AZURE_AI_SEARCH_URL,
api_key=TEST_AZURE_AI_SEARCH_KEY, api_key=TEST_AZURE_AI_SEARCH_KEY,
vector_size=5,
) )
return vector_store return vector_store
@ -144,3 +172,72 @@ class TestAzureAISearchVectorStore:
) )
assert not mock_search_client.search.called assert not mock_search_client.search.called
assert len(results) == 0 assert len(results) == 0
async def test_vector_store_customization(
self,
vector_store_custom,
sample_documents,
mock_search_client,
mock_index_client,
):
"""Test vector store customization with Azure AI Search."""
# Setup mock responses
mock_index_client.list_index_names.return_value = []
mock_index_client.create_or_update_index = MagicMock()
mock_search_client.upload_documents = MagicMock()
search_results = [
{
vector_store_custom.id_field: "doc1",
vector_store_custom.text_field: "This is document 1",
vector_store_custom.vector_field: [0.1, 0.2, 0.3, 0.4, 0.5],
vector_store_custom.attributes_field: '{"title": "Doc 1", "category": "test"}',
"@search.score": 0.9,
},
{
vector_store_custom.id_field: "doc2",
vector_store_custom.text_field: "This is document 2",
vector_store_custom.vector_field: [0.2, 0.3, 0.4, 0.5, 0.6],
vector_store_custom.attributes_field: '{"title": "Doc 2", "category": "test"}',
"@search.score": 0.8,
},
]
mock_search_client.search.return_value = search_results
mock_search_client.get_document.return_value = {
vector_store_custom.id_field: "doc1",
vector_store_custom.text_field: "This is document 1",
vector_store_custom.vector_field: [0.1, 0.2, 0.3, 0.4, 0.5],
vector_store_custom.attributes_field: '{"title": "Doc 1", "category": "test"}',
}
vector_store_custom.load_documents(sample_documents)
assert mock_index_client.create_or_update_index.called
assert mock_search_client.upload_documents.called
filter_query = vector_store_custom.filter_by_id(["doc1", "doc2"])
assert (
filter_query
== f"search.in({vector_store_custom.id_field}, 'doc1,doc2', ',')"
)
vector_results = vector_store_custom.similarity_search_by_vector(
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
)
assert len(vector_results) == 2
assert vector_results[0].document.id == "doc1"
assert vector_results[0].score == 0.9
# Define a simple text embedder function for testing
def mock_embedder(text: str) -> list[float]:
return [0.1, 0.2, 0.3, 0.4, 0.5]
text_results = vector_store_custom.similarity_search_by_text(
"test query", mock_embedder, k=2
)
assert len(text_results) == 2
doc = vector_store_custom.search_by_id("doc1")
assert doc.id == "doc1"
assert doc.text == "This is document 1"
assert doc.attributes["title"] == "Doc 1"

View File

@ -8,6 +8,7 @@ import sys
import numpy as np import numpy as np
import pytest import pytest
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.vector_stores.base import VectorStoreDocument from graphrag.vector_stores.base import VectorStoreDocument
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
@ -24,7 +25,7 @@ if not sys.platform.startswith("win"):
def test_vector_store_operations(): def test_vector_store_operations():
"""Test basic vector store operations with CosmosDB.""" """Test basic vector store operations with CosmosDB."""
vector_store = CosmosDBVectorStore( vector_store = CosmosDBVectorStore(
collection_name="testvector", vector_store_schema_config=VectorStoreSchemaConfig(index_name="testvector"),
) )
try: try:
@ -78,7 +79,7 @@ def test_vector_store_operations():
def test_clear(): def test_clear():
"""Test clearing the vector store.""" """Test clearing the vector store."""
vector_store = CosmosDBVectorStore( vector_store = CosmosDBVectorStore(
collection_name="testclear", vector_store_schema_config=VectorStoreSchemaConfig(index_name="testclear"),
) )
try: try:
vector_store.connect( vector_store.connect(
@ -102,3 +103,64 @@ def test_clear():
assert vector_store._database_exists() is False # noqa: SLF001 assert vector_store._database_exists() is False # noqa: SLF001
finally: finally:
pass pass
def test_vector_store_customization():
"""Test vector store customization with CosmosDB."""
vector_store = CosmosDBVectorStore(
vector_store_schema_config=VectorStoreSchemaConfig(
index_name="text-embeddings",
id_field="id",
text_field="text_custom",
vector_field="vector_custom",
attributes_field="attributes_custom",
vector_size=5,
),
)
try:
vector_store.connect(
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
database_name="test_db",
)
docs = [
VectorStoreDocument(
id="doc1",
text="This is document 1",
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
attributes={"title": "Doc 1", "category": "test"},
),
VectorStoreDocument(
id="doc2",
text="This is document 2",
vector=[0.2, 0.3, 0.4, 0.5, 0.6],
attributes={"title": "Doc 2", "category": "test"},
),
]
vector_store.load_documents(docs)
vector_store.filter_by_id(["doc1"])
doc = vector_store.search_by_id("doc1")
assert doc.id == "doc1"
assert doc.text == "This is document 1"
assert doc.vector is not None
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
assert doc.attributes["title"] == "Doc 1"
# Define a simple text embedder function for testing
def mock_embedder(text: str) -> list[float]:
return [0.1, 0.2, 0.3, 0.4, 0.5] # Return fixed embedding
vector_results = vector_store.similarity_search_by_vector(
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
)
assert len(vector_results) > 0
text_results = vector_store.similarity_search_by_text(
"test query", mock_embedder, k=2
)
assert len(text_results) > 0
finally:
vector_store.clear()

View File

@ -8,6 +8,7 @@ These tests will test the VectorStoreFactory class and the creation of each vect
import pytest import pytest
from graphrag.config.enums import VectorStoreType from graphrag.config.enums import VectorStoreType
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
from graphrag.vector_stores.base import BaseVectorStore from graphrag.vector_stores.base import BaseVectorStore
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
@ -17,25 +18,31 @@ from graphrag.vector_stores.lancedb import LanceDBVectorStore
def test_create_lancedb_vector_store(): def test_create_lancedb_vector_store():
kwargs = { kwargs = {
"collection_name": "test_collection",
"db_uri": "/tmp/lancedb", "db_uri": "/tmp/lancedb",
} }
vector_store = VectorStoreFactory.create_vector_store( vector_store = VectorStoreFactory.create_vector_store(
VectorStoreType.LanceDB.value, kwargs vector_store_type=VectorStoreType.LanceDB.value,
vector_store_schema_config=VectorStoreSchemaConfig(
index_name="test_collection"
),
kwargs=kwargs,
) )
assert isinstance(vector_store, LanceDBVectorStore) assert isinstance(vector_store, LanceDBVectorStore)
assert vector_store.collection_name == "test_collection" assert vector_store.index_name == "test_collection"
@pytest.mark.skip(reason="Azure AI Search requires credentials and setup") @pytest.mark.skip(reason="Azure AI Search requires credentials and setup")
def test_create_azure_ai_search_vector_store(): def test_create_azure_ai_search_vector_store():
kwargs = { kwargs = {
"collection_name": "test_collection",
"url": "https://test.search.windows.net", "url": "https://test.search.windows.net",
"api_key": "test_key", "api_key": "test_key",
} }
vector_store = VectorStoreFactory.create_vector_store( vector_store = VectorStoreFactory.create_vector_store(
VectorStoreType.AzureAISearch.value, kwargs vector_store_type=VectorStoreType.AzureAISearch.value,
vector_store_schema_config=VectorStoreSchemaConfig(
index_name="test_collection"
),
kwargs=kwargs,
) )
assert isinstance(vector_store, AzureAISearchVectorStore) assert isinstance(vector_store, AzureAISearchVectorStore)
@ -43,13 +50,18 @@ def test_create_azure_ai_search_vector_store():
@pytest.mark.skip(reason="CosmosDB requires credentials and setup") @pytest.mark.skip(reason="CosmosDB requires credentials and setup")
def test_create_cosmosdb_vector_store(): def test_create_cosmosdb_vector_store():
kwargs = { kwargs = {
"collection_name": "test_collection",
"connection_string": "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=test_key==", "connection_string": "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=test_key==",
"database_name": "test_db", "database_name": "test_db",
} }
vector_store = VectorStoreFactory.create_vector_store( vector_store = VectorStoreFactory.create_vector_store(
VectorStoreType.CosmosDB.value, kwargs vector_store_type=VectorStoreType.CosmosDB.value,
vector_store_schema_config=VectorStoreSchemaConfig(
index_name="test_collection"
),
kwargs=kwargs,
) )
assert isinstance(vector_store, CosmosDBVectorStore) assert isinstance(vector_store, CosmosDBVectorStore)
@ -67,7 +79,12 @@ def test_register_and_create_custom_vector_store():
VectorStoreFactory.register( VectorStoreFactory.register(
"custom", lambda **kwargs: custom_vector_store_class(**kwargs) "custom", lambda **kwargs: custom_vector_store_class(**kwargs)
) )
vector_store = VectorStoreFactory.create_vector_store("custom", {})
vector_store = VectorStoreFactory.create_vector_store(
vector_store_type="custom",
vector_store_schema_config=VectorStoreSchemaConfig(),
kwargs={},
)
assert custom_vector_store_class.called assert custom_vector_store_class.called
assert vector_store is instance assert vector_store is instance
@ -89,7 +106,11 @@ def test_get_vector_store_types():
def test_create_unknown_vector_store(): def test_create_unknown_vector_store():
with pytest.raises(ValueError, match="Unknown vector store type: unknown"): with pytest.raises(ValueError, match="Unknown vector store type: unknown"):
VectorStoreFactory.create_vector_store("unknown", {}) VectorStoreFactory.create_vector_store(
vector_store_type="unknown",
vector_store_schema_config=VectorStoreSchemaConfig(),
kwargs={},
)
def test_is_supported_type(): def test_is_supported_type():
@ -139,6 +160,9 @@ def test_register_class_directly_works():
# Test creating an instance # Test creating an instance
vector_store = VectorStoreFactory.create_vector_store( vector_store = VectorStoreFactory.create_vector_store(
"custom_class", {"collection_name": "test"} vector_store_type="custom_class",
vector_store_schema_config=VectorStoreSchemaConfig(),
kwargs={"collection_name": "test"},
) )
assert isinstance(vector_store, CustomVectorStore) assert isinstance(vector_store, CustomVectorStore)

View File

@ -7,20 +7,20 @@ import shutil
import tempfile import tempfile
import numpy as np import numpy as np
import pytest
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.vector_stores.base import VectorStoreDocument from graphrag.vector_stores.base import VectorStoreDocument
from graphrag.vector_stores.lancedb import LanceDBVectorStore from graphrag.vector_stores.lancedb import LanceDBVectorStore
def test_vector_store_operations(): class TestLanceDBVectorStore:
"""Test basic vector store operations with LanceDB.""" """Test class for TestLanceDBVectorStore."""
# Create a temporary directory for the test database
temp_dir = tempfile.mkdtemp()
try:
vector_store = LanceDBVectorStore(collection_name="test_collection")
vector_store.connect(db_uri=temp_dir)
docs = [ @pytest.fixture
def sample_documents(self):
"""Create sample documents for testing."""
return [
VectorStoreDocument( VectorStoreDocument(
id="1", id="1",
text="This is document 1", text="This is document 1",
@ -40,102 +40,11 @@ def test_vector_store_operations():
attributes={"title": "Doc 3", "category": "test"}, attributes={"title": "Doc 3", "category": "test"},
), ),
] ]
vector_store.load_documents(docs[:2])
assert vector_store.collection_name in vector_store.db_connection.table_names() @pytest.fixture
def sample_documents_categories(self):
doc = vector_store.search_by_id("1") """Create sample documents with different categories for testing."""
assert doc.id == "1" return [
assert doc.text == "This is document 1"
assert doc.vector is not None
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
assert doc.attributes["title"] == "Doc 1"
filter_query = vector_store.filter_by_id(["1"])
assert filter_query == "id in ('1')"
results = vector_store.similarity_search_by_vector(
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
)
assert 1 <= len(results) <= 2
assert isinstance(results[0].score, float)
# Test append mode
vector_store.load_documents([docs[2]], overwrite=False)
result = vector_store.search_by_id("3")
assert result.id == "3"
assert result.text == "This is document 3"
# Define a simple text embedder function for testing
def mock_embedder(text: str) -> list[float]:
return [0.1, 0.2, 0.3, 0.4, 0.5]
text_results = vector_store.similarity_search_by_text(
"test query", mock_embedder, k=2
)
assert 1 <= len(text_results) <= 2
assert isinstance(text_results[0].score, float)
# Test non-existent document
non_existent = vector_store.search_by_id("nonexistent")
assert non_existent.id == "nonexistent"
assert non_existent.text is None
assert non_existent.vector is None
finally:
shutil.rmtree(temp_dir)
def test_empty_collection():
"""Test creating an empty collection."""
# Create a temporary directory for the test database
temp_dir = tempfile.mkdtemp()
try:
vector_store = LanceDBVectorStore(collection_name="empty_collection")
vector_store.connect(db_uri=temp_dir)
# Load the vector store with a document, then delete it
sample_doc = VectorStoreDocument(
id="tmp",
text="Temporary document to create schema",
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
attributes={"title": "Tmp"},
)
vector_store.load_documents([sample_doc])
vector_store.db_connection.open_table(vector_store.collection_name).delete(
"id = 'tmp'"
)
# Should still have the collection
assert vector_store.collection_name in vector_store.db_connection.table_names()
# Add a document after creating an empty collection
doc = VectorStoreDocument(
id="1",
text="This is document 1",
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
attributes={"title": "Doc 1"},
)
vector_store.load_documents([doc], overwrite=False)
result = vector_store.search_by_id("1")
assert result.id == "1"
assert result.text == "This is document 1"
finally:
# Clean up - remove the temporary directory
shutil.rmtree(temp_dir)
def test_filter_search():
"""Test filtered search with LanceDB."""
# Create a temporary directory for the test database
temp_dir = tempfile.mkdtemp()
try:
vector_store = LanceDBVectorStore(collection_name="filter_collection")
vector_store.connect(db_uri=temp_dir)
# Create test documents with different categories
docs = [
VectorStoreDocument( VectorStoreDocument(
id="1", id="1",
text="Document about cats", text="Document about cats",
@ -155,18 +64,201 @@ def test_filter_search():
attributes={"category": "vehicles"}, attributes={"category": "vehicles"},
), ),
] ]
vector_store.load_documents(docs)
# Filter to include only documents about animals def test_vector_store_operations(self, sample_documents):
vector_store.filter_by_id(["1", "2"]) """Test basic vector store operations with LanceDB."""
results = vector_store.similarity_search_by_vector( # Create a temporary directory for the test database
[0.1, 0.2, 0.3, 0.4, 0.5], k=3 temp_dir = tempfile.mkdtemp()
) try:
vector_store = LanceDBVectorStore(
vector_store_schema_config=VectorStoreSchemaConfig(
index_name="test_collection", vector_size=5
)
)
vector_store.connect(db_uri=temp_dir)
vector_store.load_documents(sample_documents[:2])
# Should return at most 2 documents (the filtered ones) if vector_store.index_name:
assert len(results) <= 2 assert (
ids = [result.document.id for result in results] vector_store.index_name in vector_store.db_connection.table_names()
assert "3" not in ids )
assert set(ids).issubset({"1", "2"})
finally: doc = vector_store.search_by_id("1")
shutil.rmtree(temp_dir) assert doc.id == "1"
assert doc.text == "This is document 1"
assert doc.vector is not None
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
assert doc.attributes["title"] == "Doc 1"
filter_query = vector_store.filter_by_id(["1"])
assert filter_query == "id in ('1')"
results = vector_store.similarity_search_by_vector(
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
)
assert 1 <= len(results) <= 2
assert isinstance(results[0].score, float)
# Test append mode
vector_store.load_documents([sample_documents[2]], overwrite=False)
result = vector_store.search_by_id("3")
assert result.id == "3"
assert result.text == "This is document 3"
# Define a simple text embedder function for testing
def mock_embedder(text: str) -> list[float]:
return [0.1, 0.2, 0.3, 0.4, 0.5]
text_results = vector_store.similarity_search_by_text(
"test query", mock_embedder, k=2
)
assert 1 <= len(text_results) <= 2
assert isinstance(text_results[0].score, float)
# Test non-existent document
non_existent = vector_store.search_by_id("nonexistent")
assert non_existent.id == "nonexistent"
assert non_existent.text is None
assert non_existent.vector is None
finally:
shutil.rmtree(temp_dir)
def test_empty_collection(self):
"""Test creating an empty collection."""
# Create a temporary directory for the test database
temp_dir = tempfile.mkdtemp()
try:
vector_store = LanceDBVectorStore(
vector_store_schema_config=VectorStoreSchemaConfig(
index_name="empty_collection", vector_size=5
)
)
vector_store.connect(db_uri=temp_dir)
# Load the vector store with a document, then delete it
sample_doc = VectorStoreDocument(
id="tmp",
text="Temporary document to create schema",
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
attributes={"title": "Tmp"},
)
vector_store.load_documents([sample_doc])
vector_store.db_connection.open_table(
vector_store.index_name if vector_store.index_name else ""
).delete("id = 'tmp'")
# Should still have the collection
if vector_store.index_name:
assert (
vector_store.index_name in vector_store.db_connection.table_names()
)
# Add a document after creating an empty collection
doc = VectorStoreDocument(
id="1",
text="This is document 1",
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
attributes={"title": "Doc 1"},
)
vector_store.load_documents([doc], overwrite=False)
result = vector_store.search_by_id("1")
assert result.id == "1"
assert result.text == "This is document 1"
finally:
# Clean up - remove the temporary directory
shutil.rmtree(temp_dir)
def test_filter_search(self, sample_documents_categories):
"""Test filtered search with LanceDB."""
# Create a temporary directory for the test database
temp_dir = tempfile.mkdtemp()
try:
vector_store = LanceDBVectorStore(
vector_store_schema_config=VectorStoreSchemaConfig(
index_name="filter_collection", vector_size=5
)
)
vector_store.connect(db_uri=temp_dir)
vector_store.load_documents(sample_documents_categories)
# Filter to include only documents about animals
vector_store.filter_by_id(["1", "2"])
results = vector_store.similarity_search_by_vector(
[0.1, 0.2, 0.3, 0.4, 0.5], k=3
)
# Should return at most 2 documents (the filtered ones)
assert len(results) <= 2
ids = [result.document.id for result in results]
assert "3" not in ids
assert set(ids).issubset({"1", "2"})
finally:
shutil.rmtree(temp_dir)
def test_vector_store_customization(self, sample_documents):
"""Test vector store customization with LanceDB."""
# Create a temporary directory for the test database
temp_dir = tempfile.mkdtemp()
try:
vector_store = LanceDBVectorStore(
vector_store_schema_config=VectorStoreSchemaConfig(
index_name="text-embeddings",
id_field="id_custom",
text_field="text_custom",
vector_field="vector_custom",
attributes_field="attributes_custom",
vector_size=5,
),
)
vector_store.connect(db_uri=temp_dir)
vector_store.load_documents(sample_documents[:2])
if vector_store.index_name:
assert (
vector_store.index_name in vector_store.db_connection.table_names()
)
doc = vector_store.search_by_id("1")
assert doc.id == "1"
assert doc.text == "This is document 1"
assert doc.vector is not None
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
assert doc.attributes["title"] == "Doc 1"
filter_query = vector_store.filter_by_id(["1"])
assert filter_query == f"{vector_store.id_field} in ('1')"
results = vector_store.similarity_search_by_vector(
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
)
assert 1 <= len(results) <= 2
assert isinstance(results[0].score, float)
# Test append mode
vector_store.load_documents([sample_documents[2]], overwrite=False)
result = vector_store.search_by_id("3")
assert result.id == "3"
assert result.text == "This is document 3"
# Define a simple text embedder function for testing
def mock_embedder(text: str) -> list[float]:
return [0.1, 0.2, 0.3, 0.4, 0.5]
text_results = vector_store.similarity_search_by_text(
"test query", mock_embedder, k=2
)
assert 1 <= len(text_results) <= 2
assert isinstance(text_results[0].score, float)
# Test non-existent document
non_existent = vector_store.search_by_id("nonexistent")
assert non_existent.id == "nonexistent"
assert non_existent.text is None
assert non_existent.vector is None
finally:
shutil.rmtree(temp_dir)

View File

@ -3,6 +3,7 @@
from typing import Any from typing import Any
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.data_model.entity import Entity from graphrag.data_model.entity import Entity
from graphrag.data_model.types import TextEmbedder from graphrag.data_model.types import TextEmbedder
from graphrag.language_model.manager import ModelManager from graphrag.language_model.manager import ModelManager
@ -19,7 +20,9 @@ from graphrag.vector_stores.base import (
class MockBaseVectorStore(BaseVectorStore): class MockBaseVectorStore(BaseVectorStore):
def __init__(self, documents: list[VectorStoreDocument]) -> None: def __init__(self, documents: list[VectorStoreDocument]) -> None:
super().__init__("mock") super().__init__(
vector_store_schema_config=VectorStoreSchemaConfig(index_name="mock")
)
self.documents = documents self.documents = documents
def connect(self, **kwargs: Any) -> None: def connect(self, **kwargs: Any) -> None:

View File

@ -3,19 +3,19 @@
import pytest import pytest
from graphrag.config.embeddings import create_collection_name from graphrag.config.embeddings import create_index_name
def test_create_collection_name(): def test_create_index_name():
collection = create_collection_name("default", "entity.title") collection = create_index_name("default", "entity.title")
assert collection == "default-entity-title" assert collection == "default-entity-title"
def test_create_collection_name_invalid_embedding_throws(): def test_create_index_name_invalid_embedding_throws():
with pytest.raises(KeyError): with pytest.raises(KeyError):
create_collection_name("default", "invalid.name") create_index_name("default", "invalid.name")
def test_create_collection_name_invalid_embedding_does_not_throw(): def test_create_index_name_invalid_embedding_does_not_throw():
collection = create_collection_name("default", "invalid.name", validate=False) collection = create_index_name("default", "invalid.name", validate=False)
assert collection == "default-invalid-name" assert collection == "default-invalid-name"