diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index dc84fc45..2e108c27 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -393,6 +393,7 @@ class VectorStoreDefaults: api_key: None = None audience: None = None database_name: None = None + schema: None = None @dataclass diff --git a/graphrag/config/embeddings.py b/graphrag/config/embeddings.py index 4865da55..f1502385 100644 --- a/graphrag/config/embeddings.py +++ b/graphrag/config/embeddings.py @@ -29,14 +29,14 @@ default_embeddings: list[str] = [ ] -def create_collection_name( +def create_index_name( container_name: str, embedding_name: str, validate: bool = True ) -> str: """ - Create a collection name for the embedding store. + Create a index name for the embedding store. Within any given vector store, we can have multiple sets of embeddings organized into projects. - The `container` param is used for this partitioning, and is added as a prefix to the collection name for differentiation. + The `container` param is used for this partitioning, and is added as a prefix to the index name for differentiation. The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings diff --git a/graphrag/config/models/vector_store_config.py b/graphrag/config/models/vector_store_config.py index d327baad..96f0bbf3 100644 --- a/graphrag/config/models/vector_store_config.py +++ b/graphrag/config/models/vector_store_config.py @@ -6,7 +6,9 @@ from pydantic import BaseModel, Field, model_validator from graphrag.config.defaults import vector_store_defaults +from graphrag.config.embeddings import all_embeddings from graphrag.config.enums import VectorStoreType +from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig class VectorStoreConfig(BaseModel): @@ -85,9 +87,19 @@ class VectorStoreConfig(BaseModel): default=vector_store_defaults.overwrite, ) + embeddings_schema: dict[str, VectorStoreSchemaConfig] = {} + + def _validate_embeddings_schema(self) -> None: + """Validate the embeddings schema.""" + for name in self.embeddings_schema: + if name not in all_embeddings: + msg = f"vector_store.embeddings_schema contains an invalid embedding schema name: {name}. Please rerun `graphrag init` and select the correct embedding schema names." + raise ValueError(msg) + @model_validator(mode="after") def _validate_model(self): """Validate the model.""" self._validate_db_uri() self._validate_url() + self._validate_embeddings_schema() return self diff --git a/graphrag/config/models/vector_store_schema_config.py b/graphrag/config/models/vector_store_schema_config.py new file mode 100644 index 00000000..7a95e5ef --- /dev/null +++ b/graphrag/config/models/vector_store_schema_config.py @@ -0,0 +1,51 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field, model_validator + +DEFAULT_VECTOR_SIZE: int = 1536 + +class VectorStoreSchemaConfig(BaseModel): + """The default configuration section for Vector Store Schema.""" + + index_name: str = Field( + description="The index name to use.", + default="" + ) + + id_field: str = Field( + description="The ID field to use.", + default="id", + ) + + vector_field: str = Field( + description="The vector field to use.", + default="vector", + ) + + text_field: str = Field( + description="The text field to use.", + default="text", + ) + + attributes_field: str = Field( + description="The attributes field to use.", + default="attributes", + ) + + vector_size: int = Field( + description="The vector size to use.", + default=DEFAULT_VECTOR_SIZE, + ) + + #TODO GAUDY + def _validate_schema(self) -> None: + """Validate the schema.""" + + @model_validator(mode="after") + def _validate_model(self): + """Validate the model.""" + self._validate_schema() + return self diff --git a/graphrag/index/operations/embed_text/embed_text.py b/graphrag/index/operations/embed_text/embed_text.py index 96e6111e..21154a00 100644 --- a/graphrag/index/operations/embed_text/embed_text.py +++ b/graphrag/index/operations/embed_text/embed_text.py @@ -12,7 +12,8 @@ import pandas as pd from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks -from graphrag.config.embeddings import create_collection_name +from graphrag.config.embeddings import create_index_name +from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument from graphrag.vector_stores.factory import VectorStoreFactory @@ -49,9 +50,9 @@ async def embed_text( vector_store_config = strategy.get("vector_store") if vector_store_config: - collection_name = _get_collection_name(vector_store_config, embedding_name) + index_name = _get_index_name(vector_store_config, embedding_name) vector_store: BaseVectorStore = _create_vector_store( - vector_store_config, collection_name + vector_store_config, index_name, embedding_name ) vector_store_workflow_config = vector_store_config.get( embedding_name, vector_store_config @@ -183,27 +184,38 @@ async def _text_embed_with_vector_store( def _create_vector_store( - vector_store_config: dict, collection_name: str + vector_store_config: dict, index_name: str, embedding_name: str | None = None ) -> BaseVectorStore: vector_store_type: str = str(vector_store_config.get("type")) - if collection_name: - vector_store_config.update({"collection_name": collection_name}) + + embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get("embeddings_schema", {}) + single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig() + + if embeddings_schema is not None and embedding_name is not None and embedding_name in embeddings_schema: + raw_config = embeddings_schema[embedding_name] + if isinstance(raw_config, dict): + single_embedding_config = VectorStoreSchemaConfig(**raw_config) + else: + single_embedding_config = raw_config + + if single_embedding_config.index_name == "": + single_embedding_config.index_name = index_name vector_store = VectorStoreFactory().create_vector_store( - vector_store_type, kwargs=vector_store_config + vector_store_schema_config=single_embedding_config, vector_store_type=vector_store_type, kwargs=vector_store_config ) vector_store.connect(**vector_store_config) return vector_store -def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str: +def _get_index_name(vector_store_config: dict, embedding_name: str) -> str: container_name = vector_store_config.get("container_name", "default") - collection_name = create_collection_name(container_name, embedding_name) + index_name = create_index_name(container_name, embedding_name) - msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}" + msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {index_name}" logger.info(msg) - return collection_name + return index_name def load_strategy(strategy: TextEmbedStrategyType) -> TextEmbeddingStrategy: diff --git a/graphrag/utils/api.py b/graphrag/utils/api.py index 96ca15a6..f333543c 100644 --- a/graphrag/utils/api.py +++ b/graphrag/utils/api.py @@ -8,9 +8,10 @@ from typing import Any from graphrag.cache.factory import CacheFactory from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.config.embeddings import create_collection_name +from graphrag.config.embeddings import create_index_name from graphrag.config.models.cache_config import CacheConfig from graphrag.config.models.storage_config import StorageConfig +from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig from graphrag.data_model.types import TextEmbedder from graphrag.storage.factory import StorageFactory from graphrag.storage.pipeline_storage import PipelineStorage @@ -103,12 +104,27 @@ def get_embedding_store( index_names = [] for index, store in config_args.items(): vector_store_type = store["type"] - collection_name = create_collection_name( + index_name = create_index_name( store.get("container_name", "default"), embedding_name ) + + embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get("embeddings_schema", {}) + single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig() + + if embeddings_schema is not None and embedding_name is not None and embedding_name in embeddings_schema: + raw_config = embeddings_schema[embedding_name] + if isinstance(raw_config, dict): + single_embedding_config = VectorStoreSchemaConfig(**raw_config) + else: + single_embedding_config = raw_config + + if single_embedding_config.index_name == "": + single_embedding_config.index_name = index_name + embedding_store = VectorStoreFactory().create_vector_store( vector_store_type=vector_store_type, - kwargs={**store, "collection_name": collection_name}, + vector_store_schema_config=single_embedding_config, + kwargs={**store}, ) embedding_store.connect(**store) # If there is only a single index, return the embedding store directly diff --git a/graphrag/vector_stores/azure_ai_search.py b/graphrag/vector_stores/azure_ai_search.py index 89c7f9d4..fc138a96 100644 --- a/graphrag/vector_stores/azure_ai_search.py +++ b/graphrag/vector_stores/azure_ai_search.py @@ -24,9 +24,9 @@ from azure.search.documents.indexes.models import ( ) from azure.search.documents.models import VectorizedQuery +from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig from graphrag.data_model.types import TextEmbedder from graphrag.vector_stores.base import ( - DEFAULT_VECTOR_SIZE, BaseVectorStore, VectorStoreDocument, VectorStoreSearchResult, @@ -38,15 +38,14 @@ class AzureAISearchVectorStore(BaseVectorStore): index_client: SearchIndexClient - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) + def __init__(self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any) -> None: + super().__init__(vector_store_schema_config=vector_store_schema_config, **kwargs) def connect(self, **kwargs: Any) -> Any: """Connect to AI search vector storage.""" url = kwargs["url"] api_key = kwargs.get("api_key") audience = kwargs.get("audience") - self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE) self.vector_search_profile_name = kwargs.get( "vector_search_profile_name", "vectorSearchProfile" @@ -56,7 +55,7 @@ class AzureAISearchVectorStore(BaseVectorStore): audience_arg = {"audience": audience} if audience and not api_key else {} self.db_connection = SearchClient( endpoint=url, - index_name=self.collection_name, + index_name=self.index_name, credential=( AzureKeyCredential(api_key) if api_key else DefaultAzureCredential() ), @@ -78,8 +77,8 @@ class AzureAISearchVectorStore(BaseVectorStore): ) -> None: """Load documents into an Azure AI Search index.""" if overwrite: - if self.collection_name in self.index_client.list_index_names(): - self.index_client.delete_index(self.collection_name) + if self.index_name != "" and self.index_name in self.index_client.list_index_names(): + self.index_client.delete_index(self.index_name) # Configure vector search profile vector_search = VectorSearch( @@ -100,25 +99,23 @@ class AzureAISearchVectorStore(BaseVectorStore): ) # Configure the index index = SearchIndex( - name=self.collection_name, + name=self.index_name, fields=[ SimpleField( - name="id", + name=self.id_field, type=SearchFieldDataType.String, key=True, ), SearchField( - name="vector", + name=self.vector_field, type=SearchFieldDataType.Collection(SearchFieldDataType.Single), searchable=True, hidden=False, # DRIFT needs to return the vector for client-side similarity vector_search_dimensions=self.vector_size, vector_search_profile_name=self.vector_search_profile_name, ), - SearchableField(name="text", type=SearchFieldDataType.String), - SimpleField( - name="attributes", - type=SearchFieldDataType.String, + SearchableField(name=self.text_field, type=SearchFieldDataType.String), + SimpleField(name=self.attributes_field, type=SearchFieldDataType.String, ), ], vector_search=vector_search, @@ -129,10 +126,10 @@ class AzureAISearchVectorStore(BaseVectorStore): batch = [ { - "id": doc.id, - "vector": doc.vector, - "text": doc.text, - "attributes": json.dumps(doc.attributes), + self.id_field: doc.id, + self.vector_field: doc.vector, + self.text_field: doc.text, + self.attributes_field: json.dumps(doc.attributes), } for doc in documents if doc.vector is not None @@ -162,7 +159,7 @@ class AzureAISearchVectorStore(BaseVectorStore): ) -> list[VectorStoreSearchResult]: """Perform a vector-based similarity search.""" vectorized_query = VectorizedQuery( - vector=query_embedding, k_nearest_neighbors=k, fields="vector" + vector=query_embedding, k_nearest_neighbors=k, fields=self.vector_field ) response = self.db_connection.search( @@ -172,10 +169,10 @@ class AzureAISearchVectorStore(BaseVectorStore): return [ VectorStoreSearchResult( document=VectorStoreDocument( - id=doc.get("id", ""), - text=doc.get("text", ""), - vector=doc.get("vector", []), - attributes=(json.loads(doc.get("attributes", "{}"))), + id=doc.get(self.id_field, ""), + text=doc.get(self.text_field, ""), + vector=doc.get(self.vector_field, []), + attributes=(json.loads(doc.get(self.attributes_field, "{}"))), ), # Cosine similarity between 0.333 and 1.000 # https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking#scores-in-a-hybrid-search-results @@ -199,8 +196,8 @@ class AzureAISearchVectorStore(BaseVectorStore): """Search for a document by id.""" response = self.db_connection.get_document(id) return VectorStoreDocument( - id=response.get("id", ""), - text=response.get("text", ""), - vector=response.get("vector", []), - attributes=(json.loads(response.get("attributes", "{}"))), + id=response.get(self.id_field, ""), + text=response.get(self.text_field, ""), + vector=response.get(self.vector_field, []), + attributes=(json.loads(response.get(self.attributes_field, "{}"))), ) diff --git a/graphrag/vector_stores/base.py b/graphrag/vector_stores/base.py index c4b5e40c..331b727f 100644 --- a/graphrag/vector_stores/base.py +++ b/graphrag/vector_stores/base.py @@ -7,10 +7,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any +from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig from graphrag.data_model.types import TextEmbedder -DEFAULT_VECTOR_SIZE: int = 1536 - @dataclass class VectorStoreDocument: @@ -42,17 +41,24 @@ class BaseVectorStore(ABC): def __init__( self, - collection_name: str, + vector_store_schema_config: VectorStoreSchemaConfig, db_connection: Any | None = None, document_collection: Any | None = None, query_filter: Any | None = None, **kwargs: Any, ): - self.collection_name = collection_name self.db_connection = db_connection self.document_collection = document_collection self.query_filter = query_filter self.kwargs = kwargs + + self.index_name = vector_store_schema_config.index_name + self.id_field = vector_store_schema_config.id_field + self.text_field = vector_store_schema_config.text_field + self.vector_field = vector_store_schema_config.vector_field + self.attributes_field = vector_store_schema_config.attributes_field + self.vector_size = vector_store_schema_config.vector_size + @abstractmethod def connect(self, **kwargs: Any) -> None: diff --git a/graphrag/vector_stores/cosmosdb.py b/graphrag/vector_stores/cosmosdb.py index 9c736076..1a9f95ef 100644 --- a/graphrag/vector_stores/cosmosdb.py +++ b/graphrag/vector_stores/cosmosdb.py @@ -13,7 +13,6 @@ from azure.identity import DefaultAzureCredential from graphrag.data_model.types import TextEmbedder from graphrag.vector_stores.base import ( - DEFAULT_VECTOR_SIZE, BaseVectorStore, VectorStoreDocument, VectorStoreSearchResult, @@ -49,13 +48,13 @@ class CosmosDBVectorStore(BaseVectorStore): msg = "Database name must be provided." raise ValueError(msg) self._database_name = database_name - collection_name = self.collection_name + collection_name = self.index_name if collection_name is None: msg = "Collection name is empty or not provided." raise ValueError(msg) self._container_name = collection_name - self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE) + self.vector_size = kwargs.get("vector_size", 1024) #TODO GAUDY fix it self._create_database() self._create_container() diff --git a/graphrag/vector_stores/factory.py b/graphrag/vector_stores/factory.py index 5c81ce9c..6dc76092 100644 --- a/graphrag/vector_stores/factory.py +++ b/graphrag/vector_stores/factory.py @@ -15,6 +15,9 @@ from graphrag.vector_stores.lancedb import LanceDBVectorStore if TYPE_CHECKING: from collections.abc import Callable + from graphrag.config.models.vector_store_schema_config import ( + VectorStoreSchemaConfig, + ) from graphrag.vector_stores.base import BaseVectorStore @@ -47,7 +50,7 @@ class VectorStoreFactory: @classmethod def create_vector_store( - cls, vector_store_type: str, kwargs: dict + cls, vector_store_type: str, vector_store_schema_config: VectorStoreSchemaConfig, kwargs: dict ) -> BaseVectorStore: """Create a vector store object from the provided type. @@ -67,8 +70,11 @@ class VectorStoreFactory: msg = f"Unknown vector store type: {vector_store_type}" raise ValueError(msg) - return cls._registry[vector_store_type](**kwargs) - + return cls._registry[vector_store_type]( + vector_store_schema_config=vector_store_schema_config, + **kwargs + ) + @classmethod def get_vector_store_types(cls) -> list[str]: """Get the registered vector store implementations.""" diff --git a/graphrag/vector_stores/lancedb.py b/graphrag/vector_stores/lancedb.py index 07ab3490..0da2df50 100644 --- a/graphrag/vector_stores/lancedb.py +++ b/graphrag/vector_stores/lancedb.py @@ -8,6 +8,7 @@ from typing import Any import pyarrow as pa +from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig from graphrag.data_model.types import TextEmbedder from graphrag.vector_stores.base import ( @@ -21,18 +22,19 @@ import lancedb class LanceDBVectorStore(BaseVectorStore): """LanceDB vector storage implementation.""" - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) + def __init__(self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any) -> None: + super().__init__(vector_store_schema_config=vector_store_schema_config, **kwargs) def connect(self, **kwargs: Any) -> Any: """Connect to the vector storage.""" self.db_connection = lancedb.connect(kwargs["db_uri"]) + if ( - self.collection_name - and self.collection_name in self.db_connection.table_names() + self.index_name + and self.index_name in self.db_connection.table_names() ): self.document_collection = self.db_connection.open_table( - self.collection_name + self.index_name ) def load_documents( @@ -41,10 +43,10 @@ class LanceDBVectorStore(BaseVectorStore): """Load documents into vector storage.""" data = [ { - "id": document.id, - "text": document.text, - "vector": document.vector, - "attributes": json.dumps(document.attributes), + self.id_field: document.id, + self.text_field: document.text, + self.vector_field: document.vector, + self.attributes_field: json.dumps(document.attributes), } for document in documents if document.vector is not None @@ -54,10 +56,10 @@ class LanceDBVectorStore(BaseVectorStore): data = None schema = pa.schema([ - pa.field("id", pa.string()), - pa.field("text", pa.string()), - pa.field("vector", pa.list_(pa.float64())), - pa.field("attributes", pa.string()), + pa.field(self.id_field, pa.string()), + pa.field(self.text_field, pa.string()), + pa.field(self.vector_field, pa.list_(pa.float64())), + pa.field(self.attributes_field, pa.string()), ]) # NOTE: If modifying the next section of code, ensure that the schema remains the same. # The pyarrow format of the 'vector' field may change if the order of operations is changed @@ -65,19 +67,20 @@ class LanceDBVectorStore(BaseVectorStore): if overwrite: if data: self.document_collection = self.db_connection.create_table( - self.collection_name, data=data, mode="overwrite" + self.index_name, data=data, mode="overwrite" ) else: self.document_collection = self.db_connection.create_table( - self.collection_name, schema=schema, mode="overwrite" + self.index_name, schema=schema, mode="overwrite" ) else: # add data to existing table self.document_collection = self.db_connection.open_table( - self.collection_name + self.index_name ) if data: self.document_collection.add(data) + self.document_collection.create_index(vector_column_name=self.vector_field) def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: """Build a query filter to filter documents by id.""" @@ -100,7 +103,7 @@ class LanceDBVectorStore(BaseVectorStore): if self.query_filter: docs = ( self.document_collection.search( - query=query_embedding, vector_column_name="vector" + query=query_embedding, vector_column_name=self.vector_field ) .where(self.query_filter, prefilter=True) .limit(k) @@ -109,7 +112,7 @@ class LanceDBVectorStore(BaseVectorStore): else: docs = ( self.document_collection.search( - query=query_embedding, vector_column_name="vector" + query=query_embedding, vector_column_name=self.vector_field ) .limit(k) .to_list() @@ -117,10 +120,10 @@ class LanceDBVectorStore(BaseVectorStore): return [ VectorStoreSearchResult( document=VectorStoreDocument( - id=doc["id"], - text=doc["text"], - vector=doc["vector"], - attributes=json.loads(doc["attributes"]), + id=doc[self.id_field], + text=doc[self.text_field], + vector=doc[self.vector_field], + attributes=json.loads(doc[self.attributes_field]), ), score=1 - abs(float(doc["_distance"])), ) @@ -145,9 +148,9 @@ class LanceDBVectorStore(BaseVectorStore): ) if doc: return VectorStoreDocument( - id=doc[0]["id"], - text=doc[0]["text"], - vector=doc[0]["vector"], - attributes=json.loads(doc[0]["attributes"]), + id=doc[0][self.id_field], + text=doc[0][self.text_field], + vector=doc[0][self.vector_field], + attributes=json.loads(doc[0][self.attributes_field]), ) return VectorStoreDocument(id=id, text=None, vector=None) diff --git a/tests/unit/utils/test_embeddings.py b/tests/unit/utils/test_embeddings.py index 854bf081..9349f0c8 100644 --- a/tests/unit/utils/test_embeddings.py +++ b/tests/unit/utils/test_embeddings.py @@ -3,19 +3,19 @@ import pytest -from graphrag.config.embeddings import create_collection_name +from graphrag.config.embeddings import create_index_name -def test_create_collection_name(): - collection = create_collection_name("default", "entity.title") +def test_create_index_name(): + collection = create_index_name("default", "entity.title") assert collection == "default-entity-title" -def test_create_collection_name_invalid_embedding_throws(): +def test_create_index_name_invalid_embedding_throws(): with pytest.raises(KeyError): - create_collection_name("default", "invalid.name") + create_index_name("default", "invalid.name") -def test_create_collection_name_invalid_embedding_does_not_throw(): - collection = create_collection_name("default", "invalid.name", validate=False) +def test_create_index_name_invalid_embedding_does_not_throw(): + collection = create_index_name("default", "invalid.name", validate=False) assert collection == "default-invalid-name"