progress on vector customization

This commit is contained in:
Gaudy Blanco 2025-09-11 13:42:49 -06:00
parent 1cb20b66f5
commit bfaa7ef016
12 changed files with 190 additions and 87 deletions

View File

@ -393,6 +393,7 @@ class VectorStoreDefaults:
api_key: None = None
audience: None = None
database_name: None = None
schema: None = None
@dataclass

View File

@ -29,14 +29,14 @@ default_embeddings: list[str] = [
]
def create_collection_name(
def create_index_name(
container_name: str, embedding_name: str, validate: bool = True
) -> str:
"""
Create a collection name for the embedding store.
Create a index name for the embedding store.
Within any given vector store, we can have multiple sets of embeddings organized into projects.
The `container` param is used for this partitioning, and is added as a prefix to the collection name for differentiation.
The `container` param is used for this partitioning, and is added as a prefix to the index name for differentiation.
The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings

View File

@ -6,7 +6,9 @@
from pydantic import BaseModel, Field, model_validator
from graphrag.config.defaults import vector_store_defaults
from graphrag.config.embeddings import all_embeddings
from graphrag.config.enums import VectorStoreType
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
class VectorStoreConfig(BaseModel):
@ -85,9 +87,19 @@ class VectorStoreConfig(BaseModel):
default=vector_store_defaults.overwrite,
)
embeddings_schema: dict[str, VectorStoreSchemaConfig] = {}
def _validate_embeddings_schema(self) -> None:
"""Validate the embeddings schema."""
for name in self.embeddings_schema:
if name not in all_embeddings:
msg = f"vector_store.embeddings_schema contains an invalid embedding schema name: {name}. Please rerun `graphrag init` and select the correct embedding schema names."
raise ValueError(msg)
@model_validator(mode="after")
def _validate_model(self):
"""Validate the model."""
self._validate_db_uri()
self._validate_url()
self._validate_embeddings_schema()
return self

View File

@ -0,0 +1,51 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Parameterization settings for the default configuration."""
from pydantic import BaseModel, Field, model_validator
DEFAULT_VECTOR_SIZE: int = 1536
class VectorStoreSchemaConfig(BaseModel):
"""The default configuration section for Vector Store Schema."""
index_name: str = Field(
description="The index name to use.",
default=""
)
id_field: str = Field(
description="The ID field to use.",
default="id",
)
vector_field: str = Field(
description="The vector field to use.",
default="vector",
)
text_field: str = Field(
description="The text field to use.",
default="text",
)
attributes_field: str = Field(
description="The attributes field to use.",
default="attributes",
)
vector_size: int = Field(
description="The vector size to use.",
default=DEFAULT_VECTOR_SIZE,
)
#TODO GAUDY
def _validate_schema(self) -> None:
"""Validate the schema."""
@model_validator(mode="after")
def _validate_model(self):
"""Validate the model."""
self._validate_schema()
return self

View File

@ -12,7 +12,8 @@ import pandas as pd
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.embeddings import create_collection_name
from graphrag.config.embeddings import create_index_name
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
from graphrag.vector_stores.factory import VectorStoreFactory
@ -49,9 +50,9 @@ async def embed_text(
vector_store_config = strategy.get("vector_store")
if vector_store_config:
collection_name = _get_collection_name(vector_store_config, embedding_name)
index_name = _get_index_name(vector_store_config, embedding_name)
vector_store: BaseVectorStore = _create_vector_store(
vector_store_config, collection_name
vector_store_config, index_name, embedding_name
)
vector_store_workflow_config = vector_store_config.get(
embedding_name, vector_store_config
@ -183,27 +184,38 @@ async def _text_embed_with_vector_store(
def _create_vector_store(
vector_store_config: dict, collection_name: str
vector_store_config: dict, index_name: str, embedding_name: str | None = None
) -> BaseVectorStore:
vector_store_type: str = str(vector_store_config.get("type"))
if collection_name:
vector_store_config.update({"collection_name": collection_name})
embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get("embeddings_schema", {})
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
if embeddings_schema is not None and embedding_name is not None and embedding_name in embeddings_schema:
raw_config = embeddings_schema[embedding_name]
if isinstance(raw_config, dict):
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
else:
single_embedding_config = raw_config
if single_embedding_config.index_name == "":
single_embedding_config.index_name = index_name
vector_store = VectorStoreFactory().create_vector_store(
vector_store_type, kwargs=vector_store_config
vector_store_schema_config=single_embedding_config, vector_store_type=vector_store_type, kwargs=vector_store_config
)
vector_store.connect(**vector_store_config)
return vector_store
def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str:
def _get_index_name(vector_store_config: dict, embedding_name: str) -> str:
container_name = vector_store_config.get("container_name", "default")
collection_name = create_collection_name(container_name, embedding_name)
index_name = create_index_name(container_name, embedding_name)
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}"
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {index_name}"
logger.info(msg)
return collection_name
return index_name
def load_strategy(strategy: TextEmbedStrategyType) -> TextEmbeddingStrategy:

View File

@ -8,9 +8,10 @@ from typing import Any
from graphrag.cache.factory import CacheFactory
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.embeddings import create_collection_name
from graphrag.config.embeddings import create_index_name
from graphrag.config.models.cache_config import CacheConfig
from graphrag.config.models.storage_config import StorageConfig
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.data_model.types import TextEmbedder
from graphrag.storage.factory import StorageFactory
from graphrag.storage.pipeline_storage import PipelineStorage
@ -103,12 +104,27 @@ def get_embedding_store(
index_names = []
for index, store in config_args.items():
vector_store_type = store["type"]
collection_name = create_collection_name(
index_name = create_index_name(
store.get("container_name", "default"), embedding_name
)
embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get("embeddings_schema", {})
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
if embeddings_schema is not None and embedding_name is not None and embedding_name in embeddings_schema:
raw_config = embeddings_schema[embedding_name]
if isinstance(raw_config, dict):
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
else:
single_embedding_config = raw_config
if single_embedding_config.index_name == "":
single_embedding_config.index_name = index_name
embedding_store = VectorStoreFactory().create_vector_store(
vector_store_type=vector_store_type,
kwargs={**store, "collection_name": collection_name},
vector_store_schema_config=single_embedding_config,
kwargs={**store},
)
embedding_store.connect(**store)
# If there is only a single index, return the embedding store directly

View File

@ -24,9 +24,9 @@ from azure.search.documents.indexes.models import (
)
from azure.search.documents.models import VectorizedQuery
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.data_model.types import TextEmbedder
from graphrag.vector_stores.base import (
DEFAULT_VECTOR_SIZE,
BaseVectorStore,
VectorStoreDocument,
VectorStoreSearchResult,
@ -38,15 +38,14 @@ class AzureAISearchVectorStore(BaseVectorStore):
index_client: SearchIndexClient
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
def __init__(self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any) -> None:
super().__init__(vector_store_schema_config=vector_store_schema_config, **kwargs)
def connect(self, **kwargs: Any) -> Any:
"""Connect to AI search vector storage."""
url = kwargs["url"]
api_key = kwargs.get("api_key")
audience = kwargs.get("audience")
self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE)
self.vector_search_profile_name = kwargs.get(
"vector_search_profile_name", "vectorSearchProfile"
@ -56,7 +55,7 @@ class AzureAISearchVectorStore(BaseVectorStore):
audience_arg = {"audience": audience} if audience and not api_key else {}
self.db_connection = SearchClient(
endpoint=url,
index_name=self.collection_name,
index_name=self.index_name,
credential=(
AzureKeyCredential(api_key) if api_key else DefaultAzureCredential()
),
@ -78,8 +77,8 @@ class AzureAISearchVectorStore(BaseVectorStore):
) -> None:
"""Load documents into an Azure AI Search index."""
if overwrite:
if self.collection_name in self.index_client.list_index_names():
self.index_client.delete_index(self.collection_name)
if self.index_name != "" and self.index_name in self.index_client.list_index_names():
self.index_client.delete_index(self.index_name)
# Configure vector search profile
vector_search = VectorSearch(
@ -100,25 +99,23 @@ class AzureAISearchVectorStore(BaseVectorStore):
)
# Configure the index
index = SearchIndex(
name=self.collection_name,
name=self.index_name,
fields=[
SimpleField(
name="id",
name=self.id_field,
type=SearchFieldDataType.String,
key=True,
),
SearchField(
name="vector",
name=self.vector_field,
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
searchable=True,
hidden=False, # DRIFT needs to return the vector for client-side similarity
vector_search_dimensions=self.vector_size,
vector_search_profile_name=self.vector_search_profile_name,
),
SearchableField(name="text", type=SearchFieldDataType.String),
SimpleField(
name="attributes",
type=SearchFieldDataType.String,
SearchableField(name=self.text_field, type=SearchFieldDataType.String),
SimpleField(name=self.attributes_field, type=SearchFieldDataType.String,
),
],
vector_search=vector_search,
@ -129,10 +126,10 @@ class AzureAISearchVectorStore(BaseVectorStore):
batch = [
{
"id": doc.id,
"vector": doc.vector,
"text": doc.text,
"attributes": json.dumps(doc.attributes),
self.id_field: doc.id,
self.vector_field: doc.vector,
self.text_field: doc.text,
self.attributes_field: json.dumps(doc.attributes),
}
for doc in documents
if doc.vector is not None
@ -162,7 +159,7 @@ class AzureAISearchVectorStore(BaseVectorStore):
) -> list[VectorStoreSearchResult]:
"""Perform a vector-based similarity search."""
vectorized_query = VectorizedQuery(
vector=query_embedding, k_nearest_neighbors=k, fields="vector"
vector=query_embedding, k_nearest_neighbors=k, fields=self.vector_field
)
response = self.db_connection.search(
@ -172,10 +169,10 @@ class AzureAISearchVectorStore(BaseVectorStore):
return [
VectorStoreSearchResult(
document=VectorStoreDocument(
id=doc.get("id", ""),
text=doc.get("text", ""),
vector=doc.get("vector", []),
attributes=(json.loads(doc.get("attributes", "{}"))),
id=doc.get(self.id_field, ""),
text=doc.get(self.text_field, ""),
vector=doc.get(self.vector_field, []),
attributes=(json.loads(doc.get(self.attributes_field, "{}"))),
),
# Cosine similarity between 0.333 and 1.000
# https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking#scores-in-a-hybrid-search-results
@ -199,8 +196,8 @@ class AzureAISearchVectorStore(BaseVectorStore):
"""Search for a document by id."""
response = self.db_connection.get_document(id)
return VectorStoreDocument(
id=response.get("id", ""),
text=response.get("text", ""),
vector=response.get("vector", []),
attributes=(json.loads(response.get("attributes", "{}"))),
id=response.get(self.id_field, ""),
text=response.get(self.text_field, ""),
vector=response.get(self.vector_field, []),
attributes=(json.loads(response.get(self.attributes_field, "{}"))),
)

View File

@ -7,10 +7,9 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.data_model.types import TextEmbedder
DEFAULT_VECTOR_SIZE: int = 1536
@dataclass
class VectorStoreDocument:
@ -42,17 +41,24 @@ class BaseVectorStore(ABC):
def __init__(
self,
collection_name: str,
vector_store_schema_config: VectorStoreSchemaConfig,
db_connection: Any | None = None,
document_collection: Any | None = None,
query_filter: Any | None = None,
**kwargs: Any,
):
self.collection_name = collection_name
self.db_connection = db_connection
self.document_collection = document_collection
self.query_filter = query_filter
self.kwargs = kwargs
self.index_name = vector_store_schema_config.index_name
self.id_field = vector_store_schema_config.id_field
self.text_field = vector_store_schema_config.text_field
self.vector_field = vector_store_schema_config.vector_field
self.attributes_field = vector_store_schema_config.attributes_field
self.vector_size = vector_store_schema_config.vector_size
@abstractmethod
def connect(self, **kwargs: Any) -> None:

View File

@ -13,7 +13,6 @@ from azure.identity import DefaultAzureCredential
from graphrag.data_model.types import TextEmbedder
from graphrag.vector_stores.base import (
DEFAULT_VECTOR_SIZE,
BaseVectorStore,
VectorStoreDocument,
VectorStoreSearchResult,
@ -49,13 +48,13 @@ class CosmosDBVectorStore(BaseVectorStore):
msg = "Database name must be provided."
raise ValueError(msg)
self._database_name = database_name
collection_name = self.collection_name
collection_name = self.index_name
if collection_name is None:
msg = "Collection name is empty or not provided."
raise ValueError(msg)
self._container_name = collection_name
self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE)
self.vector_size = kwargs.get("vector_size", 1024) #TODO GAUDY fix it
self._create_database()
self._create_container()

View File

@ -15,6 +15,9 @@ from graphrag.vector_stores.lancedb import LanceDBVectorStore
if TYPE_CHECKING:
from collections.abc import Callable
from graphrag.config.models.vector_store_schema_config import (
VectorStoreSchemaConfig,
)
from graphrag.vector_stores.base import BaseVectorStore
@ -47,7 +50,7 @@ class VectorStoreFactory:
@classmethod
def create_vector_store(
cls, vector_store_type: str, kwargs: dict
cls, vector_store_type: str, vector_store_schema_config: VectorStoreSchemaConfig, kwargs: dict
) -> BaseVectorStore:
"""Create a vector store object from the provided type.
@ -67,8 +70,11 @@ class VectorStoreFactory:
msg = f"Unknown vector store type: {vector_store_type}"
raise ValueError(msg)
return cls._registry[vector_store_type](**kwargs)
return cls._registry[vector_store_type](
vector_store_schema_config=vector_store_schema_config,
**kwargs
)
@classmethod
def get_vector_store_types(cls) -> list[str]:
"""Get the registered vector store implementations."""

View File

@ -8,6 +8,7 @@ from typing import Any
import pyarrow as pa
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.data_model.types import TextEmbedder
from graphrag.vector_stores.base import (
@ -21,18 +22,19 @@ import lancedb
class LanceDBVectorStore(BaseVectorStore):
"""LanceDB vector storage implementation."""
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
def __init__(self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any) -> None:
super().__init__(vector_store_schema_config=vector_store_schema_config, **kwargs)
def connect(self, **kwargs: Any) -> Any:
"""Connect to the vector storage."""
self.db_connection = lancedb.connect(kwargs["db_uri"])
if (
self.collection_name
and self.collection_name in self.db_connection.table_names()
self.index_name
and self.index_name in self.db_connection.table_names()
):
self.document_collection = self.db_connection.open_table(
self.collection_name
self.index_name
)
def load_documents(
@ -41,10 +43,10 @@ class LanceDBVectorStore(BaseVectorStore):
"""Load documents into vector storage."""
data = [
{
"id": document.id,
"text": document.text,
"vector": document.vector,
"attributes": json.dumps(document.attributes),
self.id_field: document.id,
self.text_field: document.text,
self.vector_field: document.vector,
self.attributes_field: json.dumps(document.attributes),
}
for document in documents
if document.vector is not None
@ -54,10 +56,10 @@ class LanceDBVectorStore(BaseVectorStore):
data = None
schema = pa.schema([
pa.field("id", pa.string()),
pa.field("text", pa.string()),
pa.field("vector", pa.list_(pa.float64())),
pa.field("attributes", pa.string()),
pa.field(self.id_field, pa.string()),
pa.field(self.text_field, pa.string()),
pa.field(self.vector_field, pa.list_(pa.float64())),
pa.field(self.attributes_field, pa.string()),
])
# NOTE: If modifying the next section of code, ensure that the schema remains the same.
# The pyarrow format of the 'vector' field may change if the order of operations is changed
@ -65,19 +67,20 @@ class LanceDBVectorStore(BaseVectorStore):
if overwrite:
if data:
self.document_collection = self.db_connection.create_table(
self.collection_name, data=data, mode="overwrite"
self.index_name, data=data, mode="overwrite"
)
else:
self.document_collection = self.db_connection.create_table(
self.collection_name, schema=schema, mode="overwrite"
self.index_name, schema=schema, mode="overwrite"
)
else:
# add data to existing table
self.document_collection = self.db_connection.open_table(
self.collection_name
self.index_name
)
if data:
self.document_collection.add(data)
self.document_collection.create_index(vector_column_name=self.vector_field)
def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
"""Build a query filter to filter documents by id."""
@ -100,7 +103,7 @@ class LanceDBVectorStore(BaseVectorStore):
if self.query_filter:
docs = (
self.document_collection.search(
query=query_embedding, vector_column_name="vector"
query=query_embedding, vector_column_name=self.vector_field
)
.where(self.query_filter, prefilter=True)
.limit(k)
@ -109,7 +112,7 @@ class LanceDBVectorStore(BaseVectorStore):
else:
docs = (
self.document_collection.search(
query=query_embedding, vector_column_name="vector"
query=query_embedding, vector_column_name=self.vector_field
)
.limit(k)
.to_list()
@ -117,10 +120,10 @@ class LanceDBVectorStore(BaseVectorStore):
return [
VectorStoreSearchResult(
document=VectorStoreDocument(
id=doc["id"],
text=doc["text"],
vector=doc["vector"],
attributes=json.loads(doc["attributes"]),
id=doc[self.id_field],
text=doc[self.text_field],
vector=doc[self.vector_field],
attributes=json.loads(doc[self.attributes_field]),
),
score=1 - abs(float(doc["_distance"])),
)
@ -145,9 +148,9 @@ class LanceDBVectorStore(BaseVectorStore):
)
if doc:
return VectorStoreDocument(
id=doc[0]["id"],
text=doc[0]["text"],
vector=doc[0]["vector"],
attributes=json.loads(doc[0]["attributes"]),
id=doc[0][self.id_field],
text=doc[0][self.text_field],
vector=doc[0][self.vector_field],
attributes=json.loads(doc[0][self.attributes_field]),
)
return VectorStoreDocument(id=id, text=None, vector=None)

View File

@ -3,19 +3,19 @@
import pytest
from graphrag.config.embeddings import create_collection_name
from graphrag.config.embeddings import create_index_name
def test_create_collection_name():
collection = create_collection_name("default", "entity.title")
def test_create_index_name():
collection = create_index_name("default", "entity.title")
assert collection == "default-entity-title"
def test_create_collection_name_invalid_embedding_throws():
def test_create_index_name_invalid_embedding_throws():
with pytest.raises(KeyError):
create_collection_name("default", "invalid.name")
create_index_name("default", "invalid.name")
def test_create_collection_name_invalid_embedding_does_not_throw():
collection = create_collection_name("default", "invalid.name", validate=False)
def test_create_index_name_invalid_embedding_does_not_throw():
collection = create_index_name("default", "invalid.name", validate=False)
assert collection == "default-invalid-name"