mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-13 16:47:20 +08:00
Custom vector store schema implementation (#2062)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
* progress on vector customization * fix for lancedb vectors * cosmosdb implementation * uv run poe format * clean test for vector store * semversioner update * test_factory.py integration test fixes * fixes for cosmosdb test * integration test fix for lancedb * uv fix for format * test fixes * fixes for tests * fix cosmosdb bug * print statement * test * test * fix cosmosdb bug * test validation * validation cosmosdb * validate cosmosdb * fix cosmosdb * fix small feedback from PR --------- Co-authored-by: Gaudy Blanco <gaudy-microsoft@MacBook-Pro-m4-Gaudy-For-Work.local>
This commit is contained in:
parent
075cadd59a
commit
82cd3b7df2
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "minor",
|
||||
"description": "add customization to vector store"
|
||||
}
|
||||
55
.vscode/launch.json
vendored
55
.vscode/launch.json
vendored
@ -6,21 +6,24 @@
|
||||
"name": "Indexer",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uv",
|
||||
"module": "graphrag",
|
||||
"args": [
|
||||
"poe", "index",
|
||||
"--root", "<path_to_ragtest_root_demo>"
|
||||
"index",
|
||||
"--root",
|
||||
"<path_to_index_folder>"
|
||||
],
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "Query",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uv",
|
||||
"module": "graphrag",
|
||||
"args": [
|
||||
"poe", "query",
|
||||
"--root", "<path_to_ragtest_root_demo>",
|
||||
"--method", "global",
|
||||
"query",
|
||||
"--root",
|
||||
"<path_to_index_folder>",
|
||||
"--method", "basic",
|
||||
"--query", "What are the top themes in this story",
|
||||
]
|
||||
},
|
||||
@ -34,6 +37,42 @@
|
||||
"--config",
|
||||
"<path_to_ragtest_root_demo>/settings.yaml",
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Debug Integration Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"args": [
|
||||
"./tests/integration/vector_stores",
|
||||
"-k", "test_azure_ai_search"
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Debug Verbs Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"args": [
|
||||
"./tests/verbs",
|
||||
"-k", "test_generate_text_embeddings"
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Debug Smoke Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"args": [
|
||||
"./tests/smoke",
|
||||
"-k", "test_fixtures"
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
},
|
||||
]
|
||||
}
|
||||
@ -394,6 +394,7 @@ class VectorStoreDefaults:
|
||||
api_key: None = None
|
||||
audience: None = None
|
||||
database_name: None = None
|
||||
schema: None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -29,14 +29,14 @@ default_embeddings: list[str] = [
|
||||
]
|
||||
|
||||
|
||||
def create_collection_name(
|
||||
def create_index_name(
|
||||
container_name: str, embedding_name: str, validate: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Create a collection name for the embedding store.
|
||||
Create a index name for the embedding store.
|
||||
|
||||
Within any given vector store, we can have multiple sets of embeddings organized into projects.
|
||||
The `container` param is used for this partitioning, and is added as a prefix to the collection name for differentiation.
|
||||
The `container` param is used for this partitioning, and is added as a prefix to the index name for differentiation.
|
||||
|
||||
The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings
|
||||
|
||||
|
||||
@ -6,7 +6,9 @@
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from graphrag.config.defaults import vector_store_defaults
|
||||
from graphrag.config.embeddings import all_embeddings
|
||||
from graphrag.config.enums import VectorStoreType
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
|
||||
|
||||
class VectorStoreConfig(BaseModel):
|
||||
@ -85,9 +87,25 @@ class VectorStoreConfig(BaseModel):
|
||||
default=vector_store_defaults.overwrite,
|
||||
)
|
||||
|
||||
embeddings_schema: dict[str, VectorStoreSchemaConfig] = {}
|
||||
|
||||
def _validate_embeddings_schema(self) -> None:
|
||||
"""Validate the embeddings schema."""
|
||||
for name in self.embeddings_schema:
|
||||
if name not in all_embeddings:
|
||||
msg = f"vector_store.embeddings_schema contains an invalid embedding schema name: {name}. Please update your settings.yaml and select the correct embedding schema names."
|
||||
raise ValueError(msg)
|
||||
|
||||
if self.type == VectorStoreType.CosmosDB:
|
||||
for id_field in self.embeddings_schema:
|
||||
if id_field != "id":
|
||||
msg = "When using CosmosDB, the id_field in embeddings_schema must be 'id'. Please update your settings.yaml and set the id_field to 'id'."
|
||||
raise ValueError(msg)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_model(self):
|
||||
"""Validate the model."""
|
||||
self._validate_db_uri()
|
||||
self._validate_url()
|
||||
self._validate_embeddings_schema()
|
||||
return self
|
||||
|
||||
66
graphrag/config/models/vector_store_schema_config.py
Normal file
66
graphrag/config/models/vector_store_schema_config.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Parameterization settings for the default configuration."""
|
||||
|
||||
import re
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
DEFAULT_VECTOR_SIZE: int = 1536
|
||||
|
||||
VALID_IDENTIFIER_REGEX = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
|
||||
def is_valid_field_name(field: str) -> bool:
|
||||
"""Check if a field name is valid for CosmosDB."""
|
||||
return bool(VALID_IDENTIFIER_REGEX.match(field))
|
||||
|
||||
|
||||
class VectorStoreSchemaConfig(BaseModel):
|
||||
"""The default configuration section for Vector Store Schema."""
|
||||
|
||||
id_field: str = Field(
|
||||
description="The ID field to use.",
|
||||
default="id",
|
||||
)
|
||||
|
||||
vector_field: str = Field(
|
||||
description="The vector field to use.",
|
||||
default="vector",
|
||||
)
|
||||
|
||||
text_field: str = Field(
|
||||
description="The text field to use.",
|
||||
default="text",
|
||||
)
|
||||
|
||||
attributes_field: str = Field(
|
||||
description="The attributes field to use.",
|
||||
default="attributes",
|
||||
)
|
||||
|
||||
vector_size: int = Field(
|
||||
description="The vector size to use.",
|
||||
default=DEFAULT_VECTOR_SIZE,
|
||||
)
|
||||
|
||||
index_name: str | None = Field(description="The index name to use.", default=None)
|
||||
|
||||
def _validate_schema(self) -> None:
|
||||
"""Validate the schema."""
|
||||
for field in [
|
||||
self.id_field,
|
||||
self.vector_field,
|
||||
self.text_field,
|
||||
self.attributes_field,
|
||||
]:
|
||||
if not is_valid_field_name(field):
|
||||
msg = f"Unsafe or invalid field name: {field}"
|
||||
raise ValueError(msg)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_model(self):
|
||||
"""Validate the model."""
|
||||
self._validate_schema()
|
||||
return self
|
||||
@ -12,7 +12,8 @@ import pandas as pd
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.embeddings import create_collection_name
|
||||
from graphrag.config.embeddings import create_index_name
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy
|
||||
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
|
||||
from graphrag.vector_stores.factory import VectorStoreFactory
|
||||
@ -49,9 +50,9 @@ async def embed_text(
|
||||
vector_store_config = strategy.get("vector_store")
|
||||
|
||||
if vector_store_config:
|
||||
collection_name = _get_collection_name(vector_store_config, embedding_name)
|
||||
index_name = _get_index_name(vector_store_config, embedding_name)
|
||||
vector_store: BaseVectorStore = _create_vector_store(
|
||||
vector_store_config, collection_name
|
||||
vector_store_config, index_name, embedding_name
|
||||
)
|
||||
vector_store_workflow_config = vector_store_config.get(
|
||||
embedding_name, vector_store_config
|
||||
@ -183,27 +184,46 @@ async def _text_embed_with_vector_store(
|
||||
|
||||
|
||||
def _create_vector_store(
|
||||
vector_store_config: dict, collection_name: str
|
||||
vector_store_config: dict, index_name: str, embedding_name: str | None = None
|
||||
) -> BaseVectorStore:
|
||||
vector_store_type: str = str(vector_store_config.get("type"))
|
||||
if collection_name:
|
||||
vector_store_config.update({"collection_name": collection_name})
|
||||
|
||||
embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get(
|
||||
"embeddings_schema", {}
|
||||
)
|
||||
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
|
||||
|
||||
if (
|
||||
embeddings_schema is not None
|
||||
and embedding_name is not None
|
||||
and embedding_name in embeddings_schema
|
||||
):
|
||||
raw_config = embeddings_schema[embedding_name]
|
||||
if isinstance(raw_config, dict):
|
||||
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
|
||||
else:
|
||||
single_embedding_config = raw_config
|
||||
|
||||
if single_embedding_config.index_name is None:
|
||||
single_embedding_config.index_name = index_name
|
||||
|
||||
vector_store = VectorStoreFactory().create_vector_store(
|
||||
vector_store_type, kwargs=vector_store_config
|
||||
vector_store_schema_config=single_embedding_config,
|
||||
vector_store_type=vector_store_type,
|
||||
kwargs=vector_store_config,
|
||||
)
|
||||
|
||||
vector_store.connect(**vector_store_config)
|
||||
return vector_store
|
||||
|
||||
|
||||
def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str:
|
||||
def _get_index_name(vector_store_config: dict, embedding_name: str) -> str:
|
||||
container_name = vector_store_config.get("container_name", "default")
|
||||
collection_name = create_collection_name(container_name, embedding_name)
|
||||
index_name = create_index_name(container_name, embedding_name)
|
||||
|
||||
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}"
|
||||
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {index_name}"
|
||||
logger.info(msg)
|
||||
return collection_name
|
||||
return index_name
|
||||
|
||||
|
||||
def load_strategy(strategy: TextEmbedStrategyType) -> TextEmbeddingStrategy:
|
||||
|
||||
@ -8,9 +8,10 @@ from typing import Any
|
||||
|
||||
from graphrag.cache.factory import CacheFactory
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.config.embeddings import create_collection_name
|
||||
from graphrag.config.embeddings import create_index_name
|
||||
from graphrag.config.models.cache_config import CacheConfig
|
||||
from graphrag.config.models.storage_config import StorageConfig
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.data_model.types import TextEmbedder
|
||||
from graphrag.storage.factory import StorageFactory
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
@ -103,12 +104,33 @@ def get_embedding_store(
|
||||
index_names = []
|
||||
for index, store in config_args.items():
|
||||
vector_store_type = store["type"]
|
||||
collection_name = create_collection_name(
|
||||
index_name = create_index_name(
|
||||
store.get("container_name", "default"), embedding_name
|
||||
)
|
||||
|
||||
embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get(
|
||||
"embeddings_schema", {}
|
||||
)
|
||||
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
|
||||
|
||||
if (
|
||||
embeddings_schema is not None
|
||||
and embedding_name is not None
|
||||
and embedding_name in embeddings_schema
|
||||
):
|
||||
raw_config = embeddings_schema[embedding_name]
|
||||
if isinstance(raw_config, dict):
|
||||
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
|
||||
else:
|
||||
single_embedding_config = raw_config
|
||||
|
||||
if single_embedding_config.index_name is None:
|
||||
single_embedding_config.index_name = index_name
|
||||
|
||||
embedding_store = VectorStoreFactory().create_vector_store(
|
||||
vector_store_type=vector_store_type,
|
||||
kwargs={**store, "collection_name": collection_name},
|
||||
vector_store_schema_config=single_embedding_config,
|
||||
kwargs={**store},
|
||||
)
|
||||
embedding_store.connect(**store)
|
||||
# If there is only a single index, return the embedding store directly
|
||||
|
||||
@ -24,9 +24,9 @@ from azure.search.documents.indexes.models import (
|
||||
)
|
||||
from azure.search.documents.models import VectorizedQuery
|
||||
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.data_model.types import TextEmbedder
|
||||
from graphrag.vector_stores.base import (
|
||||
DEFAULT_VECTOR_SIZE,
|
||||
BaseVectorStore,
|
||||
VectorStoreDocument,
|
||||
VectorStoreSearchResult,
|
||||
@ -38,15 +38,18 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
||||
|
||||
index_client: SearchIndexClient
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
def __init__(
|
||||
self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(
|
||||
vector_store_schema_config=vector_store_schema_config, **kwargs
|
||||
)
|
||||
|
||||
def connect(self, **kwargs: Any) -> Any:
|
||||
"""Connect to AI search vector storage."""
|
||||
url = kwargs["url"]
|
||||
api_key = kwargs.get("api_key")
|
||||
audience = kwargs.get("audience")
|
||||
self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE)
|
||||
|
||||
self.vector_search_profile_name = kwargs.get(
|
||||
"vector_search_profile_name", "vectorSearchProfile"
|
||||
@ -56,7 +59,7 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
||||
audience_arg = {"audience": audience} if audience and not api_key else {}
|
||||
self.db_connection = SearchClient(
|
||||
endpoint=url,
|
||||
index_name=self.collection_name,
|
||||
index_name=self.index_name if self.index_name else "",
|
||||
credential=(
|
||||
AzureKeyCredential(api_key) if api_key else DefaultAzureCredential()
|
||||
),
|
||||
@ -78,8 +81,11 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
||||
) -> None:
|
||||
"""Load documents into an Azure AI Search index."""
|
||||
if overwrite:
|
||||
if self.collection_name in self.index_client.list_index_names():
|
||||
self.index_client.delete_index(self.collection_name)
|
||||
if (
|
||||
self.index_name is not None
|
||||
and self.index_name in self.index_client.list_index_names()
|
||||
):
|
||||
self.index_client.delete_index(self.index_name)
|
||||
|
||||
# Configure vector search profile
|
||||
vector_search = VectorSearch(
|
||||
@ -100,24 +106,26 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
||||
)
|
||||
# Configure the index
|
||||
index = SearchIndex(
|
||||
name=self.collection_name,
|
||||
name=self.index_name if self.index_name else "",
|
||||
fields=[
|
||||
SimpleField(
|
||||
name="id",
|
||||
name=self.id_field,
|
||||
type=SearchFieldDataType.String,
|
||||
key=True,
|
||||
),
|
||||
SearchField(
|
||||
name="vector",
|
||||
name=self.vector_field,
|
||||
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
|
||||
searchable=True,
|
||||
hidden=False, # DRIFT needs to return the vector for client-side similarity
|
||||
vector_search_dimensions=self.vector_size,
|
||||
vector_search_profile_name=self.vector_search_profile_name,
|
||||
),
|
||||
SearchableField(name="text", type=SearchFieldDataType.String),
|
||||
SearchableField(
|
||||
name=self.text_field, type=SearchFieldDataType.String
|
||||
),
|
||||
SimpleField(
|
||||
name="attributes",
|
||||
name=self.attributes_field,
|
||||
type=SearchFieldDataType.String,
|
||||
),
|
||||
],
|
||||
@ -129,10 +137,10 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
||||
|
||||
batch = [
|
||||
{
|
||||
"id": doc.id,
|
||||
"vector": doc.vector,
|
||||
"text": doc.text,
|
||||
"attributes": json.dumps(doc.attributes),
|
||||
self.id_field: doc.id,
|
||||
self.vector_field: doc.vector,
|
||||
self.text_field: doc.text,
|
||||
self.attributes_field: json.dumps(doc.attributes),
|
||||
}
|
||||
for doc in documents
|
||||
if doc.vector is not None
|
||||
@ -151,7 +159,7 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
||||
# More info about odata filtering here: https://learn.microsoft.com/en-us/azure/search/search-query-odata-search-in-function
|
||||
# search.in is faster that joined and/or conditions
|
||||
id_filter = ",".join([f"{id!s}" for id in include_ids])
|
||||
self.query_filter = f"search.in(id, '{id_filter}', ',')"
|
||||
self.query_filter = f"search.in({self.id_field}, '{id_filter}', ',')"
|
||||
|
||||
# Returning to keep consistency with other methods, but not needed
|
||||
# TODO: Refactor on a future PR
|
||||
@ -162,7 +170,7 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
||||
) -> list[VectorStoreSearchResult]:
|
||||
"""Perform a vector-based similarity search."""
|
||||
vectorized_query = VectorizedQuery(
|
||||
vector=query_embedding, k_nearest_neighbors=k, fields="vector"
|
||||
vector=query_embedding, k_nearest_neighbors=k, fields=self.vector_field
|
||||
)
|
||||
|
||||
response = self.db_connection.search(
|
||||
@ -172,10 +180,10 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
||||
return [
|
||||
VectorStoreSearchResult(
|
||||
document=VectorStoreDocument(
|
||||
id=doc.get("id", ""),
|
||||
text=doc.get("text", ""),
|
||||
vector=doc.get("vector", []),
|
||||
attributes=(json.loads(doc.get("attributes", "{}"))),
|
||||
id=doc.get(self.id_field, ""),
|
||||
text=doc.get(self.text_field, ""),
|
||||
vector=doc.get(self.vector_field, []),
|
||||
attributes=(json.loads(doc.get(self.attributes_field, "{}"))),
|
||||
),
|
||||
# Cosine similarity between 0.333 and 1.000
|
||||
# https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking#scores-in-a-hybrid-search-results
|
||||
@ -199,8 +207,8 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
||||
"""Search for a document by id."""
|
||||
response = self.db_connection.get_document(id)
|
||||
return VectorStoreDocument(
|
||||
id=response.get("id", ""),
|
||||
text=response.get("text", ""),
|
||||
vector=response.get("vector", []),
|
||||
attributes=(json.loads(response.get("attributes", "{}"))),
|
||||
id=response.get(self.id_field, ""),
|
||||
text=response.get(self.text_field, ""),
|
||||
vector=response.get(self.vector_field, []),
|
||||
attributes=(json.loads(response.get(self.attributes_field, "{}"))),
|
||||
)
|
||||
|
||||
@ -7,10 +7,9 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.data_model.types import TextEmbedder
|
||||
|
||||
DEFAULT_VECTOR_SIZE: int = 1536
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorStoreDocument:
|
||||
@ -42,18 +41,24 @@ class BaseVectorStore(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
vector_store_schema_config: VectorStoreSchemaConfig,
|
||||
db_connection: Any | None = None,
|
||||
document_collection: Any | None = None,
|
||||
query_filter: Any | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self.db_connection = db_connection
|
||||
self.document_collection = document_collection
|
||||
self.query_filter = query_filter
|
||||
self.kwargs = kwargs
|
||||
|
||||
self.index_name = vector_store_schema_config.index_name
|
||||
self.id_field = vector_store_schema_config.id_field
|
||||
self.text_field = vector_store_schema_config.text_field
|
||||
self.vector_field = vector_store_schema_config.vector_field
|
||||
self.attributes_field = vector_store_schema_config.attributes_field
|
||||
self.vector_size = vector_store_schema_config.vector_size
|
||||
|
||||
@abstractmethod
|
||||
def connect(self, **kwargs: Any) -> None:
|
||||
"""Connect to vector storage."""
|
||||
|
||||
@ -11,9 +11,9 @@ from azure.cosmos.exceptions import CosmosHttpResponseError
|
||||
from azure.cosmos.partition_key import PartitionKey
|
||||
from azure.identity import DefaultAzureCredential
|
||||
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.data_model.types import TextEmbedder
|
||||
from graphrag.vector_stores.base import (
|
||||
DEFAULT_VECTOR_SIZE,
|
||||
BaseVectorStore,
|
||||
VectorStoreDocument,
|
||||
VectorStoreSearchResult,
|
||||
@ -27,8 +27,12 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
_database_client: DatabaseProxy
|
||||
_container_client: ContainerProxy
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
def __init__(
|
||||
self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(
|
||||
vector_store_schema_config=vector_store_schema_config, **kwargs
|
||||
)
|
||||
|
||||
def connect(self, **kwargs: Any) -> Any:
|
||||
"""Connect to CosmosDB vector storage."""
|
||||
@ -49,13 +53,12 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
msg = "Database name must be provided."
|
||||
raise ValueError(msg)
|
||||
self._database_name = database_name
|
||||
collection_name = self.collection_name
|
||||
if collection_name is None:
|
||||
msg = "Collection name is empty or not provided."
|
||||
if self.index_name is None:
|
||||
msg = "Index name is empty or not provided."
|
||||
raise ValueError(msg)
|
||||
self._container_name = collection_name
|
||||
self._container_name = self.index_name
|
||||
|
||||
self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE)
|
||||
self.vector_size = self.vector_size
|
||||
self._create_database()
|
||||
self._create_container()
|
||||
|
||||
@ -80,13 +83,13 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
|
||||
def _create_container(self) -> None:
|
||||
"""Create the container if it doesn't exist."""
|
||||
partition_key = PartitionKey(path="/id", kind="Hash")
|
||||
partition_key = PartitionKey(path=f"/{self.id_field}", kind="Hash")
|
||||
|
||||
# Define the container vector policy
|
||||
vector_embedding_policy = {
|
||||
"vectorEmbeddings": [
|
||||
{
|
||||
"path": "/vector",
|
||||
"path": f"/{self.vector_field}",
|
||||
"dataType": "float32",
|
||||
"distanceFunction": "cosine",
|
||||
"dimensions": self.vector_size,
|
||||
@ -99,13 +102,18 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
"indexingMode": "consistent",
|
||||
"automatic": True,
|
||||
"includedPaths": [{"path": "/*"}],
|
||||
"excludedPaths": [{"path": "/_etag/?"}, {"path": "/vector/*"}],
|
||||
"excludedPaths": [
|
||||
{"path": "/_etag/?"},
|
||||
{"path": f"/{self.vector_field}/*"},
|
||||
],
|
||||
}
|
||||
|
||||
# Currently, the CosmosDB emulator does not support the diskANN policy.
|
||||
try:
|
||||
# First try with the standard diskANN policy
|
||||
indexing_policy["vectorIndexes"] = [{"path": "/vector", "type": "diskANN"}]
|
||||
indexing_policy["vectorIndexes"] = [
|
||||
{"path": f"/{self.vector_field}", "type": "diskANN"}
|
||||
]
|
||||
|
||||
# Create the container and container client
|
||||
self._database_client.create_container_if_not_exists(
|
||||
@ -158,12 +166,16 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
# Upload documents to CosmosDB
|
||||
for doc in documents:
|
||||
if doc.vector is not None:
|
||||
print("Document to store:") # noqa: T201
|
||||
print(doc) # noqa: T201
|
||||
doc_json = {
|
||||
"id": doc.id,
|
||||
"vector": doc.vector,
|
||||
"text": doc.text,
|
||||
"attributes": json.dumps(doc.attributes),
|
||||
self.id_field: doc.id,
|
||||
self.vector_field: doc.vector,
|
||||
self.text_field: doc.text,
|
||||
self.attributes_field: json.dumps(doc.attributes),
|
||||
}
|
||||
print("Storing document in CosmosDB:") # noqa: T201
|
||||
print(doc_json) # noqa: T201
|
||||
self._container_client.upsert_item(doc_json)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
@ -175,7 +187,7 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
raise ValueError(msg)
|
||||
|
||||
try:
|
||||
query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608
|
||||
query = f"SELECT TOP {k} c.{self.id_field}, c.{self.text_field}, c.{self.vector_field}, c.{self.attributes_field}, VectorDistance(c.{self.vector_field}, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.{self.vector_field}, @embedding)" # noqa: S608
|
||||
query_params = [{"name": "@embedding", "value": query_embedding}]
|
||||
items = list(
|
||||
self._container_client.query_items(
|
||||
@ -187,7 +199,7 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
except (CosmosHttpResponseError, ValueError):
|
||||
# Currently, the CosmosDB emulator does not support the VectorDistance function.
|
||||
# For emulator or test environments - fetch all items and calculate distance locally
|
||||
query = "SELECT c.id, c.text, c.vector, c.attributes FROM c"
|
||||
query = f"SELECT c.{self.id_field}, c.{self.text_field}, c.{self.vector_field}, c.{self.attributes_field} FROM c" # noqa: S608
|
||||
items = list(
|
||||
self._container_client.query_items(
|
||||
query=query,
|
||||
@ -206,7 +218,7 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
|
||||
# Calculate scores for all items
|
||||
for item in items:
|
||||
item_vector = item.get("vector", [])
|
||||
item_vector = item.get(self.vector_field, [])
|
||||
similarity = cosine_similarity(query_embedding, item_vector)
|
||||
item["SimilarityScore"] = similarity
|
||||
|
||||
@ -218,10 +230,10 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
return [
|
||||
VectorStoreSearchResult(
|
||||
document=VectorStoreDocument(
|
||||
id=item.get("id", ""),
|
||||
text=item.get("text", ""),
|
||||
vector=item.get("vector", []),
|
||||
attributes=(json.loads(item.get("attributes", "{}"))),
|
||||
id=item.get(self.id_field, ""),
|
||||
text=item.get(self.text_field, ""),
|
||||
vector=item.get(self.vector_field, []),
|
||||
attributes=(json.loads(item.get(self.attributes_field, "{}"))),
|
||||
),
|
||||
score=item.get("SimilarityScore", 0.0),
|
||||
)
|
||||
@ -248,7 +260,9 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
id_filter = ", ".join([f"'{id}'" for id in include_ids])
|
||||
else:
|
||||
id_filter = ", ".join([str(id) for id in include_ids])
|
||||
self.query_filter = f"SELECT * FROM c WHERE c.id IN ({id_filter})" # noqa: S608
|
||||
self.query_filter = (
|
||||
f"SELECT * FROM c WHERE c.{self.id_field} IN ({id_filter})" # noqa: S608
|
||||
)
|
||||
return self.query_filter
|
||||
|
||||
def search_by_id(self, id: str) -> VectorStoreDocument:
|
||||
@ -259,10 +273,10 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
|
||||
item = self._container_client.read_item(item=id, partition_key=id)
|
||||
return VectorStoreDocument(
|
||||
id=item.get("id", ""),
|
||||
vector=item.get("vector", []),
|
||||
text=item.get("text", ""),
|
||||
attributes=(json.loads(item.get("attributes", "{}"))),
|
||||
id=item.get(self.id_field, ""),
|
||||
vector=item.get(self.vector_field, []),
|
||||
text=item.get(self.text_field, ""),
|
||||
attributes=(json.loads(item.get(self.attributes_field, "{}"))),
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
|
||||
@ -15,6 +15,9 @@ from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from graphrag.config.models.vector_store_schema_config import (
|
||||
VectorStoreSchemaConfig,
|
||||
)
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
|
||||
|
||||
@ -47,7 +50,10 @@ class VectorStoreFactory:
|
||||
|
||||
@classmethod
|
||||
def create_vector_store(
|
||||
cls, vector_store_type: str, kwargs: dict
|
||||
cls,
|
||||
vector_store_type: str,
|
||||
vector_store_schema_config: VectorStoreSchemaConfig,
|
||||
kwargs: dict,
|
||||
) -> BaseVectorStore:
|
||||
"""Create a vector store object from the provided type.
|
||||
|
||||
@ -67,7 +73,9 @@ class VectorStoreFactory:
|
||||
msg = f"Unknown vector store type: {vector_store_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return cls._registry[vector_store_type](**kwargs)
|
||||
return cls._registry[vector_store_type](
|
||||
vector_store_schema_config=vector_store_schema_config, **kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_vector_store_types(cls) -> list[str]:
|
||||
|
||||
@ -5,9 +5,9 @@
|
||||
|
||||
import json # noqa: I001
|
||||
from typing import Any
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
import numpy as np
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.data_model.types import TextEmbedder
|
||||
|
||||
from graphrag.vector_stores.base import (
|
||||
@ -21,60 +21,81 @@ import lancedb
|
||||
class LanceDBVectorStore(BaseVectorStore):
|
||||
"""LanceDB vector storage implementation."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
def __init__(
|
||||
self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(
|
||||
vector_store_schema_config=vector_store_schema_config, **kwargs
|
||||
)
|
||||
|
||||
def connect(self, **kwargs: Any) -> Any:
|
||||
"""Connect to the vector storage."""
|
||||
self.db_connection = lancedb.connect(kwargs["db_uri"])
|
||||
if (
|
||||
self.collection_name
|
||||
and self.collection_name in self.db_connection.table_names()
|
||||
):
|
||||
self.document_collection = self.db_connection.open_table(
|
||||
self.collection_name
|
||||
)
|
||||
|
||||
if self.index_name and self.index_name in self.db_connection.table_names():
|
||||
self.document_collection = self.db_connection.open_table(self.index_name)
|
||||
|
||||
def load_documents(
|
||||
self, documents: list[VectorStoreDocument], overwrite: bool = True
|
||||
) -> None:
|
||||
"""Load documents into vector storage."""
|
||||
data = [
|
||||
{
|
||||
"id": document.id,
|
||||
"text": document.text,
|
||||
"vector": document.vector,
|
||||
"attributes": json.dumps(document.attributes),
|
||||
}
|
||||
for document in documents
|
||||
if document.vector is not None
|
||||
]
|
||||
# Step 1: Prepare data columns manually
|
||||
ids = []
|
||||
texts = []
|
||||
vectors = []
|
||||
attributes = []
|
||||
|
||||
if len(data) == 0:
|
||||
for document in documents:
|
||||
self.vector_size = (
|
||||
len(document.vector) if document.vector else self.vector_size
|
||||
)
|
||||
if document.vector is not None and len(document.vector) == self.vector_size:
|
||||
ids.append(document.id)
|
||||
texts.append(document.text)
|
||||
vectors.append(np.array(document.vector, dtype=np.float32))
|
||||
attributes.append(json.dumps(document.attributes))
|
||||
|
||||
# Step 2: Handle empty case
|
||||
if len(ids) == 0:
|
||||
data = None
|
||||
else:
|
||||
# Step 3: Flatten the vectors and build FixedSizeListArray manually
|
||||
flat_vector = np.concatenate(vectors).astype(np.float32)
|
||||
flat_array = pa.array(flat_vector, type=pa.float32())
|
||||
vector_column = pa.FixedSizeListArray.from_arrays(
|
||||
flat_array, self.vector_size
|
||||
)
|
||||
|
||||
# Step 4: Create PyArrow table (let schema be inferred)
|
||||
data = pa.table({
|
||||
self.id_field: pa.array(ids, type=pa.string()),
|
||||
self.text_field: pa.array(texts, type=pa.string()),
|
||||
self.vector_field: vector_column,
|
||||
self.attributes_field: pa.array(attributes, type=pa.string()),
|
||||
})
|
||||
|
||||
schema = pa.schema([
|
||||
pa.field("id", pa.string()),
|
||||
pa.field("text", pa.string()),
|
||||
pa.field("vector", pa.list_(pa.float64())),
|
||||
pa.field("attributes", pa.string()),
|
||||
])
|
||||
# NOTE: If modifying the next section of code, ensure that the schema remains the same.
|
||||
# The pyarrow format of the 'vector' field may change if the order of operations is changed
|
||||
# and will break vector search.
|
||||
if overwrite:
|
||||
if data:
|
||||
self.document_collection = self.db_connection.create_table(
|
||||
self.collection_name, data=data, mode="overwrite"
|
||||
self.index_name if self.index_name else "",
|
||||
data=data,
|
||||
mode="overwrite",
|
||||
schema=data.schema,
|
||||
)
|
||||
else:
|
||||
self.document_collection = self.db_connection.create_table(
|
||||
self.collection_name, schema=schema, mode="overwrite"
|
||||
self.index_name if self.index_name else "", mode="overwrite"
|
||||
)
|
||||
self.document_collection.create_index(
|
||||
vector_column_name=self.vector_field, index_type="IVF_FLAT"
|
||||
)
|
||||
else:
|
||||
# add data to existing table
|
||||
self.document_collection = self.db_connection.open_table(
|
||||
self.collection_name
|
||||
self.index_name if self.index_name else ""
|
||||
)
|
||||
if data:
|
||||
self.document_collection.add(data)
|
||||
@ -86,30 +107,32 @@ class LanceDBVectorStore(BaseVectorStore):
|
||||
else:
|
||||
if isinstance(include_ids[0], str):
|
||||
id_filter = ", ".join([f"'{id}'" for id in include_ids])
|
||||
self.query_filter = f"id in ({id_filter})"
|
||||
self.query_filter = f"{self.id_field} in ({id_filter})"
|
||||
else:
|
||||
self.query_filter = (
|
||||
f"id in ({', '.join([str(id) for id in include_ids])})"
|
||||
f"{self.id_field} in ({', '.join([str(id) for id in include_ids])})"
|
||||
)
|
||||
return self.query_filter
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self, query_embedding: list[float], k: int = 10, **kwargs: Any
|
||||
self, query_embedding: list[float] | np.ndarray, k: int = 10, **kwargs: Any
|
||||
) -> list[VectorStoreSearchResult]:
|
||||
"""Perform a vector-based similarity search."""
|
||||
if self.query_filter:
|
||||
docs = (
|
||||
self.document_collection.search(
|
||||
query=query_embedding, vector_column_name="vector"
|
||||
query=query_embedding, vector_column_name=self.vector_field
|
||||
)
|
||||
.where(self.query_filter, prefilter=True)
|
||||
.limit(k)
|
||||
.to_list()
|
||||
)
|
||||
else:
|
||||
query_embedding = np.array(query_embedding, dtype=np.float32)
|
||||
|
||||
docs = (
|
||||
self.document_collection.search(
|
||||
query=query_embedding, vector_column_name="vector"
|
||||
query=query_embedding, vector_column_name=self.vector_field
|
||||
)
|
||||
.limit(k)
|
||||
.to_list()
|
||||
@ -117,10 +140,10 @@ class LanceDBVectorStore(BaseVectorStore):
|
||||
return [
|
||||
VectorStoreSearchResult(
|
||||
document=VectorStoreDocument(
|
||||
id=doc["id"],
|
||||
text=doc["text"],
|
||||
vector=doc["vector"],
|
||||
attributes=json.loads(doc["attributes"]),
|
||||
id=doc[self.id_field],
|
||||
text=doc[self.text_field],
|
||||
vector=doc[self.vector_field],
|
||||
attributes=json.loads(doc[self.attributes_field]),
|
||||
),
|
||||
score=1 - abs(float(doc["_distance"])),
|
||||
)
|
||||
@ -140,14 +163,14 @@ class LanceDBVectorStore(BaseVectorStore):
|
||||
"""Search for a document by id."""
|
||||
doc = (
|
||||
self.document_collection.search()
|
||||
.where(f"id == '{id}'", prefilter=True)
|
||||
.where(f"{self.id_field} == '{id}'", prefilter=True)
|
||||
.to_list()
|
||||
)
|
||||
if doc:
|
||||
return VectorStoreDocument(
|
||||
id=doc[0]["id"],
|
||||
text=doc[0]["text"],
|
||||
vector=doc[0]["vector"],
|
||||
attributes=json.loads(doc[0]["attributes"]),
|
||||
id=doc[0][self.id_field],
|
||||
text=doc[0][self.text_field],
|
||||
vector=doc[0][self.vector_field],
|
||||
attributes=json.loads(doc[0][self.attributes_field]),
|
||||
)
|
||||
return VectorStoreDocument(id=id, text=None, vector=None)
|
||||
|
||||
@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
|
||||
from graphrag.vector_stores.base import VectorStoreDocument
|
||||
|
||||
@ -39,7 +40,35 @@ class TestAzureAISearchVectorStore:
|
||||
@pytest.fixture
|
||||
def vector_store(self, mock_search_client, mock_index_client):
|
||||
"""Create an Azure AI Search vector store instance."""
|
||||
vector_store = AzureAISearchVectorStore(collection_name="test_vectors")
|
||||
vector_store = AzureAISearchVectorStore(
|
||||
vector_store_schema_config=VectorStoreSchemaConfig(
|
||||
index_name="test_vectors", vector_size=5
|
||||
),
|
||||
)
|
||||
|
||||
# Create the necessary mocks first
|
||||
vector_store.db_connection = mock_search_client
|
||||
vector_store.index_client = mock_index_client
|
||||
|
||||
vector_store.connect(
|
||||
url=TEST_AZURE_AI_SEARCH_URL,
|
||||
api_key=TEST_AZURE_AI_SEARCH_KEY,
|
||||
)
|
||||
return vector_store
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store_custom(self, mock_search_client, mock_index_client):
|
||||
"""Create an Azure AI Search vector store instance."""
|
||||
vector_store = AzureAISearchVectorStore(
|
||||
vector_store_schema_config=VectorStoreSchemaConfig(
|
||||
index_name="test_vectors",
|
||||
id_field="id_custom",
|
||||
text_field="text_custom",
|
||||
attributes_field="attributes_custom",
|
||||
vector_field="vector_custom",
|
||||
vector_size=5,
|
||||
),
|
||||
)
|
||||
|
||||
# Create the necessary mocks first
|
||||
vector_store.db_connection = mock_search_client
|
||||
@ -48,7 +77,6 @@ class TestAzureAISearchVectorStore:
|
||||
vector_store.connect(
|
||||
url=TEST_AZURE_AI_SEARCH_URL,
|
||||
api_key=TEST_AZURE_AI_SEARCH_KEY,
|
||||
vector_size=5,
|
||||
)
|
||||
return vector_store
|
||||
|
||||
@ -144,3 +172,72 @@ class TestAzureAISearchVectorStore:
|
||||
)
|
||||
assert not mock_search_client.search.called
|
||||
assert len(results) == 0
|
||||
|
||||
async def test_vector_store_customization(
|
||||
self,
|
||||
vector_store_custom,
|
||||
sample_documents,
|
||||
mock_search_client,
|
||||
mock_index_client,
|
||||
):
|
||||
"""Test vector store customization with Azure AI Search."""
|
||||
# Setup mock responses
|
||||
mock_index_client.list_index_names.return_value = []
|
||||
mock_index_client.create_or_update_index = MagicMock()
|
||||
mock_search_client.upload_documents = MagicMock()
|
||||
|
||||
search_results = [
|
||||
{
|
||||
vector_store_custom.id_field: "doc1",
|
||||
vector_store_custom.text_field: "This is document 1",
|
||||
vector_store_custom.vector_field: [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
vector_store_custom.attributes_field: '{"title": "Doc 1", "category": "test"}',
|
||||
"@search.score": 0.9,
|
||||
},
|
||||
{
|
||||
vector_store_custom.id_field: "doc2",
|
||||
vector_store_custom.text_field: "This is document 2",
|
||||
vector_store_custom.vector_field: [0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
vector_store_custom.attributes_field: '{"title": "Doc 2", "category": "test"}',
|
||||
"@search.score": 0.8,
|
||||
},
|
||||
]
|
||||
mock_search_client.search.return_value = search_results
|
||||
|
||||
mock_search_client.get_document.return_value = {
|
||||
vector_store_custom.id_field: "doc1",
|
||||
vector_store_custom.text_field: "This is document 1",
|
||||
vector_store_custom.vector_field: [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
vector_store_custom.attributes_field: '{"title": "Doc 1", "category": "test"}',
|
||||
}
|
||||
|
||||
vector_store_custom.load_documents(sample_documents)
|
||||
assert mock_index_client.create_or_update_index.called
|
||||
assert mock_search_client.upload_documents.called
|
||||
|
||||
filter_query = vector_store_custom.filter_by_id(["doc1", "doc2"])
|
||||
assert (
|
||||
filter_query
|
||||
== f"search.in({vector_store_custom.id_field}, 'doc1,doc2', ',')"
|
||||
)
|
||||
|
||||
vector_results = vector_store_custom.similarity_search_by_vector(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
|
||||
)
|
||||
assert len(vector_results) == 2
|
||||
assert vector_results[0].document.id == "doc1"
|
||||
assert vector_results[0].score == 0.9
|
||||
|
||||
# Define a simple text embedder function for testing
|
||||
def mock_embedder(text: str) -> list[float]:
|
||||
return [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
|
||||
text_results = vector_store_custom.similarity_search_by_text(
|
||||
"test query", mock_embedder, k=2
|
||||
)
|
||||
assert len(text_results) == 2
|
||||
|
||||
doc = vector_store_custom.search_by_id("doc1")
|
||||
assert doc.id == "doc1"
|
||||
assert doc.text == "This is document 1"
|
||||
assert doc.attributes["title"] == "Doc 1"
|
||||
|
||||
@ -8,6 +8,7 @@ import sys
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.vector_stores.base import VectorStoreDocument
|
||||
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
|
||||
|
||||
@ -24,7 +25,7 @@ if not sys.platform.startswith("win"):
|
||||
def test_vector_store_operations():
|
||||
"""Test basic vector store operations with CosmosDB."""
|
||||
vector_store = CosmosDBVectorStore(
|
||||
collection_name="testvector",
|
||||
vector_store_schema_config=VectorStoreSchemaConfig(index_name="testvector"),
|
||||
)
|
||||
|
||||
try:
|
||||
@ -78,7 +79,7 @@ def test_vector_store_operations():
|
||||
def test_clear():
|
||||
"""Test clearing the vector store."""
|
||||
vector_store = CosmosDBVectorStore(
|
||||
collection_name="testclear",
|
||||
vector_store_schema_config=VectorStoreSchemaConfig(index_name="testclear"),
|
||||
)
|
||||
try:
|
||||
vector_store.connect(
|
||||
@ -102,3 +103,64 @@ def test_clear():
|
||||
assert vector_store._database_exists() is False # noqa: SLF001
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
def test_vector_store_customization():
|
||||
"""Test vector store customization with CosmosDB."""
|
||||
vector_store = CosmosDBVectorStore(
|
||||
vector_store_schema_config=VectorStoreSchemaConfig(
|
||||
index_name="text-embeddings",
|
||||
id_field="id",
|
||||
text_field="text_custom",
|
||||
vector_field="vector_custom",
|
||||
attributes_field="attributes_custom",
|
||||
vector_size=5,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
vector_store.connect(
|
||||
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
|
||||
database_name="test_db",
|
||||
)
|
||||
|
||||
docs = [
|
||||
VectorStoreDocument(
|
||||
id="doc1",
|
||||
text="This is document 1",
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
attributes={"title": "Doc 1", "category": "test"},
|
||||
),
|
||||
VectorStoreDocument(
|
||||
id="doc2",
|
||||
text="This is document 2",
|
||||
vector=[0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
attributes={"title": "Doc 2", "category": "test"},
|
||||
),
|
||||
]
|
||||
vector_store.load_documents(docs)
|
||||
|
||||
vector_store.filter_by_id(["doc1"])
|
||||
|
||||
doc = vector_store.search_by_id("doc1")
|
||||
assert doc.id == "doc1"
|
||||
assert doc.text == "This is document 1"
|
||||
assert doc.vector is not None
|
||||
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
assert doc.attributes["title"] == "Doc 1"
|
||||
|
||||
# Define a simple text embedder function for testing
|
||||
def mock_embedder(text: str) -> list[float]:
|
||||
return [0.1, 0.2, 0.3, 0.4, 0.5] # Return fixed embedding
|
||||
|
||||
vector_results = vector_store.similarity_search_by_vector(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
|
||||
)
|
||||
assert len(vector_results) > 0
|
||||
|
||||
text_results = vector_store.similarity_search_by_text(
|
||||
"test query", mock_embedder, k=2
|
||||
)
|
||||
assert len(text_results) > 0
|
||||
finally:
|
||||
vector_store.clear()
|
||||
|
||||
@ -8,6 +8,7 @@ These tests will test the VectorStoreFactory class and the creation of each vect
|
||||
import pytest
|
||||
|
||||
from graphrag.config.enums import VectorStoreType
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
|
||||
@ -17,25 +18,31 @@ from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
|
||||
def test_create_lancedb_vector_store():
|
||||
kwargs = {
|
||||
"collection_name": "test_collection",
|
||||
"db_uri": "/tmp/lancedb",
|
||||
}
|
||||
vector_store = VectorStoreFactory.create_vector_store(
|
||||
VectorStoreType.LanceDB.value, kwargs
|
||||
vector_store_type=VectorStoreType.LanceDB.value,
|
||||
vector_store_schema_config=VectorStoreSchemaConfig(
|
||||
index_name="test_collection"
|
||||
),
|
||||
kwargs=kwargs,
|
||||
)
|
||||
assert isinstance(vector_store, LanceDBVectorStore)
|
||||
assert vector_store.collection_name == "test_collection"
|
||||
assert vector_store.index_name == "test_collection"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Azure AI Search requires credentials and setup")
|
||||
def test_create_azure_ai_search_vector_store():
|
||||
kwargs = {
|
||||
"collection_name": "test_collection",
|
||||
"url": "https://test.search.windows.net",
|
||||
"api_key": "test_key",
|
||||
}
|
||||
vector_store = VectorStoreFactory.create_vector_store(
|
||||
VectorStoreType.AzureAISearch.value, kwargs
|
||||
vector_store_type=VectorStoreType.AzureAISearch.value,
|
||||
vector_store_schema_config=VectorStoreSchemaConfig(
|
||||
index_name="test_collection"
|
||||
),
|
||||
kwargs=kwargs,
|
||||
)
|
||||
assert isinstance(vector_store, AzureAISearchVectorStore)
|
||||
|
||||
@ -43,13 +50,18 @@ def test_create_azure_ai_search_vector_store():
|
||||
@pytest.mark.skip(reason="CosmosDB requires credentials and setup")
|
||||
def test_create_cosmosdb_vector_store():
|
||||
kwargs = {
|
||||
"collection_name": "test_collection",
|
||||
"connection_string": "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=test_key==",
|
||||
"database_name": "test_db",
|
||||
}
|
||||
|
||||
vector_store = VectorStoreFactory.create_vector_store(
|
||||
VectorStoreType.CosmosDB.value, kwargs
|
||||
vector_store_type=VectorStoreType.CosmosDB.value,
|
||||
vector_store_schema_config=VectorStoreSchemaConfig(
|
||||
index_name="test_collection"
|
||||
),
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
assert isinstance(vector_store, CosmosDBVectorStore)
|
||||
|
||||
|
||||
@ -67,7 +79,12 @@ def test_register_and_create_custom_vector_store():
|
||||
VectorStoreFactory.register(
|
||||
"custom", lambda **kwargs: custom_vector_store_class(**kwargs)
|
||||
)
|
||||
vector_store = VectorStoreFactory.create_vector_store("custom", {})
|
||||
|
||||
vector_store = VectorStoreFactory.create_vector_store(
|
||||
vector_store_type="custom",
|
||||
vector_store_schema_config=VectorStoreSchemaConfig(),
|
||||
kwargs={},
|
||||
)
|
||||
|
||||
assert custom_vector_store_class.called
|
||||
assert vector_store is instance
|
||||
@ -89,7 +106,11 @@ def test_get_vector_store_types():
|
||||
|
||||
def test_create_unknown_vector_store():
|
||||
with pytest.raises(ValueError, match="Unknown vector store type: unknown"):
|
||||
VectorStoreFactory.create_vector_store("unknown", {})
|
||||
VectorStoreFactory.create_vector_store(
|
||||
vector_store_type="unknown",
|
||||
vector_store_schema_config=VectorStoreSchemaConfig(),
|
||||
kwargs={},
|
||||
)
|
||||
|
||||
|
||||
def test_is_supported_type():
|
||||
@ -139,6 +160,9 @@ def test_register_class_directly_works():
|
||||
|
||||
# Test creating an instance
|
||||
vector_store = VectorStoreFactory.create_vector_store(
|
||||
"custom_class", {"collection_name": "test"}
|
||||
vector_store_type="custom_class",
|
||||
vector_store_schema_config=VectorStoreSchemaConfig(),
|
||||
kwargs={"collection_name": "test"},
|
||||
)
|
||||
|
||||
assert isinstance(vector_store, CustomVectorStore)
|
||||
|
||||
@ -7,20 +7,20 @@ import shutil
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.vector_stores.base import VectorStoreDocument
|
||||
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
|
||||
|
||||
def test_vector_store_operations():
|
||||
"""Test basic vector store operations with LanceDB."""
|
||||
# Create a temporary directory for the test database
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
vector_store = LanceDBVectorStore(collection_name="test_collection")
|
||||
vector_store.connect(db_uri=temp_dir)
|
||||
class TestLanceDBVectorStore:
|
||||
"""Test class for TestLanceDBVectorStore."""
|
||||
|
||||
docs = [
|
||||
@pytest.fixture
|
||||
def sample_documents(self):
|
||||
"""Create sample documents for testing."""
|
||||
return [
|
||||
VectorStoreDocument(
|
||||
id="1",
|
||||
text="This is document 1",
|
||||
@ -40,102 +40,11 @@ def test_vector_store_operations():
|
||||
attributes={"title": "Doc 3", "category": "test"},
|
||||
),
|
||||
]
|
||||
vector_store.load_documents(docs[:2])
|
||||
|
||||
assert vector_store.collection_name in vector_store.db_connection.table_names()
|
||||
|
||||
doc = vector_store.search_by_id("1")
|
||||
assert doc.id == "1"
|
||||
assert doc.text == "This is document 1"
|
||||
|
||||
assert doc.vector is not None
|
||||
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
assert doc.attributes["title"] == "Doc 1"
|
||||
|
||||
filter_query = vector_store.filter_by_id(["1"])
|
||||
assert filter_query == "id in ('1')"
|
||||
|
||||
results = vector_store.similarity_search_by_vector(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
|
||||
)
|
||||
assert 1 <= len(results) <= 2
|
||||
assert isinstance(results[0].score, float)
|
||||
|
||||
# Test append mode
|
||||
vector_store.load_documents([docs[2]], overwrite=False)
|
||||
result = vector_store.search_by_id("3")
|
||||
assert result.id == "3"
|
||||
assert result.text == "This is document 3"
|
||||
|
||||
# Define a simple text embedder function for testing
|
||||
def mock_embedder(text: str) -> list[float]:
|
||||
return [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
|
||||
text_results = vector_store.similarity_search_by_text(
|
||||
"test query", mock_embedder, k=2
|
||||
)
|
||||
assert 1 <= len(text_results) <= 2
|
||||
assert isinstance(text_results[0].score, float)
|
||||
|
||||
# Test non-existent document
|
||||
non_existent = vector_store.search_by_id("nonexistent")
|
||||
assert non_existent.id == "nonexistent"
|
||||
assert non_existent.text is None
|
||||
assert non_existent.vector is None
|
||||
finally:
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
def test_empty_collection():
|
||||
"""Test creating an empty collection."""
|
||||
# Create a temporary directory for the test database
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
vector_store = LanceDBVectorStore(collection_name="empty_collection")
|
||||
vector_store.connect(db_uri=temp_dir)
|
||||
|
||||
# Load the vector store with a document, then delete it
|
||||
sample_doc = VectorStoreDocument(
|
||||
id="tmp",
|
||||
text="Temporary document to create schema",
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
attributes={"title": "Tmp"},
|
||||
)
|
||||
vector_store.load_documents([sample_doc])
|
||||
vector_store.db_connection.open_table(vector_store.collection_name).delete(
|
||||
"id = 'tmp'"
|
||||
)
|
||||
|
||||
# Should still have the collection
|
||||
assert vector_store.collection_name in vector_store.db_connection.table_names()
|
||||
|
||||
# Add a document after creating an empty collection
|
||||
doc = VectorStoreDocument(
|
||||
id="1",
|
||||
text="This is document 1",
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
attributes={"title": "Doc 1"},
|
||||
)
|
||||
vector_store.load_documents([doc], overwrite=False)
|
||||
|
||||
result = vector_store.search_by_id("1")
|
||||
assert result.id == "1"
|
||||
assert result.text == "This is document 1"
|
||||
finally:
|
||||
# Clean up - remove the temporary directory
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
def test_filter_search():
|
||||
"""Test filtered search with LanceDB."""
|
||||
# Create a temporary directory for the test database
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
vector_store = LanceDBVectorStore(collection_name="filter_collection")
|
||||
vector_store.connect(db_uri=temp_dir)
|
||||
|
||||
# Create test documents with different categories
|
||||
docs = [
|
||||
@pytest.fixture
|
||||
def sample_documents_categories(self):
|
||||
"""Create sample documents with different categories for testing."""
|
||||
return [
|
||||
VectorStoreDocument(
|
||||
id="1",
|
||||
text="Document about cats",
|
||||
@ -155,18 +64,201 @@ def test_filter_search():
|
||||
attributes={"category": "vehicles"},
|
||||
),
|
||||
]
|
||||
vector_store.load_documents(docs)
|
||||
|
||||
# Filter to include only documents about animals
|
||||
vector_store.filter_by_id(["1", "2"])
|
||||
results = vector_store.similarity_search_by_vector(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5], k=3
|
||||
)
|
||||
def test_vector_store_operations(self, sample_documents):
|
||||
"""Test basic vector store operations with LanceDB."""
|
||||
# Create a temporary directory for the test database
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
vector_store = LanceDBVectorStore(
|
||||
vector_store_schema_config=VectorStoreSchemaConfig(
|
||||
index_name="test_collection", vector_size=5
|
||||
)
|
||||
)
|
||||
vector_store.connect(db_uri=temp_dir)
|
||||
vector_store.load_documents(sample_documents[:2])
|
||||
|
||||
# Should return at most 2 documents (the filtered ones)
|
||||
assert len(results) <= 2
|
||||
ids = [result.document.id for result in results]
|
||||
assert "3" not in ids
|
||||
assert set(ids).issubset({"1", "2"})
|
||||
finally:
|
||||
shutil.rmtree(temp_dir)
|
||||
if vector_store.index_name:
|
||||
assert (
|
||||
vector_store.index_name in vector_store.db_connection.table_names()
|
||||
)
|
||||
|
||||
doc = vector_store.search_by_id("1")
|
||||
assert doc.id == "1"
|
||||
assert doc.text == "This is document 1"
|
||||
|
||||
assert doc.vector is not None
|
||||
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
assert doc.attributes["title"] == "Doc 1"
|
||||
|
||||
filter_query = vector_store.filter_by_id(["1"])
|
||||
assert filter_query == "id in ('1')"
|
||||
|
||||
results = vector_store.similarity_search_by_vector(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
|
||||
)
|
||||
assert 1 <= len(results) <= 2
|
||||
assert isinstance(results[0].score, float)
|
||||
|
||||
# Test append mode
|
||||
vector_store.load_documents([sample_documents[2]], overwrite=False)
|
||||
result = vector_store.search_by_id("3")
|
||||
assert result.id == "3"
|
||||
assert result.text == "This is document 3"
|
||||
|
||||
# Define a simple text embedder function for testing
|
||||
def mock_embedder(text: str) -> list[float]:
|
||||
return [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
|
||||
text_results = vector_store.similarity_search_by_text(
|
||||
"test query", mock_embedder, k=2
|
||||
)
|
||||
assert 1 <= len(text_results) <= 2
|
||||
assert isinstance(text_results[0].score, float)
|
||||
|
||||
# Test non-existent document
|
||||
non_existent = vector_store.search_by_id("nonexistent")
|
||||
assert non_existent.id == "nonexistent"
|
||||
assert non_existent.text is None
|
||||
assert non_existent.vector is None
|
||||
finally:
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
def test_empty_collection(self):
|
||||
"""Test creating an empty collection."""
|
||||
# Create a temporary directory for the test database
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
vector_store = LanceDBVectorStore(
|
||||
vector_store_schema_config=VectorStoreSchemaConfig(
|
||||
index_name="empty_collection", vector_size=5
|
||||
)
|
||||
)
|
||||
vector_store.connect(db_uri=temp_dir)
|
||||
|
||||
# Load the vector store with a document, then delete it
|
||||
sample_doc = VectorStoreDocument(
|
||||
id="tmp",
|
||||
text="Temporary document to create schema",
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
attributes={"title": "Tmp"},
|
||||
)
|
||||
vector_store.load_documents([sample_doc])
|
||||
vector_store.db_connection.open_table(
|
||||
vector_store.index_name if vector_store.index_name else ""
|
||||
).delete("id = 'tmp'")
|
||||
|
||||
# Should still have the collection
|
||||
if vector_store.index_name:
|
||||
assert (
|
||||
vector_store.index_name in vector_store.db_connection.table_names()
|
||||
)
|
||||
|
||||
# Add a document after creating an empty collection
|
||||
doc = VectorStoreDocument(
|
||||
id="1",
|
||||
text="This is document 1",
|
||||
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
attributes={"title": "Doc 1"},
|
||||
)
|
||||
vector_store.load_documents([doc], overwrite=False)
|
||||
|
||||
result = vector_store.search_by_id("1")
|
||||
assert result.id == "1"
|
||||
assert result.text == "This is document 1"
|
||||
finally:
|
||||
# Clean up - remove the temporary directory
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
def test_filter_search(self, sample_documents_categories):
|
||||
"""Test filtered search with LanceDB."""
|
||||
# Create a temporary directory for the test database
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
vector_store = LanceDBVectorStore(
|
||||
vector_store_schema_config=VectorStoreSchemaConfig(
|
||||
index_name="filter_collection", vector_size=5
|
||||
)
|
||||
)
|
||||
|
||||
vector_store.connect(db_uri=temp_dir)
|
||||
|
||||
vector_store.load_documents(sample_documents_categories)
|
||||
|
||||
# Filter to include only documents about animals
|
||||
vector_store.filter_by_id(["1", "2"])
|
||||
results = vector_store.similarity_search_by_vector(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5], k=3
|
||||
)
|
||||
|
||||
# Should return at most 2 documents (the filtered ones)
|
||||
assert len(results) <= 2
|
||||
ids = [result.document.id for result in results]
|
||||
assert "3" not in ids
|
||||
assert set(ids).issubset({"1", "2"})
|
||||
finally:
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
def test_vector_store_customization(self, sample_documents):
|
||||
"""Test vector store customization with LanceDB."""
|
||||
# Create a temporary directory for the test database
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
vector_store = LanceDBVectorStore(
|
||||
vector_store_schema_config=VectorStoreSchemaConfig(
|
||||
index_name="text-embeddings",
|
||||
id_field="id_custom",
|
||||
text_field="text_custom",
|
||||
vector_field="vector_custom",
|
||||
attributes_field="attributes_custom",
|
||||
vector_size=5,
|
||||
),
|
||||
)
|
||||
vector_store.connect(db_uri=temp_dir)
|
||||
vector_store.load_documents(sample_documents[:2])
|
||||
|
||||
if vector_store.index_name:
|
||||
assert (
|
||||
vector_store.index_name in vector_store.db_connection.table_names()
|
||||
)
|
||||
|
||||
doc = vector_store.search_by_id("1")
|
||||
assert doc.id == "1"
|
||||
assert doc.text == "This is document 1"
|
||||
|
||||
assert doc.vector is not None
|
||||
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
assert doc.attributes["title"] == "Doc 1"
|
||||
|
||||
filter_query = vector_store.filter_by_id(["1"])
|
||||
assert filter_query == f"{vector_store.id_field} in ('1')"
|
||||
|
||||
results = vector_store.similarity_search_by_vector(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
|
||||
)
|
||||
assert 1 <= len(results) <= 2
|
||||
assert isinstance(results[0].score, float)
|
||||
|
||||
# Test append mode
|
||||
vector_store.load_documents([sample_documents[2]], overwrite=False)
|
||||
result = vector_store.search_by_id("3")
|
||||
assert result.id == "3"
|
||||
assert result.text == "This is document 3"
|
||||
|
||||
# Define a simple text embedder function for testing
|
||||
def mock_embedder(text: str) -> list[float]:
|
||||
return [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
|
||||
text_results = vector_store.similarity_search_by_text(
|
||||
"test query", mock_embedder, k=2
|
||||
)
|
||||
assert 1 <= len(text_results) <= 2
|
||||
assert isinstance(text_results[0].score, float)
|
||||
|
||||
# Test non-existent document
|
||||
non_existent = vector_store.search_by_id("nonexistent")
|
||||
assert non_existent.id == "nonexistent"
|
||||
assert non_existent.text is None
|
||||
assert non_existent.vector is None
|
||||
finally:
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.data_model.entity import Entity
|
||||
from graphrag.data_model.types import TextEmbedder
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
@ -19,7 +20,9 @@ from graphrag.vector_stores.base import (
|
||||
|
||||
class MockBaseVectorStore(BaseVectorStore):
|
||||
def __init__(self, documents: list[VectorStoreDocument]) -> None:
|
||||
super().__init__("mock")
|
||||
super().__init__(
|
||||
vector_store_schema_config=VectorStoreSchemaConfig(index_name="mock")
|
||||
)
|
||||
self.documents = documents
|
||||
|
||||
def connect(self, **kwargs: Any) -> None:
|
||||
|
||||
@ -3,19 +3,19 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from graphrag.config.embeddings import create_collection_name
|
||||
from graphrag.config.embeddings import create_index_name
|
||||
|
||||
|
||||
def test_create_collection_name():
|
||||
collection = create_collection_name("default", "entity.title")
|
||||
def test_create_index_name():
|
||||
collection = create_index_name("default", "entity.title")
|
||||
assert collection == "default-entity-title"
|
||||
|
||||
|
||||
def test_create_collection_name_invalid_embedding_throws():
|
||||
def test_create_index_name_invalid_embedding_throws():
|
||||
with pytest.raises(KeyError):
|
||||
create_collection_name("default", "invalid.name")
|
||||
create_index_name("default", "invalid.name")
|
||||
|
||||
|
||||
def test_create_collection_name_invalid_embedding_does_not_throw():
|
||||
collection = create_collection_name("default", "invalid.name", validate=False)
|
||||
def test_create_index_name_invalid_embedding_does_not_throw():
|
||||
collection = create_index_name("default", "invalid.name", validate=False)
|
||||
assert collection == "default-invalid-name"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user