mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Custom vector store schema implementation (#2062)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
* progress on vector customization * fix for lancedb vectors * cosmosdb implementation * uv run poe format * clean test for vector store * semversioner update * test_factory.py integration test fixes * fixes for cosmosdb test * integration test fix for lancedb * uv fix for format * test fixes * fixes for tests * fix cosmosdb bug * print statement * test * test * fix cosmosdb bug * test validation * validation cosmosdb * validate cosmosdb * fix cosmosdb * fix small feedback from PR --------- Co-authored-by: Gaudy Blanco <gaudy-microsoft@MacBook-Pro-m4-Gaudy-For-Work.local>
This commit is contained in:
parent
075cadd59a
commit
82cd3b7df2
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "minor",
|
||||||
|
"description": "add customization to vector store"
|
||||||
|
}
|
||||||
55
.vscode/launch.json
vendored
55
.vscode/launch.json
vendored
@ -6,21 +6,24 @@
|
|||||||
"name": "Indexer",
|
"name": "Indexer",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"module": "uv",
|
"module": "graphrag",
|
||||||
"args": [
|
"args": [
|
||||||
"poe", "index",
|
"index",
|
||||||
"--root", "<path_to_ragtest_root_demo>"
|
"--root",
|
||||||
|
"<path_to_index_folder>"
|
||||||
],
|
],
|
||||||
|
"console": "integratedTerminal"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Query",
|
"name": "Query",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"module": "uv",
|
"module": "graphrag",
|
||||||
"args": [
|
"args": [
|
||||||
"poe", "query",
|
"query",
|
||||||
"--root", "<path_to_ragtest_root_demo>",
|
"--root",
|
||||||
"--method", "global",
|
"<path_to_index_folder>",
|
||||||
|
"--method", "basic",
|
||||||
"--query", "What are the top themes in this story",
|
"--query", "What are the top themes in this story",
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -34,6 +37,42 @@
|
|||||||
"--config",
|
"--config",
|
||||||
"<path_to_ragtest_root_demo>/settings.yaml",
|
"<path_to_ragtest_root_demo>/settings.yaml",
|
||||||
]
|
]
|
||||||
}
|
},
|
||||||
|
{
|
||||||
|
"name": "Debug Integration Pytest",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"module": "pytest",
|
||||||
|
"args": [
|
||||||
|
"./tests/integration/vector_stores",
|
||||||
|
"-k", "test_azure_ai_search"
|
||||||
|
],
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Debug Verbs Pytest",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"module": "pytest",
|
||||||
|
"args": [
|
||||||
|
"./tests/verbs",
|
||||||
|
"-k", "test_generate_text_embeddings"
|
||||||
|
],
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Debug Smoke Pytest",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"module": "pytest",
|
||||||
|
"args": [
|
||||||
|
"./tests/smoke",
|
||||||
|
"-k", "test_fixtures"
|
||||||
|
],
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": false
|
||||||
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -394,6 +394,7 @@ class VectorStoreDefaults:
|
|||||||
api_key: None = None
|
api_key: None = None
|
||||||
audience: None = None
|
audience: None = None
|
||||||
database_name: None = None
|
database_name: None = None
|
||||||
|
schema: None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -29,14 +29,14 @@ default_embeddings: list[str] = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def create_collection_name(
|
def create_index_name(
|
||||||
container_name: str, embedding_name: str, validate: bool = True
|
container_name: str, embedding_name: str, validate: bool = True
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Create a collection name for the embedding store.
|
Create a index name for the embedding store.
|
||||||
|
|
||||||
Within any given vector store, we can have multiple sets of embeddings organized into projects.
|
Within any given vector store, we can have multiple sets of embeddings organized into projects.
|
||||||
The `container` param is used for this partitioning, and is added as a prefix to the collection name for differentiation.
|
The `container` param is used for this partitioning, and is added as a prefix to the index name for differentiation.
|
||||||
|
|
||||||
The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings
|
The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,9 @@
|
|||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
from graphrag.config.defaults import vector_store_defaults
|
from graphrag.config.defaults import vector_store_defaults
|
||||||
|
from graphrag.config.embeddings import all_embeddings
|
||||||
from graphrag.config.enums import VectorStoreType
|
from graphrag.config.enums import VectorStoreType
|
||||||
|
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||||
|
|
||||||
|
|
||||||
class VectorStoreConfig(BaseModel):
|
class VectorStoreConfig(BaseModel):
|
||||||
@ -85,9 +87,25 @@ class VectorStoreConfig(BaseModel):
|
|||||||
default=vector_store_defaults.overwrite,
|
default=vector_store_defaults.overwrite,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
embeddings_schema: dict[str, VectorStoreSchemaConfig] = {}
|
||||||
|
|
||||||
|
def _validate_embeddings_schema(self) -> None:
|
||||||
|
"""Validate the embeddings schema."""
|
||||||
|
for name in self.embeddings_schema:
|
||||||
|
if name not in all_embeddings:
|
||||||
|
msg = f"vector_store.embeddings_schema contains an invalid embedding schema name: {name}. Please update your settings.yaml and select the correct embedding schema names."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
if self.type == VectorStoreType.CosmosDB:
|
||||||
|
for id_field in self.embeddings_schema:
|
||||||
|
if id_field != "id":
|
||||||
|
msg = "When using CosmosDB, the id_field in embeddings_schema must be 'id'. Please update your settings.yaml and set the id_field to 'id'."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def _validate_model(self):
|
def _validate_model(self):
|
||||||
"""Validate the model."""
|
"""Validate the model."""
|
||||||
self._validate_db_uri()
|
self._validate_db_uri()
|
||||||
self._validate_url()
|
self._validate_url()
|
||||||
|
self._validate_embeddings_schema()
|
||||||
return self
|
return self
|
||||||
|
|||||||
66
graphrag/config/models/vector_store_schema_config.py
Normal file
66
graphrag/config/models/vector_store_schema_config.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License
|
||||||
|
|
||||||
|
"""Parameterization settings for the default configuration."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
DEFAULT_VECTOR_SIZE: int = 1536
|
||||||
|
|
||||||
|
VALID_IDENTIFIER_REGEX = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_field_name(field: str) -> bool:
|
||||||
|
"""Check if a field name is valid for CosmosDB."""
|
||||||
|
return bool(VALID_IDENTIFIER_REGEX.match(field))
|
||||||
|
|
||||||
|
|
||||||
|
class VectorStoreSchemaConfig(BaseModel):
|
||||||
|
"""The default configuration section for Vector Store Schema."""
|
||||||
|
|
||||||
|
id_field: str = Field(
|
||||||
|
description="The ID field to use.",
|
||||||
|
default="id",
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_field: str = Field(
|
||||||
|
description="The vector field to use.",
|
||||||
|
default="vector",
|
||||||
|
)
|
||||||
|
|
||||||
|
text_field: str = Field(
|
||||||
|
description="The text field to use.",
|
||||||
|
default="text",
|
||||||
|
)
|
||||||
|
|
||||||
|
attributes_field: str = Field(
|
||||||
|
description="The attributes field to use.",
|
||||||
|
default="attributes",
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_size: int = Field(
|
||||||
|
description="The vector size to use.",
|
||||||
|
default=DEFAULT_VECTOR_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
index_name: str | None = Field(description="The index name to use.", default=None)
|
||||||
|
|
||||||
|
def _validate_schema(self) -> None:
|
||||||
|
"""Validate the schema."""
|
||||||
|
for field in [
|
||||||
|
self.id_field,
|
||||||
|
self.vector_field,
|
||||||
|
self.text_field,
|
||||||
|
self.attributes_field,
|
||||||
|
]:
|
||||||
|
if not is_valid_field_name(field):
|
||||||
|
msg = f"Unsafe or invalid field name: {field}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _validate_model(self):
|
||||||
|
"""Validate the model."""
|
||||||
|
self._validate_schema()
|
||||||
|
return self
|
||||||
@ -12,7 +12,8 @@ import pandas as pd
|
|||||||
|
|
||||||
from graphrag.cache.pipeline_cache import PipelineCache
|
from graphrag.cache.pipeline_cache import PipelineCache
|
||||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||||
from graphrag.config.embeddings import create_collection_name
|
from graphrag.config.embeddings import create_index_name
|
||||||
|
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||||
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy
|
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy
|
||||||
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
|
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
|
||||||
from graphrag.vector_stores.factory import VectorStoreFactory
|
from graphrag.vector_stores.factory import VectorStoreFactory
|
||||||
@ -49,9 +50,9 @@ async def embed_text(
|
|||||||
vector_store_config = strategy.get("vector_store")
|
vector_store_config = strategy.get("vector_store")
|
||||||
|
|
||||||
if vector_store_config:
|
if vector_store_config:
|
||||||
collection_name = _get_collection_name(vector_store_config, embedding_name)
|
index_name = _get_index_name(vector_store_config, embedding_name)
|
||||||
vector_store: BaseVectorStore = _create_vector_store(
|
vector_store: BaseVectorStore = _create_vector_store(
|
||||||
vector_store_config, collection_name
|
vector_store_config, index_name, embedding_name
|
||||||
)
|
)
|
||||||
vector_store_workflow_config = vector_store_config.get(
|
vector_store_workflow_config = vector_store_config.get(
|
||||||
embedding_name, vector_store_config
|
embedding_name, vector_store_config
|
||||||
@ -183,27 +184,46 @@ async def _text_embed_with_vector_store(
|
|||||||
|
|
||||||
|
|
||||||
def _create_vector_store(
|
def _create_vector_store(
|
||||||
vector_store_config: dict, collection_name: str
|
vector_store_config: dict, index_name: str, embedding_name: str | None = None
|
||||||
) -> BaseVectorStore:
|
) -> BaseVectorStore:
|
||||||
vector_store_type: str = str(vector_store_config.get("type"))
|
vector_store_type: str = str(vector_store_config.get("type"))
|
||||||
if collection_name:
|
|
||||||
vector_store_config.update({"collection_name": collection_name})
|
embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get(
|
||||||
|
"embeddings_schema", {}
|
||||||
|
)
|
||||||
|
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
|
||||||
|
|
||||||
|
if (
|
||||||
|
embeddings_schema is not None
|
||||||
|
and embedding_name is not None
|
||||||
|
and embedding_name in embeddings_schema
|
||||||
|
):
|
||||||
|
raw_config = embeddings_schema[embedding_name]
|
||||||
|
if isinstance(raw_config, dict):
|
||||||
|
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
|
||||||
|
else:
|
||||||
|
single_embedding_config = raw_config
|
||||||
|
|
||||||
|
if single_embedding_config.index_name is None:
|
||||||
|
single_embedding_config.index_name = index_name
|
||||||
|
|
||||||
vector_store = VectorStoreFactory().create_vector_store(
|
vector_store = VectorStoreFactory().create_vector_store(
|
||||||
vector_store_type, kwargs=vector_store_config
|
vector_store_schema_config=single_embedding_config,
|
||||||
|
vector_store_type=vector_store_type,
|
||||||
|
kwargs=vector_store_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_store.connect(**vector_store_config)
|
vector_store.connect(**vector_store_config)
|
||||||
return vector_store
|
return vector_store
|
||||||
|
|
||||||
|
|
||||||
def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str:
|
def _get_index_name(vector_store_config: dict, embedding_name: str) -> str:
|
||||||
container_name = vector_store_config.get("container_name", "default")
|
container_name = vector_store_config.get("container_name", "default")
|
||||||
collection_name = create_collection_name(container_name, embedding_name)
|
index_name = create_index_name(container_name, embedding_name)
|
||||||
|
|
||||||
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}"
|
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {index_name}"
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
return collection_name
|
return index_name
|
||||||
|
|
||||||
|
|
||||||
def load_strategy(strategy: TextEmbedStrategyType) -> TextEmbeddingStrategy:
|
def load_strategy(strategy: TextEmbedStrategyType) -> TextEmbeddingStrategy:
|
||||||
|
|||||||
@ -8,9 +8,10 @@ from typing import Any
|
|||||||
|
|
||||||
from graphrag.cache.factory import CacheFactory
|
from graphrag.cache.factory import CacheFactory
|
||||||
from graphrag.cache.pipeline_cache import PipelineCache
|
from graphrag.cache.pipeline_cache import PipelineCache
|
||||||
from graphrag.config.embeddings import create_collection_name
|
from graphrag.config.embeddings import create_index_name
|
||||||
from graphrag.config.models.cache_config import CacheConfig
|
from graphrag.config.models.cache_config import CacheConfig
|
||||||
from graphrag.config.models.storage_config import StorageConfig
|
from graphrag.config.models.storage_config import StorageConfig
|
||||||
|
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||||
from graphrag.data_model.types import TextEmbedder
|
from graphrag.data_model.types import TextEmbedder
|
||||||
from graphrag.storage.factory import StorageFactory
|
from graphrag.storage.factory import StorageFactory
|
||||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||||
@ -103,12 +104,33 @@ def get_embedding_store(
|
|||||||
index_names = []
|
index_names = []
|
||||||
for index, store in config_args.items():
|
for index, store in config_args.items():
|
||||||
vector_store_type = store["type"]
|
vector_store_type = store["type"]
|
||||||
collection_name = create_collection_name(
|
index_name = create_index_name(
|
||||||
store.get("container_name", "default"), embedding_name
|
store.get("container_name", "default"), embedding_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get(
|
||||||
|
"embeddings_schema", {}
|
||||||
|
)
|
||||||
|
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
|
||||||
|
|
||||||
|
if (
|
||||||
|
embeddings_schema is not None
|
||||||
|
and embedding_name is not None
|
||||||
|
and embedding_name in embeddings_schema
|
||||||
|
):
|
||||||
|
raw_config = embeddings_schema[embedding_name]
|
||||||
|
if isinstance(raw_config, dict):
|
||||||
|
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
|
||||||
|
else:
|
||||||
|
single_embedding_config = raw_config
|
||||||
|
|
||||||
|
if single_embedding_config.index_name is None:
|
||||||
|
single_embedding_config.index_name = index_name
|
||||||
|
|
||||||
embedding_store = VectorStoreFactory().create_vector_store(
|
embedding_store = VectorStoreFactory().create_vector_store(
|
||||||
vector_store_type=vector_store_type,
|
vector_store_type=vector_store_type,
|
||||||
kwargs={**store, "collection_name": collection_name},
|
vector_store_schema_config=single_embedding_config,
|
||||||
|
kwargs={**store},
|
||||||
)
|
)
|
||||||
embedding_store.connect(**store)
|
embedding_store.connect(**store)
|
||||||
# If there is only a single index, return the embedding store directly
|
# If there is only a single index, return the embedding store directly
|
||||||
|
|||||||
@ -24,9 +24,9 @@ from azure.search.documents.indexes.models import (
|
|||||||
)
|
)
|
||||||
from azure.search.documents.models import VectorizedQuery
|
from azure.search.documents.models import VectorizedQuery
|
||||||
|
|
||||||
|
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||||
from graphrag.data_model.types import TextEmbedder
|
from graphrag.data_model.types import TextEmbedder
|
||||||
from graphrag.vector_stores.base import (
|
from graphrag.vector_stores.base import (
|
||||||
DEFAULT_VECTOR_SIZE,
|
|
||||||
BaseVectorStore,
|
BaseVectorStore,
|
||||||
VectorStoreDocument,
|
VectorStoreDocument,
|
||||||
VectorStoreSearchResult,
|
VectorStoreSearchResult,
|
||||||
@ -38,15 +38,18 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
|||||||
|
|
||||||
index_client: SearchIndexClient
|
index_client: SearchIndexClient
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(
|
||||||
super().__init__(**kwargs)
|
self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
vector_store_schema_config=vector_store_schema_config, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def connect(self, **kwargs: Any) -> Any:
|
def connect(self, **kwargs: Any) -> Any:
|
||||||
"""Connect to AI search vector storage."""
|
"""Connect to AI search vector storage."""
|
||||||
url = kwargs["url"]
|
url = kwargs["url"]
|
||||||
api_key = kwargs.get("api_key")
|
api_key = kwargs.get("api_key")
|
||||||
audience = kwargs.get("audience")
|
audience = kwargs.get("audience")
|
||||||
self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE)
|
|
||||||
|
|
||||||
self.vector_search_profile_name = kwargs.get(
|
self.vector_search_profile_name = kwargs.get(
|
||||||
"vector_search_profile_name", "vectorSearchProfile"
|
"vector_search_profile_name", "vectorSearchProfile"
|
||||||
@ -56,7 +59,7 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
|||||||
audience_arg = {"audience": audience} if audience and not api_key else {}
|
audience_arg = {"audience": audience} if audience and not api_key else {}
|
||||||
self.db_connection = SearchClient(
|
self.db_connection = SearchClient(
|
||||||
endpoint=url,
|
endpoint=url,
|
||||||
index_name=self.collection_name,
|
index_name=self.index_name if self.index_name else "",
|
||||||
credential=(
|
credential=(
|
||||||
AzureKeyCredential(api_key) if api_key else DefaultAzureCredential()
|
AzureKeyCredential(api_key) if api_key else DefaultAzureCredential()
|
||||||
),
|
),
|
||||||
@ -78,8 +81,11 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Load documents into an Azure AI Search index."""
|
"""Load documents into an Azure AI Search index."""
|
||||||
if overwrite:
|
if overwrite:
|
||||||
if self.collection_name in self.index_client.list_index_names():
|
if (
|
||||||
self.index_client.delete_index(self.collection_name)
|
self.index_name is not None
|
||||||
|
and self.index_name in self.index_client.list_index_names()
|
||||||
|
):
|
||||||
|
self.index_client.delete_index(self.index_name)
|
||||||
|
|
||||||
# Configure vector search profile
|
# Configure vector search profile
|
||||||
vector_search = VectorSearch(
|
vector_search = VectorSearch(
|
||||||
@ -100,24 +106,26 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
|||||||
)
|
)
|
||||||
# Configure the index
|
# Configure the index
|
||||||
index = SearchIndex(
|
index = SearchIndex(
|
||||||
name=self.collection_name,
|
name=self.index_name if self.index_name else "",
|
||||||
fields=[
|
fields=[
|
||||||
SimpleField(
|
SimpleField(
|
||||||
name="id",
|
name=self.id_field,
|
||||||
type=SearchFieldDataType.String,
|
type=SearchFieldDataType.String,
|
||||||
key=True,
|
key=True,
|
||||||
),
|
),
|
||||||
SearchField(
|
SearchField(
|
||||||
name="vector",
|
name=self.vector_field,
|
||||||
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
|
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
|
||||||
searchable=True,
|
searchable=True,
|
||||||
hidden=False, # DRIFT needs to return the vector for client-side similarity
|
hidden=False, # DRIFT needs to return the vector for client-side similarity
|
||||||
vector_search_dimensions=self.vector_size,
|
vector_search_dimensions=self.vector_size,
|
||||||
vector_search_profile_name=self.vector_search_profile_name,
|
vector_search_profile_name=self.vector_search_profile_name,
|
||||||
),
|
),
|
||||||
SearchableField(name="text", type=SearchFieldDataType.String),
|
SearchableField(
|
||||||
|
name=self.text_field, type=SearchFieldDataType.String
|
||||||
|
),
|
||||||
SimpleField(
|
SimpleField(
|
||||||
name="attributes",
|
name=self.attributes_field,
|
||||||
type=SearchFieldDataType.String,
|
type=SearchFieldDataType.String,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -129,10 +137,10 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
|||||||
|
|
||||||
batch = [
|
batch = [
|
||||||
{
|
{
|
||||||
"id": doc.id,
|
self.id_field: doc.id,
|
||||||
"vector": doc.vector,
|
self.vector_field: doc.vector,
|
||||||
"text": doc.text,
|
self.text_field: doc.text,
|
||||||
"attributes": json.dumps(doc.attributes),
|
self.attributes_field: json.dumps(doc.attributes),
|
||||||
}
|
}
|
||||||
for doc in documents
|
for doc in documents
|
||||||
if doc.vector is not None
|
if doc.vector is not None
|
||||||
@ -151,7 +159,7 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
|||||||
# More info about odata filtering here: https://learn.microsoft.com/en-us/azure/search/search-query-odata-search-in-function
|
# More info about odata filtering here: https://learn.microsoft.com/en-us/azure/search/search-query-odata-search-in-function
|
||||||
# search.in is faster that joined and/or conditions
|
# search.in is faster that joined and/or conditions
|
||||||
id_filter = ",".join([f"{id!s}" for id in include_ids])
|
id_filter = ",".join([f"{id!s}" for id in include_ids])
|
||||||
self.query_filter = f"search.in(id, '{id_filter}', ',')"
|
self.query_filter = f"search.in({self.id_field}, '{id_filter}', ',')"
|
||||||
|
|
||||||
# Returning to keep consistency with other methods, but not needed
|
# Returning to keep consistency with other methods, but not needed
|
||||||
# TODO: Refactor on a future PR
|
# TODO: Refactor on a future PR
|
||||||
@ -162,7 +170,7 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
|||||||
) -> list[VectorStoreSearchResult]:
|
) -> list[VectorStoreSearchResult]:
|
||||||
"""Perform a vector-based similarity search."""
|
"""Perform a vector-based similarity search."""
|
||||||
vectorized_query = VectorizedQuery(
|
vectorized_query = VectorizedQuery(
|
||||||
vector=query_embedding, k_nearest_neighbors=k, fields="vector"
|
vector=query_embedding, k_nearest_neighbors=k, fields=self.vector_field
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self.db_connection.search(
|
response = self.db_connection.search(
|
||||||
@ -172,10 +180,10 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
|||||||
return [
|
return [
|
||||||
VectorStoreSearchResult(
|
VectorStoreSearchResult(
|
||||||
document=VectorStoreDocument(
|
document=VectorStoreDocument(
|
||||||
id=doc.get("id", ""),
|
id=doc.get(self.id_field, ""),
|
||||||
text=doc.get("text", ""),
|
text=doc.get(self.text_field, ""),
|
||||||
vector=doc.get("vector", []),
|
vector=doc.get(self.vector_field, []),
|
||||||
attributes=(json.loads(doc.get("attributes", "{}"))),
|
attributes=(json.loads(doc.get(self.attributes_field, "{}"))),
|
||||||
),
|
),
|
||||||
# Cosine similarity between 0.333 and 1.000
|
# Cosine similarity between 0.333 and 1.000
|
||||||
# https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking#scores-in-a-hybrid-search-results
|
# https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking#scores-in-a-hybrid-search-results
|
||||||
@ -199,8 +207,8 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
|||||||
"""Search for a document by id."""
|
"""Search for a document by id."""
|
||||||
response = self.db_connection.get_document(id)
|
response = self.db_connection.get_document(id)
|
||||||
return VectorStoreDocument(
|
return VectorStoreDocument(
|
||||||
id=response.get("id", ""),
|
id=response.get(self.id_field, ""),
|
||||||
text=response.get("text", ""),
|
text=response.get(self.text_field, ""),
|
||||||
vector=response.get("vector", []),
|
vector=response.get(self.vector_field, []),
|
||||||
attributes=(json.loads(response.get("attributes", "{}"))),
|
attributes=(json.loads(response.get(self.attributes_field, "{}"))),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -7,10 +7,9 @@ from abc import ABC, abstractmethod
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||||
from graphrag.data_model.types import TextEmbedder
|
from graphrag.data_model.types import TextEmbedder
|
||||||
|
|
||||||
DEFAULT_VECTOR_SIZE: int = 1536
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VectorStoreDocument:
|
class VectorStoreDocument:
|
||||||
@ -42,18 +41,24 @@ class BaseVectorStore(ABC):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
vector_store_schema_config: VectorStoreSchemaConfig,
|
||||||
db_connection: Any | None = None,
|
db_connection: Any | None = None,
|
||||||
document_collection: Any | None = None,
|
document_collection: Any | None = None,
|
||||||
query_filter: Any | None = None,
|
query_filter: Any | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
self.collection_name = collection_name
|
|
||||||
self.db_connection = db_connection
|
self.db_connection = db_connection
|
||||||
self.document_collection = document_collection
|
self.document_collection = document_collection
|
||||||
self.query_filter = query_filter
|
self.query_filter = query_filter
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
self.index_name = vector_store_schema_config.index_name
|
||||||
|
self.id_field = vector_store_schema_config.id_field
|
||||||
|
self.text_field = vector_store_schema_config.text_field
|
||||||
|
self.vector_field = vector_store_schema_config.vector_field
|
||||||
|
self.attributes_field = vector_store_schema_config.attributes_field
|
||||||
|
self.vector_size = vector_store_schema_config.vector_size
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def connect(self, **kwargs: Any) -> None:
|
def connect(self, **kwargs: Any) -> None:
|
||||||
"""Connect to vector storage."""
|
"""Connect to vector storage."""
|
||||||
|
|||||||
@ -11,9 +11,9 @@ from azure.cosmos.exceptions import CosmosHttpResponseError
|
|||||||
from azure.cosmos.partition_key import PartitionKey
|
from azure.cosmos.partition_key import PartitionKey
|
||||||
from azure.identity import DefaultAzureCredential
|
from azure.identity import DefaultAzureCredential
|
||||||
|
|
||||||
|
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||||
from graphrag.data_model.types import TextEmbedder
|
from graphrag.data_model.types import TextEmbedder
|
||||||
from graphrag.vector_stores.base import (
|
from graphrag.vector_stores.base import (
|
||||||
DEFAULT_VECTOR_SIZE,
|
|
||||||
BaseVectorStore,
|
BaseVectorStore,
|
||||||
VectorStoreDocument,
|
VectorStoreDocument,
|
||||||
VectorStoreSearchResult,
|
VectorStoreSearchResult,
|
||||||
@ -27,8 +27,12 @@ class CosmosDBVectorStore(BaseVectorStore):
|
|||||||
_database_client: DatabaseProxy
|
_database_client: DatabaseProxy
|
||||||
_container_client: ContainerProxy
|
_container_client: ContainerProxy
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(
|
||||||
super().__init__(**kwargs)
|
self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
vector_store_schema_config=vector_store_schema_config, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def connect(self, **kwargs: Any) -> Any:
|
def connect(self, **kwargs: Any) -> Any:
|
||||||
"""Connect to CosmosDB vector storage."""
|
"""Connect to CosmosDB vector storage."""
|
||||||
@ -49,13 +53,12 @@ class CosmosDBVectorStore(BaseVectorStore):
|
|||||||
msg = "Database name must be provided."
|
msg = "Database name must be provided."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
self._database_name = database_name
|
self._database_name = database_name
|
||||||
collection_name = self.collection_name
|
if self.index_name is None:
|
||||||
if collection_name is None:
|
msg = "Index name is empty or not provided."
|
||||||
msg = "Collection name is empty or not provided."
|
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
self._container_name = collection_name
|
self._container_name = self.index_name
|
||||||
|
|
||||||
self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE)
|
self.vector_size = self.vector_size
|
||||||
self._create_database()
|
self._create_database()
|
||||||
self._create_container()
|
self._create_container()
|
||||||
|
|
||||||
@ -80,13 +83,13 @@ class CosmosDBVectorStore(BaseVectorStore):
|
|||||||
|
|
||||||
def _create_container(self) -> None:
|
def _create_container(self) -> None:
|
||||||
"""Create the container if it doesn't exist."""
|
"""Create the container if it doesn't exist."""
|
||||||
partition_key = PartitionKey(path="/id", kind="Hash")
|
partition_key = PartitionKey(path=f"/{self.id_field}", kind="Hash")
|
||||||
|
|
||||||
# Define the container vector policy
|
# Define the container vector policy
|
||||||
vector_embedding_policy = {
|
vector_embedding_policy = {
|
||||||
"vectorEmbeddings": [
|
"vectorEmbeddings": [
|
||||||
{
|
{
|
||||||
"path": "/vector",
|
"path": f"/{self.vector_field}",
|
||||||
"dataType": "float32",
|
"dataType": "float32",
|
||||||
"distanceFunction": "cosine",
|
"distanceFunction": "cosine",
|
||||||
"dimensions": self.vector_size,
|
"dimensions": self.vector_size,
|
||||||
@ -99,13 +102,18 @@ class CosmosDBVectorStore(BaseVectorStore):
|
|||||||
"indexingMode": "consistent",
|
"indexingMode": "consistent",
|
||||||
"automatic": True,
|
"automatic": True,
|
||||||
"includedPaths": [{"path": "/*"}],
|
"includedPaths": [{"path": "/*"}],
|
||||||
"excludedPaths": [{"path": "/_etag/?"}, {"path": "/vector/*"}],
|
"excludedPaths": [
|
||||||
|
{"path": "/_etag/?"},
|
||||||
|
{"path": f"/{self.vector_field}/*"},
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Currently, the CosmosDB emulator does not support the diskANN policy.
|
# Currently, the CosmosDB emulator does not support the diskANN policy.
|
||||||
try:
|
try:
|
||||||
# First try with the standard diskANN policy
|
# First try with the standard diskANN policy
|
||||||
indexing_policy["vectorIndexes"] = [{"path": "/vector", "type": "diskANN"}]
|
indexing_policy["vectorIndexes"] = [
|
||||||
|
{"path": f"/{self.vector_field}", "type": "diskANN"}
|
||||||
|
]
|
||||||
|
|
||||||
# Create the container and container client
|
# Create the container and container client
|
||||||
self._database_client.create_container_if_not_exists(
|
self._database_client.create_container_if_not_exists(
|
||||||
@ -158,12 +166,16 @@ class CosmosDBVectorStore(BaseVectorStore):
|
|||||||
# Upload documents to CosmosDB
|
# Upload documents to CosmosDB
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
if doc.vector is not None:
|
if doc.vector is not None:
|
||||||
|
print("Document to store:") # noqa: T201
|
||||||
|
print(doc) # noqa: T201
|
||||||
doc_json = {
|
doc_json = {
|
||||||
"id": doc.id,
|
self.id_field: doc.id,
|
||||||
"vector": doc.vector,
|
self.vector_field: doc.vector,
|
||||||
"text": doc.text,
|
self.text_field: doc.text,
|
||||||
"attributes": json.dumps(doc.attributes),
|
self.attributes_field: json.dumps(doc.attributes),
|
||||||
}
|
}
|
||||||
|
print("Storing document in CosmosDB:") # noqa: T201
|
||||||
|
print(doc_json) # noqa: T201
|
||||||
self._container_client.upsert_item(doc_json)
|
self._container_client.upsert_item(doc_json)
|
||||||
|
|
||||||
def similarity_search_by_vector(
|
def similarity_search_by_vector(
|
||||||
@ -175,7 +187,7 @@ class CosmosDBVectorStore(BaseVectorStore):
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608
|
query = f"SELECT TOP {k} c.{self.id_field}, c.{self.text_field}, c.{self.vector_field}, c.{self.attributes_field}, VectorDistance(c.{self.vector_field}, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.{self.vector_field}, @embedding)" # noqa: S608
|
||||||
query_params = [{"name": "@embedding", "value": query_embedding}]
|
query_params = [{"name": "@embedding", "value": query_embedding}]
|
||||||
items = list(
|
items = list(
|
||||||
self._container_client.query_items(
|
self._container_client.query_items(
|
||||||
@ -187,7 +199,7 @@ class CosmosDBVectorStore(BaseVectorStore):
|
|||||||
except (CosmosHttpResponseError, ValueError):
|
except (CosmosHttpResponseError, ValueError):
|
||||||
# Currently, the CosmosDB emulator does not support the VectorDistance function.
|
# Currently, the CosmosDB emulator does not support the VectorDistance function.
|
||||||
# For emulator or test environments - fetch all items and calculate distance locally
|
# For emulator or test environments - fetch all items and calculate distance locally
|
||||||
query = "SELECT c.id, c.text, c.vector, c.attributes FROM c"
|
query = f"SELECT c.{self.id_field}, c.{self.text_field}, c.{self.vector_field}, c.{self.attributes_field} FROM c" # noqa: S608
|
||||||
items = list(
|
items = list(
|
||||||
self._container_client.query_items(
|
self._container_client.query_items(
|
||||||
query=query,
|
query=query,
|
||||||
@ -206,7 +218,7 @@ class CosmosDBVectorStore(BaseVectorStore):
|
|||||||
|
|
||||||
# Calculate scores for all items
|
# Calculate scores for all items
|
||||||
for item in items:
|
for item in items:
|
||||||
item_vector = item.get("vector", [])
|
item_vector = item.get(self.vector_field, [])
|
||||||
similarity = cosine_similarity(query_embedding, item_vector)
|
similarity = cosine_similarity(query_embedding, item_vector)
|
||||||
item["SimilarityScore"] = similarity
|
item["SimilarityScore"] = similarity
|
||||||
|
|
||||||
@ -218,10 +230,10 @@ class CosmosDBVectorStore(BaseVectorStore):
|
|||||||
return [
|
return [
|
||||||
VectorStoreSearchResult(
|
VectorStoreSearchResult(
|
||||||
document=VectorStoreDocument(
|
document=VectorStoreDocument(
|
||||||
id=item.get("id", ""),
|
id=item.get(self.id_field, ""),
|
||||||
text=item.get("text", ""),
|
text=item.get(self.text_field, ""),
|
||||||
vector=item.get("vector", []),
|
vector=item.get(self.vector_field, []),
|
||||||
attributes=(json.loads(item.get("attributes", "{}"))),
|
attributes=(json.loads(item.get(self.attributes_field, "{}"))),
|
||||||
),
|
),
|
||||||
score=item.get("SimilarityScore", 0.0),
|
score=item.get("SimilarityScore", 0.0),
|
||||||
)
|
)
|
||||||
@ -248,7 +260,9 @@ class CosmosDBVectorStore(BaseVectorStore):
|
|||||||
id_filter = ", ".join([f"'{id}'" for id in include_ids])
|
id_filter = ", ".join([f"'{id}'" for id in include_ids])
|
||||||
else:
|
else:
|
||||||
id_filter = ", ".join([str(id) for id in include_ids])
|
id_filter = ", ".join([str(id) for id in include_ids])
|
||||||
self.query_filter = f"SELECT * FROM c WHERE c.id IN ({id_filter})" # noqa: S608
|
self.query_filter = (
|
||||||
|
f"SELECT * FROM c WHERE c.{self.id_field} IN ({id_filter})" # noqa: S608
|
||||||
|
)
|
||||||
return self.query_filter
|
return self.query_filter
|
||||||
|
|
||||||
def search_by_id(self, id: str) -> VectorStoreDocument:
|
def search_by_id(self, id: str) -> VectorStoreDocument:
|
||||||
@ -259,10 +273,10 @@ class CosmosDBVectorStore(BaseVectorStore):
|
|||||||
|
|
||||||
item = self._container_client.read_item(item=id, partition_key=id)
|
item = self._container_client.read_item(item=id, partition_key=id)
|
||||||
return VectorStoreDocument(
|
return VectorStoreDocument(
|
||||||
id=item.get("id", ""),
|
id=item.get(self.id_field, ""),
|
||||||
vector=item.get("vector", []),
|
vector=item.get(self.vector_field, []),
|
||||||
text=item.get("text", ""),
|
text=item.get(self.text_field, ""),
|
||||||
attributes=(json.loads(item.get("attributes", "{}"))),
|
attributes=(json.loads(item.get(self.attributes_field, "{}"))),
|
||||||
)
|
)
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
|
|||||||
@ -15,6 +15,9 @@ from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
from graphrag.config.models.vector_store_schema_config import (
|
||||||
|
VectorStoreSchemaConfig,
|
||||||
|
)
|
||||||
from graphrag.vector_stores.base import BaseVectorStore
|
from graphrag.vector_stores.base import BaseVectorStore
|
||||||
|
|
||||||
|
|
||||||
@ -47,7 +50,10 @@ class VectorStoreFactory:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_vector_store(
|
def create_vector_store(
|
||||||
cls, vector_store_type: str, kwargs: dict
|
cls,
|
||||||
|
vector_store_type: str,
|
||||||
|
vector_store_schema_config: VectorStoreSchemaConfig,
|
||||||
|
kwargs: dict,
|
||||||
) -> BaseVectorStore:
|
) -> BaseVectorStore:
|
||||||
"""Create a vector store object from the provided type.
|
"""Create a vector store object from the provided type.
|
||||||
|
|
||||||
@ -67,7 +73,9 @@ class VectorStoreFactory:
|
|||||||
msg = f"Unknown vector store type: {vector_store_type}"
|
msg = f"Unknown vector store type: {vector_store_type}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
return cls._registry[vector_store_type](**kwargs)
|
return cls._registry[vector_store_type](
|
||||||
|
vector_store_schema_config=vector_store_schema_config, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_vector_store_types(cls) -> list[str]:
|
def get_vector_store_types(cls) -> list[str]:
|
||||||
|
|||||||
@ -5,9 +5,9 @@
|
|||||||
|
|
||||||
import json # noqa: I001
|
import json # noqa: I001
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
import numpy as np
|
||||||
|
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||||
from graphrag.data_model.types import TextEmbedder
|
from graphrag.data_model.types import TextEmbedder
|
||||||
|
|
||||||
from graphrag.vector_stores.base import (
|
from graphrag.vector_stores.base import (
|
||||||
@ -21,60 +21,81 @@ import lancedb
|
|||||||
class LanceDBVectorStore(BaseVectorStore):
|
class LanceDBVectorStore(BaseVectorStore):
|
||||||
"""LanceDB vector storage implementation."""
|
"""LanceDB vector storage implementation."""
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(
|
||||||
super().__init__(**kwargs)
|
self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
vector_store_schema_config=vector_store_schema_config, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def connect(self, **kwargs: Any) -> Any:
|
def connect(self, **kwargs: Any) -> Any:
|
||||||
"""Connect to the vector storage."""
|
"""Connect to the vector storage."""
|
||||||
self.db_connection = lancedb.connect(kwargs["db_uri"])
|
self.db_connection = lancedb.connect(kwargs["db_uri"])
|
||||||
if (
|
|
||||||
self.collection_name
|
if self.index_name and self.index_name in self.db_connection.table_names():
|
||||||
and self.collection_name in self.db_connection.table_names()
|
self.document_collection = self.db_connection.open_table(self.index_name)
|
||||||
):
|
|
||||||
self.document_collection = self.db_connection.open_table(
|
|
||||||
self.collection_name
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_documents(
|
def load_documents(
|
||||||
self, documents: list[VectorStoreDocument], overwrite: bool = True
|
self, documents: list[VectorStoreDocument], overwrite: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Load documents into vector storage."""
|
"""Load documents into vector storage."""
|
||||||
data = [
|
# Step 1: Prepare data columns manually
|
||||||
{
|
ids = []
|
||||||
"id": document.id,
|
texts = []
|
||||||
"text": document.text,
|
vectors = []
|
||||||
"vector": document.vector,
|
attributes = []
|
||||||
"attributes": json.dumps(document.attributes),
|
|
||||||
}
|
|
||||||
for document in documents
|
|
||||||
if document.vector is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
if len(data) == 0:
|
for document in documents:
|
||||||
|
self.vector_size = (
|
||||||
|
len(document.vector) if document.vector else self.vector_size
|
||||||
|
)
|
||||||
|
if document.vector is not None and len(document.vector) == self.vector_size:
|
||||||
|
ids.append(document.id)
|
||||||
|
texts.append(document.text)
|
||||||
|
vectors.append(np.array(document.vector, dtype=np.float32))
|
||||||
|
attributes.append(json.dumps(document.attributes))
|
||||||
|
|
||||||
|
# Step 2: Handle empty case
|
||||||
|
if len(ids) == 0:
|
||||||
data = None
|
data = None
|
||||||
|
else:
|
||||||
|
# Step 3: Flatten the vectors and build FixedSizeListArray manually
|
||||||
|
flat_vector = np.concatenate(vectors).astype(np.float32)
|
||||||
|
flat_array = pa.array(flat_vector, type=pa.float32())
|
||||||
|
vector_column = pa.FixedSizeListArray.from_arrays(
|
||||||
|
flat_array, self.vector_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 4: Create PyArrow table (let schema be inferred)
|
||||||
|
data = pa.table({
|
||||||
|
self.id_field: pa.array(ids, type=pa.string()),
|
||||||
|
self.text_field: pa.array(texts, type=pa.string()),
|
||||||
|
self.vector_field: vector_column,
|
||||||
|
self.attributes_field: pa.array(attributes, type=pa.string()),
|
||||||
|
})
|
||||||
|
|
||||||
schema = pa.schema([
|
|
||||||
pa.field("id", pa.string()),
|
|
||||||
pa.field("text", pa.string()),
|
|
||||||
pa.field("vector", pa.list_(pa.float64())),
|
|
||||||
pa.field("attributes", pa.string()),
|
|
||||||
])
|
|
||||||
# NOTE: If modifying the next section of code, ensure that the schema remains the same.
|
# NOTE: If modifying the next section of code, ensure that the schema remains the same.
|
||||||
# The pyarrow format of the 'vector' field may change if the order of operations is changed
|
# The pyarrow format of the 'vector' field may change if the order of operations is changed
|
||||||
# and will break vector search.
|
# and will break vector search.
|
||||||
if overwrite:
|
if overwrite:
|
||||||
if data:
|
if data:
|
||||||
self.document_collection = self.db_connection.create_table(
|
self.document_collection = self.db_connection.create_table(
|
||||||
self.collection_name, data=data, mode="overwrite"
|
self.index_name if self.index_name else "",
|
||||||
|
data=data,
|
||||||
|
mode="overwrite",
|
||||||
|
schema=data.schema,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.document_collection = self.db_connection.create_table(
|
self.document_collection = self.db_connection.create_table(
|
||||||
self.collection_name, schema=schema, mode="overwrite"
|
self.index_name if self.index_name else "", mode="overwrite"
|
||||||
)
|
)
|
||||||
|
self.document_collection.create_index(
|
||||||
|
vector_column_name=self.vector_field, index_type="IVF_FLAT"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# add data to existing table
|
# add data to existing table
|
||||||
self.document_collection = self.db_connection.open_table(
|
self.document_collection = self.db_connection.open_table(
|
||||||
self.collection_name
|
self.index_name if self.index_name else ""
|
||||||
)
|
)
|
||||||
if data:
|
if data:
|
||||||
self.document_collection.add(data)
|
self.document_collection.add(data)
|
||||||
@ -86,30 +107,32 @@ class LanceDBVectorStore(BaseVectorStore):
|
|||||||
else:
|
else:
|
||||||
if isinstance(include_ids[0], str):
|
if isinstance(include_ids[0], str):
|
||||||
id_filter = ", ".join([f"'{id}'" for id in include_ids])
|
id_filter = ", ".join([f"'{id}'" for id in include_ids])
|
||||||
self.query_filter = f"id in ({id_filter})"
|
self.query_filter = f"{self.id_field} in ({id_filter})"
|
||||||
else:
|
else:
|
||||||
self.query_filter = (
|
self.query_filter = (
|
||||||
f"id in ({', '.join([str(id) for id in include_ids])})"
|
f"{self.id_field} in ({', '.join([str(id) for id in include_ids])})"
|
||||||
)
|
)
|
||||||
return self.query_filter
|
return self.query_filter
|
||||||
|
|
||||||
def similarity_search_by_vector(
|
def similarity_search_by_vector(
|
||||||
self, query_embedding: list[float], k: int = 10, **kwargs: Any
|
self, query_embedding: list[float] | np.ndarray, k: int = 10, **kwargs: Any
|
||||||
) -> list[VectorStoreSearchResult]:
|
) -> list[VectorStoreSearchResult]:
|
||||||
"""Perform a vector-based similarity search."""
|
"""Perform a vector-based similarity search."""
|
||||||
if self.query_filter:
|
if self.query_filter:
|
||||||
docs = (
|
docs = (
|
||||||
self.document_collection.search(
|
self.document_collection.search(
|
||||||
query=query_embedding, vector_column_name="vector"
|
query=query_embedding, vector_column_name=self.vector_field
|
||||||
)
|
)
|
||||||
.where(self.query_filter, prefilter=True)
|
.where(self.query_filter, prefilter=True)
|
||||||
.limit(k)
|
.limit(k)
|
||||||
.to_list()
|
.to_list()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
query_embedding = np.array(query_embedding, dtype=np.float32)
|
||||||
|
|
||||||
docs = (
|
docs = (
|
||||||
self.document_collection.search(
|
self.document_collection.search(
|
||||||
query=query_embedding, vector_column_name="vector"
|
query=query_embedding, vector_column_name=self.vector_field
|
||||||
)
|
)
|
||||||
.limit(k)
|
.limit(k)
|
||||||
.to_list()
|
.to_list()
|
||||||
@ -117,10 +140,10 @@ class LanceDBVectorStore(BaseVectorStore):
|
|||||||
return [
|
return [
|
||||||
VectorStoreSearchResult(
|
VectorStoreSearchResult(
|
||||||
document=VectorStoreDocument(
|
document=VectorStoreDocument(
|
||||||
id=doc["id"],
|
id=doc[self.id_field],
|
||||||
text=doc["text"],
|
text=doc[self.text_field],
|
||||||
vector=doc["vector"],
|
vector=doc[self.vector_field],
|
||||||
attributes=json.loads(doc["attributes"]),
|
attributes=json.loads(doc[self.attributes_field]),
|
||||||
),
|
),
|
||||||
score=1 - abs(float(doc["_distance"])),
|
score=1 - abs(float(doc["_distance"])),
|
||||||
)
|
)
|
||||||
@ -140,14 +163,14 @@ class LanceDBVectorStore(BaseVectorStore):
|
|||||||
"""Search for a document by id."""
|
"""Search for a document by id."""
|
||||||
doc = (
|
doc = (
|
||||||
self.document_collection.search()
|
self.document_collection.search()
|
||||||
.where(f"id == '{id}'", prefilter=True)
|
.where(f"{self.id_field} == '{id}'", prefilter=True)
|
||||||
.to_list()
|
.to_list()
|
||||||
)
|
)
|
||||||
if doc:
|
if doc:
|
||||||
return VectorStoreDocument(
|
return VectorStoreDocument(
|
||||||
id=doc[0]["id"],
|
id=doc[0][self.id_field],
|
||||||
text=doc[0]["text"],
|
text=doc[0][self.text_field],
|
||||||
vector=doc[0]["vector"],
|
vector=doc[0][self.vector_field],
|
||||||
attributes=json.loads(doc[0]["attributes"]),
|
attributes=json.loads(doc[0][self.attributes_field]),
|
||||||
)
|
)
|
||||||
return VectorStoreDocument(id=id, text=None, vector=None)
|
return VectorStoreDocument(id=id, text=None, vector=None)
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||||
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
|
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
|
||||||
from graphrag.vector_stores.base import VectorStoreDocument
|
from graphrag.vector_stores.base import VectorStoreDocument
|
||||||
|
|
||||||
@ -39,7 +40,35 @@ class TestAzureAISearchVectorStore:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def vector_store(self, mock_search_client, mock_index_client):
|
def vector_store(self, mock_search_client, mock_index_client):
|
||||||
"""Create an Azure AI Search vector store instance."""
|
"""Create an Azure AI Search vector store instance."""
|
||||||
vector_store = AzureAISearchVectorStore(collection_name="test_vectors")
|
vector_store = AzureAISearchVectorStore(
|
||||||
|
vector_store_schema_config=VectorStoreSchemaConfig(
|
||||||
|
index_name="test_vectors", vector_size=5
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the necessary mocks first
|
||||||
|
vector_store.db_connection = mock_search_client
|
||||||
|
vector_store.index_client = mock_index_client
|
||||||
|
|
||||||
|
vector_store.connect(
|
||||||
|
url=TEST_AZURE_AI_SEARCH_URL,
|
||||||
|
api_key=TEST_AZURE_AI_SEARCH_KEY,
|
||||||
|
)
|
||||||
|
return vector_store
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vector_store_custom(self, mock_search_client, mock_index_client):
|
||||||
|
"""Create an Azure AI Search vector store instance."""
|
||||||
|
vector_store = AzureAISearchVectorStore(
|
||||||
|
vector_store_schema_config=VectorStoreSchemaConfig(
|
||||||
|
index_name="test_vectors",
|
||||||
|
id_field="id_custom",
|
||||||
|
text_field="text_custom",
|
||||||
|
attributes_field="attributes_custom",
|
||||||
|
vector_field="vector_custom",
|
||||||
|
vector_size=5,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# Create the necessary mocks first
|
# Create the necessary mocks first
|
||||||
vector_store.db_connection = mock_search_client
|
vector_store.db_connection = mock_search_client
|
||||||
@ -48,7 +77,6 @@ class TestAzureAISearchVectorStore:
|
|||||||
vector_store.connect(
|
vector_store.connect(
|
||||||
url=TEST_AZURE_AI_SEARCH_URL,
|
url=TEST_AZURE_AI_SEARCH_URL,
|
||||||
api_key=TEST_AZURE_AI_SEARCH_KEY,
|
api_key=TEST_AZURE_AI_SEARCH_KEY,
|
||||||
vector_size=5,
|
|
||||||
)
|
)
|
||||||
return vector_store
|
return vector_store
|
||||||
|
|
||||||
@ -144,3 +172,72 @@ class TestAzureAISearchVectorStore:
|
|||||||
)
|
)
|
||||||
assert not mock_search_client.search.called
|
assert not mock_search_client.search.called
|
||||||
assert len(results) == 0
|
assert len(results) == 0
|
||||||
|
|
||||||
|
async def test_vector_store_customization(
|
||||||
|
self,
|
||||||
|
vector_store_custom,
|
||||||
|
sample_documents,
|
||||||
|
mock_search_client,
|
||||||
|
mock_index_client,
|
||||||
|
):
|
||||||
|
"""Test vector store customization with Azure AI Search."""
|
||||||
|
# Setup mock responses
|
||||||
|
mock_index_client.list_index_names.return_value = []
|
||||||
|
mock_index_client.create_or_update_index = MagicMock()
|
||||||
|
mock_search_client.upload_documents = MagicMock()
|
||||||
|
|
||||||
|
search_results = [
|
||||||
|
{
|
||||||
|
vector_store_custom.id_field: "doc1",
|
||||||
|
vector_store_custom.text_field: "This is document 1",
|
||||||
|
vector_store_custom.vector_field: [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||||
|
vector_store_custom.attributes_field: '{"title": "Doc 1", "category": "test"}',
|
||||||
|
"@search.score": 0.9,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
vector_store_custom.id_field: "doc2",
|
||||||
|
vector_store_custom.text_field: "This is document 2",
|
||||||
|
vector_store_custom.vector_field: [0.2, 0.3, 0.4, 0.5, 0.6],
|
||||||
|
vector_store_custom.attributes_field: '{"title": "Doc 2", "category": "test"}',
|
||||||
|
"@search.score": 0.8,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
mock_search_client.search.return_value = search_results
|
||||||
|
|
||||||
|
mock_search_client.get_document.return_value = {
|
||||||
|
vector_store_custom.id_field: "doc1",
|
||||||
|
vector_store_custom.text_field: "This is document 1",
|
||||||
|
vector_store_custom.vector_field: [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||||
|
vector_store_custom.attributes_field: '{"title": "Doc 1", "category": "test"}',
|
||||||
|
}
|
||||||
|
|
||||||
|
vector_store_custom.load_documents(sample_documents)
|
||||||
|
assert mock_index_client.create_or_update_index.called
|
||||||
|
assert mock_search_client.upload_documents.called
|
||||||
|
|
||||||
|
filter_query = vector_store_custom.filter_by_id(["doc1", "doc2"])
|
||||||
|
assert (
|
||||||
|
filter_query
|
||||||
|
== f"search.in({vector_store_custom.id_field}, 'doc1,doc2', ',')"
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_results = vector_store_custom.similarity_search_by_vector(
|
||||||
|
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
|
||||||
|
)
|
||||||
|
assert len(vector_results) == 2
|
||||||
|
assert vector_results[0].document.id == "doc1"
|
||||||
|
assert vector_results[0].score == 0.9
|
||||||
|
|
||||||
|
# Define a simple text embedder function for testing
|
||||||
|
def mock_embedder(text: str) -> list[float]:
|
||||||
|
return [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||||
|
|
||||||
|
text_results = vector_store_custom.similarity_search_by_text(
|
||||||
|
"test query", mock_embedder, k=2
|
||||||
|
)
|
||||||
|
assert len(text_results) == 2
|
||||||
|
|
||||||
|
doc = vector_store_custom.search_by_id("doc1")
|
||||||
|
assert doc.id == "doc1"
|
||||||
|
assert doc.text == "This is document 1"
|
||||||
|
assert doc.attributes["title"] == "Doc 1"
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import sys
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||||
from graphrag.vector_stores.base import VectorStoreDocument
|
from graphrag.vector_stores.base import VectorStoreDocument
|
||||||
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
|
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
|
||||||
|
|
||||||
@ -24,7 +25,7 @@ if not sys.platform.startswith("win"):
|
|||||||
def test_vector_store_operations():
|
def test_vector_store_operations():
|
||||||
"""Test basic vector store operations with CosmosDB."""
|
"""Test basic vector store operations with CosmosDB."""
|
||||||
vector_store = CosmosDBVectorStore(
|
vector_store = CosmosDBVectorStore(
|
||||||
collection_name="testvector",
|
vector_store_schema_config=VectorStoreSchemaConfig(index_name="testvector"),
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -78,7 +79,7 @@ def test_vector_store_operations():
|
|||||||
def test_clear():
|
def test_clear():
|
||||||
"""Test clearing the vector store."""
|
"""Test clearing the vector store."""
|
||||||
vector_store = CosmosDBVectorStore(
|
vector_store = CosmosDBVectorStore(
|
||||||
collection_name="testclear",
|
vector_store_schema_config=VectorStoreSchemaConfig(index_name="testclear"),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
vector_store.connect(
|
vector_store.connect(
|
||||||
@ -102,3 +103,64 @@ def test_clear():
|
|||||||
assert vector_store._database_exists() is False # noqa: SLF001
|
assert vector_store._database_exists() is False # noqa: SLF001
|
||||||
finally:
|
finally:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_vector_store_customization():
|
||||||
|
"""Test vector store customization with CosmosDB."""
|
||||||
|
vector_store = CosmosDBVectorStore(
|
||||||
|
vector_store_schema_config=VectorStoreSchemaConfig(
|
||||||
|
index_name="text-embeddings",
|
||||||
|
id_field="id",
|
||||||
|
text_field="text_custom",
|
||||||
|
vector_field="vector_custom",
|
||||||
|
attributes_field="attributes_custom",
|
||||||
|
vector_size=5,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
vector_store.connect(
|
||||||
|
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
|
||||||
|
database_name="test_db",
|
||||||
|
)
|
||||||
|
|
||||||
|
docs = [
|
||||||
|
VectorStoreDocument(
|
||||||
|
id="doc1",
|
||||||
|
text="This is document 1",
|
||||||
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||||
|
attributes={"title": "Doc 1", "category": "test"},
|
||||||
|
),
|
||||||
|
VectorStoreDocument(
|
||||||
|
id="doc2",
|
||||||
|
text="This is document 2",
|
||||||
|
vector=[0.2, 0.3, 0.4, 0.5, 0.6],
|
||||||
|
attributes={"title": "Doc 2", "category": "test"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
vector_store.load_documents(docs)
|
||||||
|
|
||||||
|
vector_store.filter_by_id(["doc1"])
|
||||||
|
|
||||||
|
doc = vector_store.search_by_id("doc1")
|
||||||
|
assert doc.id == "doc1"
|
||||||
|
assert doc.text == "This is document 1"
|
||||||
|
assert doc.vector is not None
|
||||||
|
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||||
|
assert doc.attributes["title"] == "Doc 1"
|
||||||
|
|
||||||
|
# Define a simple text embedder function for testing
|
||||||
|
def mock_embedder(text: str) -> list[float]:
|
||||||
|
return [0.1, 0.2, 0.3, 0.4, 0.5] # Return fixed embedding
|
||||||
|
|
||||||
|
vector_results = vector_store.similarity_search_by_vector(
|
||||||
|
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
|
||||||
|
)
|
||||||
|
assert len(vector_results) > 0
|
||||||
|
|
||||||
|
text_results = vector_store.similarity_search_by_text(
|
||||||
|
"test query", mock_embedder, k=2
|
||||||
|
)
|
||||||
|
assert len(text_results) > 0
|
||||||
|
finally:
|
||||||
|
vector_store.clear()
|
||||||
|
|||||||
@ -8,6 +8,7 @@ These tests will test the VectorStoreFactory class and the creation of each vect
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from graphrag.config.enums import VectorStoreType
|
from graphrag.config.enums import VectorStoreType
|
||||||
|
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||||
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
|
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
|
||||||
from graphrag.vector_stores.base import BaseVectorStore
|
from graphrag.vector_stores.base import BaseVectorStore
|
||||||
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
|
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
|
||||||
@ -17,25 +18,31 @@ from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
|||||||
|
|
||||||
def test_create_lancedb_vector_store():
|
def test_create_lancedb_vector_store():
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"collection_name": "test_collection",
|
|
||||||
"db_uri": "/tmp/lancedb",
|
"db_uri": "/tmp/lancedb",
|
||||||
}
|
}
|
||||||
vector_store = VectorStoreFactory.create_vector_store(
|
vector_store = VectorStoreFactory.create_vector_store(
|
||||||
VectorStoreType.LanceDB.value, kwargs
|
vector_store_type=VectorStoreType.LanceDB.value,
|
||||||
|
vector_store_schema_config=VectorStoreSchemaConfig(
|
||||||
|
index_name="test_collection"
|
||||||
|
),
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
assert isinstance(vector_store, LanceDBVectorStore)
|
assert isinstance(vector_store, LanceDBVectorStore)
|
||||||
assert vector_store.collection_name == "test_collection"
|
assert vector_store.index_name == "test_collection"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Azure AI Search requires credentials and setup")
|
@pytest.mark.skip(reason="Azure AI Search requires credentials and setup")
|
||||||
def test_create_azure_ai_search_vector_store():
|
def test_create_azure_ai_search_vector_store():
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"collection_name": "test_collection",
|
|
||||||
"url": "https://test.search.windows.net",
|
"url": "https://test.search.windows.net",
|
||||||
"api_key": "test_key",
|
"api_key": "test_key",
|
||||||
}
|
}
|
||||||
vector_store = VectorStoreFactory.create_vector_store(
|
vector_store = VectorStoreFactory.create_vector_store(
|
||||||
VectorStoreType.AzureAISearch.value, kwargs
|
vector_store_type=VectorStoreType.AzureAISearch.value,
|
||||||
|
vector_store_schema_config=VectorStoreSchemaConfig(
|
||||||
|
index_name="test_collection"
|
||||||
|
),
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
assert isinstance(vector_store, AzureAISearchVectorStore)
|
assert isinstance(vector_store, AzureAISearchVectorStore)
|
||||||
|
|
||||||
@ -43,13 +50,18 @@ def test_create_azure_ai_search_vector_store():
|
|||||||
@pytest.mark.skip(reason="CosmosDB requires credentials and setup")
|
@pytest.mark.skip(reason="CosmosDB requires credentials and setup")
|
||||||
def test_create_cosmosdb_vector_store():
|
def test_create_cosmosdb_vector_store():
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"collection_name": "test_collection",
|
|
||||||
"connection_string": "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=test_key==",
|
"connection_string": "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=test_key==",
|
||||||
"database_name": "test_db",
|
"database_name": "test_db",
|
||||||
}
|
}
|
||||||
|
|
||||||
vector_store = VectorStoreFactory.create_vector_store(
|
vector_store = VectorStoreFactory.create_vector_store(
|
||||||
VectorStoreType.CosmosDB.value, kwargs
|
vector_store_type=VectorStoreType.CosmosDB.value,
|
||||||
|
vector_store_schema_config=VectorStoreSchemaConfig(
|
||||||
|
index_name="test_collection"
|
||||||
|
),
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(vector_store, CosmosDBVectorStore)
|
assert isinstance(vector_store, CosmosDBVectorStore)
|
||||||
|
|
||||||
|
|
||||||
@ -67,7 +79,12 @@ def test_register_and_create_custom_vector_store():
|
|||||||
VectorStoreFactory.register(
|
VectorStoreFactory.register(
|
||||||
"custom", lambda **kwargs: custom_vector_store_class(**kwargs)
|
"custom", lambda **kwargs: custom_vector_store_class(**kwargs)
|
||||||
)
|
)
|
||||||
vector_store = VectorStoreFactory.create_vector_store("custom", {})
|
|
||||||
|
vector_store = VectorStoreFactory.create_vector_store(
|
||||||
|
vector_store_type="custom",
|
||||||
|
vector_store_schema_config=VectorStoreSchemaConfig(),
|
||||||
|
kwargs={},
|
||||||
|
)
|
||||||
|
|
||||||
assert custom_vector_store_class.called
|
assert custom_vector_store_class.called
|
||||||
assert vector_store is instance
|
assert vector_store is instance
|
||||||
@ -89,7 +106,11 @@ def test_get_vector_store_types():
|
|||||||
|
|
||||||
def test_create_unknown_vector_store():
|
def test_create_unknown_vector_store():
|
||||||
with pytest.raises(ValueError, match="Unknown vector store type: unknown"):
|
with pytest.raises(ValueError, match="Unknown vector store type: unknown"):
|
||||||
VectorStoreFactory.create_vector_store("unknown", {})
|
VectorStoreFactory.create_vector_store(
|
||||||
|
vector_store_type="unknown",
|
||||||
|
vector_store_schema_config=VectorStoreSchemaConfig(),
|
||||||
|
kwargs={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_is_supported_type():
|
def test_is_supported_type():
|
||||||
@ -139,6 +160,9 @@ def test_register_class_directly_works():
|
|||||||
|
|
||||||
# Test creating an instance
|
# Test creating an instance
|
||||||
vector_store = VectorStoreFactory.create_vector_store(
|
vector_store = VectorStoreFactory.create_vector_store(
|
||||||
"custom_class", {"collection_name": "test"}
|
vector_store_type="custom_class",
|
||||||
|
vector_store_schema_config=VectorStoreSchemaConfig(),
|
||||||
|
kwargs={"collection_name": "test"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(vector_store, CustomVectorStore)
|
assert isinstance(vector_store, CustomVectorStore)
|
||||||
|
|||||||
@ -7,20 +7,20 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||||
from graphrag.vector_stores.base import VectorStoreDocument
|
from graphrag.vector_stores.base import VectorStoreDocument
|
||||||
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||||
|
|
||||||
|
|
||||||
def test_vector_store_operations():
|
class TestLanceDBVectorStore:
|
||||||
"""Test basic vector store operations with LanceDB."""
|
"""Test class for TestLanceDBVectorStore."""
|
||||||
# Create a temporary directory for the test database
|
|
||||||
temp_dir = tempfile.mkdtemp()
|
|
||||||
try:
|
|
||||||
vector_store = LanceDBVectorStore(collection_name="test_collection")
|
|
||||||
vector_store.connect(db_uri=temp_dir)
|
|
||||||
|
|
||||||
docs = [
|
@pytest.fixture
|
||||||
|
def sample_documents(self):
|
||||||
|
"""Create sample documents for testing."""
|
||||||
|
return [
|
||||||
VectorStoreDocument(
|
VectorStoreDocument(
|
||||||
id="1",
|
id="1",
|
||||||
text="This is document 1",
|
text="This is document 1",
|
||||||
@ -40,102 +40,11 @@ def test_vector_store_operations():
|
|||||||
attributes={"title": "Doc 3", "category": "test"},
|
attributes={"title": "Doc 3", "category": "test"},
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
vector_store.load_documents(docs[:2])
|
|
||||||
|
|
||||||
assert vector_store.collection_name in vector_store.db_connection.table_names()
|
@pytest.fixture
|
||||||
|
def sample_documents_categories(self):
|
||||||
doc = vector_store.search_by_id("1")
|
"""Create sample documents with different categories for testing."""
|
||||||
assert doc.id == "1"
|
return [
|
||||||
assert doc.text == "This is document 1"
|
|
||||||
|
|
||||||
assert doc.vector is not None
|
|
||||||
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
|
|
||||||
assert doc.attributes["title"] == "Doc 1"
|
|
||||||
|
|
||||||
filter_query = vector_store.filter_by_id(["1"])
|
|
||||||
assert filter_query == "id in ('1')"
|
|
||||||
|
|
||||||
results = vector_store.similarity_search_by_vector(
|
|
||||||
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
|
|
||||||
)
|
|
||||||
assert 1 <= len(results) <= 2
|
|
||||||
assert isinstance(results[0].score, float)
|
|
||||||
|
|
||||||
# Test append mode
|
|
||||||
vector_store.load_documents([docs[2]], overwrite=False)
|
|
||||||
result = vector_store.search_by_id("3")
|
|
||||||
assert result.id == "3"
|
|
||||||
assert result.text == "This is document 3"
|
|
||||||
|
|
||||||
# Define a simple text embedder function for testing
|
|
||||||
def mock_embedder(text: str) -> list[float]:
|
|
||||||
return [0.1, 0.2, 0.3, 0.4, 0.5]
|
|
||||||
|
|
||||||
text_results = vector_store.similarity_search_by_text(
|
|
||||||
"test query", mock_embedder, k=2
|
|
||||||
)
|
|
||||||
assert 1 <= len(text_results) <= 2
|
|
||||||
assert isinstance(text_results[0].score, float)
|
|
||||||
|
|
||||||
# Test non-existent document
|
|
||||||
non_existent = vector_store.search_by_id("nonexistent")
|
|
||||||
assert non_existent.id == "nonexistent"
|
|
||||||
assert non_existent.text is None
|
|
||||||
assert non_existent.vector is None
|
|
||||||
finally:
|
|
||||||
shutil.rmtree(temp_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def test_empty_collection():
|
|
||||||
"""Test creating an empty collection."""
|
|
||||||
# Create a temporary directory for the test database
|
|
||||||
temp_dir = tempfile.mkdtemp()
|
|
||||||
try:
|
|
||||||
vector_store = LanceDBVectorStore(collection_name="empty_collection")
|
|
||||||
vector_store.connect(db_uri=temp_dir)
|
|
||||||
|
|
||||||
# Load the vector store with a document, then delete it
|
|
||||||
sample_doc = VectorStoreDocument(
|
|
||||||
id="tmp",
|
|
||||||
text="Temporary document to create schema",
|
|
||||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
|
|
||||||
attributes={"title": "Tmp"},
|
|
||||||
)
|
|
||||||
vector_store.load_documents([sample_doc])
|
|
||||||
vector_store.db_connection.open_table(vector_store.collection_name).delete(
|
|
||||||
"id = 'tmp'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should still have the collection
|
|
||||||
assert vector_store.collection_name in vector_store.db_connection.table_names()
|
|
||||||
|
|
||||||
# Add a document after creating an empty collection
|
|
||||||
doc = VectorStoreDocument(
|
|
||||||
id="1",
|
|
||||||
text="This is document 1",
|
|
||||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
|
|
||||||
attributes={"title": "Doc 1"},
|
|
||||||
)
|
|
||||||
vector_store.load_documents([doc], overwrite=False)
|
|
||||||
|
|
||||||
result = vector_store.search_by_id("1")
|
|
||||||
assert result.id == "1"
|
|
||||||
assert result.text == "This is document 1"
|
|
||||||
finally:
|
|
||||||
# Clean up - remove the temporary directory
|
|
||||||
shutil.rmtree(temp_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def test_filter_search():
|
|
||||||
"""Test filtered search with LanceDB."""
|
|
||||||
# Create a temporary directory for the test database
|
|
||||||
temp_dir = tempfile.mkdtemp()
|
|
||||||
try:
|
|
||||||
vector_store = LanceDBVectorStore(collection_name="filter_collection")
|
|
||||||
vector_store.connect(db_uri=temp_dir)
|
|
||||||
|
|
||||||
# Create test documents with different categories
|
|
||||||
docs = [
|
|
||||||
VectorStoreDocument(
|
VectorStoreDocument(
|
||||||
id="1",
|
id="1",
|
||||||
text="Document about cats",
|
text="Document about cats",
|
||||||
@ -155,18 +64,201 @@ def test_filter_search():
|
|||||||
attributes={"category": "vehicles"},
|
attributes={"category": "vehicles"},
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
vector_store.load_documents(docs)
|
|
||||||
|
|
||||||
# Filter to include only documents about animals
|
def test_vector_store_operations(self, sample_documents):
|
||||||
vector_store.filter_by_id(["1", "2"])
|
"""Test basic vector store operations with LanceDB."""
|
||||||
results = vector_store.similarity_search_by_vector(
|
# Create a temporary directory for the test database
|
||||||
[0.1, 0.2, 0.3, 0.4, 0.5], k=3
|
temp_dir = tempfile.mkdtemp()
|
||||||
)
|
try:
|
||||||
|
vector_store = LanceDBVectorStore(
|
||||||
|
vector_store_schema_config=VectorStoreSchemaConfig(
|
||||||
|
index_name="test_collection", vector_size=5
|
||||||
|
)
|
||||||
|
)
|
||||||
|
vector_store.connect(db_uri=temp_dir)
|
||||||
|
vector_store.load_documents(sample_documents[:2])
|
||||||
|
|
||||||
# Should return at most 2 documents (the filtered ones)
|
if vector_store.index_name:
|
||||||
assert len(results) <= 2
|
assert (
|
||||||
ids = [result.document.id for result in results]
|
vector_store.index_name in vector_store.db_connection.table_names()
|
||||||
assert "3" not in ids
|
)
|
||||||
assert set(ids).issubset({"1", "2"})
|
|
||||||
finally:
|
doc = vector_store.search_by_id("1")
|
||||||
shutil.rmtree(temp_dir)
|
assert doc.id == "1"
|
||||||
|
assert doc.text == "This is document 1"
|
||||||
|
|
||||||
|
assert doc.vector is not None
|
||||||
|
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||||
|
assert doc.attributes["title"] == "Doc 1"
|
||||||
|
|
||||||
|
filter_query = vector_store.filter_by_id(["1"])
|
||||||
|
assert filter_query == "id in ('1')"
|
||||||
|
|
||||||
|
results = vector_store.similarity_search_by_vector(
|
||||||
|
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
|
||||||
|
)
|
||||||
|
assert 1 <= len(results) <= 2
|
||||||
|
assert isinstance(results[0].score, float)
|
||||||
|
|
||||||
|
# Test append mode
|
||||||
|
vector_store.load_documents([sample_documents[2]], overwrite=False)
|
||||||
|
result = vector_store.search_by_id("3")
|
||||||
|
assert result.id == "3"
|
||||||
|
assert result.text == "This is document 3"
|
||||||
|
|
||||||
|
# Define a simple text embedder function for testing
|
||||||
|
def mock_embedder(text: str) -> list[float]:
|
||||||
|
return [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||||
|
|
||||||
|
text_results = vector_store.similarity_search_by_text(
|
||||||
|
"test query", mock_embedder, k=2
|
||||||
|
)
|
||||||
|
assert 1 <= len(text_results) <= 2
|
||||||
|
assert isinstance(text_results[0].score, float)
|
||||||
|
|
||||||
|
# Test non-existent document
|
||||||
|
non_existent = vector_store.search_by_id("nonexistent")
|
||||||
|
assert non_existent.id == "nonexistent"
|
||||||
|
assert non_existent.text is None
|
||||||
|
assert non_existent.vector is None
|
||||||
|
finally:
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
|
||||||
|
def test_empty_collection(self):
|
||||||
|
"""Test creating an empty collection."""
|
||||||
|
# Create a temporary directory for the test database
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
try:
|
||||||
|
vector_store = LanceDBVectorStore(
|
||||||
|
vector_store_schema_config=VectorStoreSchemaConfig(
|
||||||
|
index_name="empty_collection", vector_size=5
|
||||||
|
)
|
||||||
|
)
|
||||||
|
vector_store.connect(db_uri=temp_dir)
|
||||||
|
|
||||||
|
# Load the vector store with a document, then delete it
|
||||||
|
sample_doc = VectorStoreDocument(
|
||||||
|
id="tmp",
|
||||||
|
text="Temporary document to create schema",
|
||||||
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||||
|
attributes={"title": "Tmp"},
|
||||||
|
)
|
||||||
|
vector_store.load_documents([sample_doc])
|
||||||
|
vector_store.db_connection.open_table(
|
||||||
|
vector_store.index_name if vector_store.index_name else ""
|
||||||
|
).delete("id = 'tmp'")
|
||||||
|
|
||||||
|
# Should still have the collection
|
||||||
|
if vector_store.index_name:
|
||||||
|
assert (
|
||||||
|
vector_store.index_name in vector_store.db_connection.table_names()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a document after creating an empty collection
|
||||||
|
doc = VectorStoreDocument(
|
||||||
|
id="1",
|
||||||
|
text="This is document 1",
|
||||||
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||||
|
attributes={"title": "Doc 1"},
|
||||||
|
)
|
||||||
|
vector_store.load_documents([doc], overwrite=False)
|
||||||
|
|
||||||
|
result = vector_store.search_by_id("1")
|
||||||
|
assert result.id == "1"
|
||||||
|
assert result.text == "This is document 1"
|
||||||
|
finally:
|
||||||
|
# Clean up - remove the temporary directory
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
|
||||||
|
def test_filter_search(self, sample_documents_categories):
|
||||||
|
"""Test filtered search with LanceDB."""
|
||||||
|
# Create a temporary directory for the test database
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
try:
|
||||||
|
vector_store = LanceDBVectorStore(
|
||||||
|
vector_store_schema_config=VectorStoreSchemaConfig(
|
||||||
|
index_name="filter_collection", vector_size=5
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_store.connect(db_uri=temp_dir)
|
||||||
|
|
||||||
|
vector_store.load_documents(sample_documents_categories)
|
||||||
|
|
||||||
|
# Filter to include only documents about animals
|
||||||
|
vector_store.filter_by_id(["1", "2"])
|
||||||
|
results = vector_store.similarity_search_by_vector(
|
||||||
|
[0.1, 0.2, 0.3, 0.4, 0.5], k=3
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return at most 2 documents (the filtered ones)
|
||||||
|
assert len(results) <= 2
|
||||||
|
ids = [result.document.id for result in results]
|
||||||
|
assert "3" not in ids
|
||||||
|
assert set(ids).issubset({"1", "2"})
|
||||||
|
finally:
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
|
||||||
|
def test_vector_store_customization(self, sample_documents):
|
||||||
|
"""Test vector store customization with LanceDB."""
|
||||||
|
# Create a temporary directory for the test database
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
try:
|
||||||
|
vector_store = LanceDBVectorStore(
|
||||||
|
vector_store_schema_config=VectorStoreSchemaConfig(
|
||||||
|
index_name="text-embeddings",
|
||||||
|
id_field="id_custom",
|
||||||
|
text_field="text_custom",
|
||||||
|
vector_field="vector_custom",
|
||||||
|
attributes_field="attributes_custom",
|
||||||
|
vector_size=5,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
vector_store.connect(db_uri=temp_dir)
|
||||||
|
vector_store.load_documents(sample_documents[:2])
|
||||||
|
|
||||||
|
if vector_store.index_name:
|
||||||
|
assert (
|
||||||
|
vector_store.index_name in vector_store.db_connection.table_names()
|
||||||
|
)
|
||||||
|
|
||||||
|
doc = vector_store.search_by_id("1")
|
||||||
|
assert doc.id == "1"
|
||||||
|
assert doc.text == "This is document 1"
|
||||||
|
|
||||||
|
assert doc.vector is not None
|
||||||
|
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||||
|
assert doc.attributes["title"] == "Doc 1"
|
||||||
|
|
||||||
|
filter_query = vector_store.filter_by_id(["1"])
|
||||||
|
assert filter_query == f"{vector_store.id_field} in ('1')"
|
||||||
|
|
||||||
|
results = vector_store.similarity_search_by_vector(
|
||||||
|
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
|
||||||
|
)
|
||||||
|
assert 1 <= len(results) <= 2
|
||||||
|
assert isinstance(results[0].score, float)
|
||||||
|
|
||||||
|
# Test append mode
|
||||||
|
vector_store.load_documents([sample_documents[2]], overwrite=False)
|
||||||
|
result = vector_store.search_by_id("3")
|
||||||
|
assert result.id == "3"
|
||||||
|
assert result.text == "This is document 3"
|
||||||
|
|
||||||
|
# Define a simple text embedder function for testing
|
||||||
|
def mock_embedder(text: str) -> list[float]:
|
||||||
|
return [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||||
|
|
||||||
|
text_results = vector_store.similarity_search_by_text(
|
||||||
|
"test query", mock_embedder, k=2
|
||||||
|
)
|
||||||
|
assert 1 <= len(text_results) <= 2
|
||||||
|
assert isinstance(text_results[0].score, float)
|
||||||
|
|
||||||
|
# Test non-existent document
|
||||||
|
non_existent = vector_store.search_by_id("nonexistent")
|
||||||
|
assert non_existent.id == "nonexistent"
|
||||||
|
assert non_existent.text is None
|
||||||
|
assert non_existent.vector is None
|
||||||
|
finally:
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||||
from graphrag.data_model.entity import Entity
|
from graphrag.data_model.entity import Entity
|
||||||
from graphrag.data_model.types import TextEmbedder
|
from graphrag.data_model.types import TextEmbedder
|
||||||
from graphrag.language_model.manager import ModelManager
|
from graphrag.language_model.manager import ModelManager
|
||||||
@ -19,7 +20,9 @@ from graphrag.vector_stores.base import (
|
|||||||
|
|
||||||
class MockBaseVectorStore(BaseVectorStore):
|
class MockBaseVectorStore(BaseVectorStore):
|
||||||
def __init__(self, documents: list[VectorStoreDocument]) -> None:
|
def __init__(self, documents: list[VectorStoreDocument]) -> None:
|
||||||
super().__init__("mock")
|
super().__init__(
|
||||||
|
vector_store_schema_config=VectorStoreSchemaConfig(index_name="mock")
|
||||||
|
)
|
||||||
self.documents = documents
|
self.documents = documents
|
||||||
|
|
||||||
def connect(self, **kwargs: Any) -> None:
|
def connect(self, **kwargs: Any) -> None:
|
||||||
|
|||||||
@ -3,19 +3,19 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from graphrag.config.embeddings import create_collection_name
|
from graphrag.config.embeddings import create_index_name
|
||||||
|
|
||||||
|
|
||||||
def test_create_collection_name():
|
def test_create_index_name():
|
||||||
collection = create_collection_name("default", "entity.title")
|
collection = create_index_name("default", "entity.title")
|
||||||
assert collection == "default-entity-title"
|
assert collection == "default-entity-title"
|
||||||
|
|
||||||
|
|
||||||
def test_create_collection_name_invalid_embedding_throws():
|
def test_create_index_name_invalid_embedding_throws():
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
create_collection_name("default", "invalid.name")
|
create_index_name("default", "invalid.name")
|
||||||
|
|
||||||
|
|
||||||
def test_create_collection_name_invalid_embedding_does_not_throw():
|
def test_create_index_name_invalid_embedding_does_not_throw():
|
||||||
collection = create_collection_name("default", "invalid.name", validate=False)
|
collection = create_index_name("default", "invalid.name", validate=False)
|
||||||
assert collection == "default-invalid-name"
|
assert collection == "default-invalid-name"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user