mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 09:07:20 +08:00
progress on vector customization
This commit is contained in:
parent
1cb20b66f5
commit
bfaa7ef016
@ -393,6 +393,7 @@ class VectorStoreDefaults:
|
||||
api_key: None = None
|
||||
audience: None = None
|
||||
database_name: None = None
|
||||
schema: None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -29,14 +29,14 @@ default_embeddings: list[str] = [
|
||||
]
|
||||
|
||||
|
||||
def create_collection_name(
|
||||
def create_index_name(
|
||||
container_name: str, embedding_name: str, validate: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Create a collection name for the embedding store.
|
||||
Create a index name for the embedding store.
|
||||
|
||||
Within any given vector store, we can have multiple sets of embeddings organized into projects.
|
||||
The `container` param is used for this partitioning, and is added as a prefix to the collection name for differentiation.
|
||||
The `container` param is used for this partitioning, and is added as a prefix to the index name for differentiation.
|
||||
|
||||
The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings
|
||||
|
||||
|
||||
@ -6,7 +6,9 @@
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from graphrag.config.defaults import vector_store_defaults
|
||||
from graphrag.config.embeddings import all_embeddings
|
||||
from graphrag.config.enums import VectorStoreType
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
|
||||
|
||||
class VectorStoreConfig(BaseModel):
|
||||
@ -85,9 +87,19 @@ class VectorStoreConfig(BaseModel):
|
||||
default=vector_store_defaults.overwrite,
|
||||
)
|
||||
|
||||
embeddings_schema: dict[str, VectorStoreSchemaConfig] = {}
|
||||
|
||||
def _validate_embeddings_schema(self) -> None:
|
||||
"""Validate the embeddings schema."""
|
||||
for name in self.embeddings_schema:
|
||||
if name not in all_embeddings:
|
||||
msg = f"vector_store.embeddings_schema contains an invalid embedding schema name: {name}. Please rerun `graphrag init` and select the correct embedding schema names."
|
||||
raise ValueError(msg)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_model(self):
|
||||
"""Validate the model."""
|
||||
self._validate_db_uri()
|
||||
self._validate_url()
|
||||
self._validate_embeddings_schema()
|
||||
return self
|
||||
|
||||
51
graphrag/config/models/vector_store_schema_config.py
Normal file
51
graphrag/config/models/vector_store_schema_config.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Parameterization settings for the default configuration."""
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
DEFAULT_VECTOR_SIZE: int = 1536
|
||||
|
||||
class VectorStoreSchemaConfig(BaseModel):
|
||||
"""The default configuration section for Vector Store Schema."""
|
||||
|
||||
index_name: str = Field(
|
||||
description="The index name to use.",
|
||||
default=""
|
||||
)
|
||||
|
||||
id_field: str = Field(
|
||||
description="The ID field to use.",
|
||||
default="id",
|
||||
)
|
||||
|
||||
vector_field: str = Field(
|
||||
description="The vector field to use.",
|
||||
default="vector",
|
||||
)
|
||||
|
||||
text_field: str = Field(
|
||||
description="The text field to use.",
|
||||
default="text",
|
||||
)
|
||||
|
||||
attributes_field: str = Field(
|
||||
description="The attributes field to use.",
|
||||
default="attributes",
|
||||
)
|
||||
|
||||
vector_size: int = Field(
|
||||
description="The vector size to use.",
|
||||
default=DEFAULT_VECTOR_SIZE,
|
||||
)
|
||||
|
||||
#TODO GAUDY
|
||||
def _validate_schema(self) -> None:
|
||||
"""Validate the schema."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_model(self):
|
||||
"""Validate the model."""
|
||||
self._validate_schema()
|
||||
return self
|
||||
@ -12,7 +12,8 @@ import pandas as pd
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.embeddings import create_collection_name
|
||||
from graphrag.config.embeddings import create_index_name
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy
|
||||
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
|
||||
from graphrag.vector_stores.factory import VectorStoreFactory
|
||||
@ -49,9 +50,9 @@ async def embed_text(
|
||||
vector_store_config = strategy.get("vector_store")
|
||||
|
||||
if vector_store_config:
|
||||
collection_name = _get_collection_name(vector_store_config, embedding_name)
|
||||
index_name = _get_index_name(vector_store_config, embedding_name)
|
||||
vector_store: BaseVectorStore = _create_vector_store(
|
||||
vector_store_config, collection_name
|
||||
vector_store_config, index_name, embedding_name
|
||||
)
|
||||
vector_store_workflow_config = vector_store_config.get(
|
||||
embedding_name, vector_store_config
|
||||
@ -183,27 +184,38 @@ async def _text_embed_with_vector_store(
|
||||
|
||||
|
||||
def _create_vector_store(
|
||||
vector_store_config: dict, collection_name: str
|
||||
vector_store_config: dict, index_name: str, embedding_name: str | None = None
|
||||
) -> BaseVectorStore:
|
||||
vector_store_type: str = str(vector_store_config.get("type"))
|
||||
if collection_name:
|
||||
vector_store_config.update({"collection_name": collection_name})
|
||||
|
||||
embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get("embeddings_schema", {})
|
||||
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
|
||||
|
||||
if embeddings_schema is not None and embedding_name is not None and embedding_name in embeddings_schema:
|
||||
raw_config = embeddings_schema[embedding_name]
|
||||
if isinstance(raw_config, dict):
|
||||
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
|
||||
else:
|
||||
single_embedding_config = raw_config
|
||||
|
||||
if single_embedding_config.index_name == "":
|
||||
single_embedding_config.index_name = index_name
|
||||
|
||||
vector_store = VectorStoreFactory().create_vector_store(
|
||||
vector_store_type, kwargs=vector_store_config
|
||||
vector_store_schema_config=single_embedding_config, vector_store_type=vector_store_type, kwargs=vector_store_config
|
||||
)
|
||||
|
||||
vector_store.connect(**vector_store_config)
|
||||
return vector_store
|
||||
|
||||
|
||||
def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str:
|
||||
def _get_index_name(vector_store_config: dict, embedding_name: str) -> str:
|
||||
container_name = vector_store_config.get("container_name", "default")
|
||||
collection_name = create_collection_name(container_name, embedding_name)
|
||||
index_name = create_index_name(container_name, embedding_name)
|
||||
|
||||
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}"
|
||||
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {index_name}"
|
||||
logger.info(msg)
|
||||
return collection_name
|
||||
return index_name
|
||||
|
||||
|
||||
def load_strategy(strategy: TextEmbedStrategyType) -> TextEmbeddingStrategy:
|
||||
|
||||
@ -8,9 +8,10 @@ from typing import Any
|
||||
|
||||
from graphrag.cache.factory import CacheFactory
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.config.embeddings import create_collection_name
|
||||
from graphrag.config.embeddings import create_index_name
|
||||
from graphrag.config.models.cache_config import CacheConfig
|
||||
from graphrag.config.models.storage_config import StorageConfig
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.data_model.types import TextEmbedder
|
||||
from graphrag.storage.factory import StorageFactory
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
@ -103,12 +104,27 @@ def get_embedding_store(
|
||||
index_names = []
|
||||
for index, store in config_args.items():
|
||||
vector_store_type = store["type"]
|
||||
collection_name = create_collection_name(
|
||||
index_name = create_index_name(
|
||||
store.get("container_name", "default"), embedding_name
|
||||
)
|
||||
|
||||
embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get("embeddings_schema", {})
|
||||
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
|
||||
|
||||
if embeddings_schema is not None and embedding_name is not None and embedding_name in embeddings_schema:
|
||||
raw_config = embeddings_schema[embedding_name]
|
||||
if isinstance(raw_config, dict):
|
||||
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
|
||||
else:
|
||||
single_embedding_config = raw_config
|
||||
|
||||
if single_embedding_config.index_name == "":
|
||||
single_embedding_config.index_name = index_name
|
||||
|
||||
embedding_store = VectorStoreFactory().create_vector_store(
|
||||
vector_store_type=vector_store_type,
|
||||
kwargs={**store, "collection_name": collection_name},
|
||||
vector_store_schema_config=single_embedding_config,
|
||||
kwargs={**store},
|
||||
)
|
||||
embedding_store.connect(**store)
|
||||
# If there is only a single index, return the embedding store directly
|
||||
|
||||
@ -24,9 +24,9 @@ from azure.search.documents.indexes.models import (
|
||||
)
|
||||
from azure.search.documents.models import VectorizedQuery
|
||||
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.data_model.types import TextEmbedder
|
||||
from graphrag.vector_stores.base import (
|
||||
DEFAULT_VECTOR_SIZE,
|
||||
BaseVectorStore,
|
||||
VectorStoreDocument,
|
||||
VectorStoreSearchResult,
|
||||
@ -38,15 +38,14 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
||||
|
||||
index_client: SearchIndexClient
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
def __init__(self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any) -> None:
|
||||
super().__init__(vector_store_schema_config=vector_store_schema_config, **kwargs)
|
||||
|
||||
def connect(self, **kwargs: Any) -> Any:
|
||||
"""Connect to AI search vector storage."""
|
||||
url = kwargs["url"]
|
||||
api_key = kwargs.get("api_key")
|
||||
audience = kwargs.get("audience")
|
||||
self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE)
|
||||
|
||||
self.vector_search_profile_name = kwargs.get(
|
||||
"vector_search_profile_name", "vectorSearchProfile"
|
||||
@ -56,7 +55,7 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
||||
audience_arg = {"audience": audience} if audience and not api_key else {}
|
||||
self.db_connection = SearchClient(
|
||||
endpoint=url,
|
||||
index_name=self.collection_name,
|
||||
index_name=self.index_name,
|
||||
credential=(
|
||||
AzureKeyCredential(api_key) if api_key else DefaultAzureCredential()
|
||||
),
|
||||
@ -78,8 +77,8 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
||||
) -> None:
|
||||
"""Load documents into an Azure AI Search index."""
|
||||
if overwrite:
|
||||
if self.collection_name in self.index_client.list_index_names():
|
||||
self.index_client.delete_index(self.collection_name)
|
||||
if self.index_name != "" and self.index_name in self.index_client.list_index_names():
|
||||
self.index_client.delete_index(self.index_name)
|
||||
|
||||
# Configure vector search profile
|
||||
vector_search = VectorSearch(
|
||||
@ -100,25 +99,23 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
||||
)
|
||||
# Configure the index
|
||||
index = SearchIndex(
|
||||
name=self.collection_name,
|
||||
name=self.index_name,
|
||||
fields=[
|
||||
SimpleField(
|
||||
name="id",
|
||||
name=self.id_field,
|
||||
type=SearchFieldDataType.String,
|
||||
key=True,
|
||||
),
|
||||
SearchField(
|
||||
name="vector",
|
||||
name=self.vector_field,
|
||||
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
|
||||
searchable=True,
|
||||
hidden=False, # DRIFT needs to return the vector for client-side similarity
|
||||
vector_search_dimensions=self.vector_size,
|
||||
vector_search_profile_name=self.vector_search_profile_name,
|
||||
),
|
||||
SearchableField(name="text", type=SearchFieldDataType.String),
|
||||
SimpleField(
|
||||
name="attributes",
|
||||
type=SearchFieldDataType.String,
|
||||
SearchableField(name=self.text_field, type=SearchFieldDataType.String),
|
||||
SimpleField(name=self.attributes_field, type=SearchFieldDataType.String,
|
||||
),
|
||||
],
|
||||
vector_search=vector_search,
|
||||
@ -129,10 +126,10 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
||||
|
||||
batch = [
|
||||
{
|
||||
"id": doc.id,
|
||||
"vector": doc.vector,
|
||||
"text": doc.text,
|
||||
"attributes": json.dumps(doc.attributes),
|
||||
self.id_field: doc.id,
|
||||
self.vector_field: doc.vector,
|
||||
self.text_field: doc.text,
|
||||
self.attributes_field: json.dumps(doc.attributes),
|
||||
}
|
||||
for doc in documents
|
||||
if doc.vector is not None
|
||||
@ -162,7 +159,7 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
||||
) -> list[VectorStoreSearchResult]:
|
||||
"""Perform a vector-based similarity search."""
|
||||
vectorized_query = VectorizedQuery(
|
||||
vector=query_embedding, k_nearest_neighbors=k, fields="vector"
|
||||
vector=query_embedding, k_nearest_neighbors=k, fields=self.vector_field
|
||||
)
|
||||
|
||||
response = self.db_connection.search(
|
||||
@ -172,10 +169,10 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
||||
return [
|
||||
VectorStoreSearchResult(
|
||||
document=VectorStoreDocument(
|
||||
id=doc.get("id", ""),
|
||||
text=doc.get("text", ""),
|
||||
vector=doc.get("vector", []),
|
||||
attributes=(json.loads(doc.get("attributes", "{}"))),
|
||||
id=doc.get(self.id_field, ""),
|
||||
text=doc.get(self.text_field, ""),
|
||||
vector=doc.get(self.vector_field, []),
|
||||
attributes=(json.loads(doc.get(self.attributes_field, "{}"))),
|
||||
),
|
||||
# Cosine similarity between 0.333 and 1.000
|
||||
# https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking#scores-in-a-hybrid-search-results
|
||||
@ -199,8 +196,8 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
||||
"""Search for a document by id."""
|
||||
response = self.db_connection.get_document(id)
|
||||
return VectorStoreDocument(
|
||||
id=response.get("id", ""),
|
||||
text=response.get("text", ""),
|
||||
vector=response.get("vector", []),
|
||||
attributes=(json.loads(response.get("attributes", "{}"))),
|
||||
id=response.get(self.id_field, ""),
|
||||
text=response.get(self.text_field, ""),
|
||||
vector=response.get(self.vector_field, []),
|
||||
attributes=(json.loads(response.get(self.attributes_field, "{}"))),
|
||||
)
|
||||
|
||||
@ -7,10 +7,9 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.data_model.types import TextEmbedder
|
||||
|
||||
DEFAULT_VECTOR_SIZE: int = 1536
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorStoreDocument:
|
||||
@ -42,17 +41,24 @@ class BaseVectorStore(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
vector_store_schema_config: VectorStoreSchemaConfig,
|
||||
db_connection: Any | None = None,
|
||||
document_collection: Any | None = None,
|
||||
query_filter: Any | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self.db_connection = db_connection
|
||||
self.document_collection = document_collection
|
||||
self.query_filter = query_filter
|
||||
self.kwargs = kwargs
|
||||
|
||||
self.index_name = vector_store_schema_config.index_name
|
||||
self.id_field = vector_store_schema_config.id_field
|
||||
self.text_field = vector_store_schema_config.text_field
|
||||
self.vector_field = vector_store_schema_config.vector_field
|
||||
self.attributes_field = vector_store_schema_config.attributes_field
|
||||
self.vector_size = vector_store_schema_config.vector_size
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def connect(self, **kwargs: Any) -> None:
|
||||
|
||||
@ -13,7 +13,6 @@ from azure.identity import DefaultAzureCredential
|
||||
|
||||
from graphrag.data_model.types import TextEmbedder
|
||||
from graphrag.vector_stores.base import (
|
||||
DEFAULT_VECTOR_SIZE,
|
||||
BaseVectorStore,
|
||||
VectorStoreDocument,
|
||||
VectorStoreSearchResult,
|
||||
@ -49,13 +48,13 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
msg = "Database name must be provided."
|
||||
raise ValueError(msg)
|
||||
self._database_name = database_name
|
||||
collection_name = self.collection_name
|
||||
collection_name = self.index_name
|
||||
if collection_name is None:
|
||||
msg = "Collection name is empty or not provided."
|
||||
raise ValueError(msg)
|
||||
self._container_name = collection_name
|
||||
|
||||
self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE)
|
||||
self.vector_size = kwargs.get("vector_size", 1024) #TODO GAUDY fix it
|
||||
self._create_database()
|
||||
self._create_container()
|
||||
|
||||
|
||||
@ -15,6 +15,9 @@ from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from graphrag.config.models.vector_store_schema_config import (
|
||||
VectorStoreSchemaConfig,
|
||||
)
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
|
||||
|
||||
@ -47,7 +50,7 @@ class VectorStoreFactory:
|
||||
|
||||
@classmethod
|
||||
def create_vector_store(
|
||||
cls, vector_store_type: str, kwargs: dict
|
||||
cls, vector_store_type: str, vector_store_schema_config: VectorStoreSchemaConfig, kwargs: dict
|
||||
) -> BaseVectorStore:
|
||||
"""Create a vector store object from the provided type.
|
||||
|
||||
@ -67,8 +70,11 @@ class VectorStoreFactory:
|
||||
msg = f"Unknown vector store type: {vector_store_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return cls._registry[vector_store_type](**kwargs)
|
||||
|
||||
return cls._registry[vector_store_type](
|
||||
vector_store_schema_config=vector_store_schema_config,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_vector_store_types(cls) -> list[str]:
|
||||
"""Get the registered vector store implementations."""
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.data_model.types import TextEmbedder
|
||||
|
||||
from graphrag.vector_stores.base import (
|
||||
@ -21,18 +22,19 @@ import lancedb
|
||||
class LanceDBVectorStore(BaseVectorStore):
|
||||
"""LanceDB vector storage implementation."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
def __init__(self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any) -> None:
|
||||
super().__init__(vector_store_schema_config=vector_store_schema_config, **kwargs)
|
||||
|
||||
def connect(self, **kwargs: Any) -> Any:
|
||||
"""Connect to the vector storage."""
|
||||
self.db_connection = lancedb.connect(kwargs["db_uri"])
|
||||
|
||||
if (
|
||||
self.collection_name
|
||||
and self.collection_name in self.db_connection.table_names()
|
||||
self.index_name
|
||||
and self.index_name in self.db_connection.table_names()
|
||||
):
|
||||
self.document_collection = self.db_connection.open_table(
|
||||
self.collection_name
|
||||
self.index_name
|
||||
)
|
||||
|
||||
def load_documents(
|
||||
@ -41,10 +43,10 @@ class LanceDBVectorStore(BaseVectorStore):
|
||||
"""Load documents into vector storage."""
|
||||
data = [
|
||||
{
|
||||
"id": document.id,
|
||||
"text": document.text,
|
||||
"vector": document.vector,
|
||||
"attributes": json.dumps(document.attributes),
|
||||
self.id_field: document.id,
|
||||
self.text_field: document.text,
|
||||
self.vector_field: document.vector,
|
||||
self.attributes_field: json.dumps(document.attributes),
|
||||
}
|
||||
for document in documents
|
||||
if document.vector is not None
|
||||
@ -54,10 +56,10 @@ class LanceDBVectorStore(BaseVectorStore):
|
||||
data = None
|
||||
|
||||
schema = pa.schema([
|
||||
pa.field("id", pa.string()),
|
||||
pa.field("text", pa.string()),
|
||||
pa.field("vector", pa.list_(pa.float64())),
|
||||
pa.field("attributes", pa.string()),
|
||||
pa.field(self.id_field, pa.string()),
|
||||
pa.field(self.text_field, pa.string()),
|
||||
pa.field(self.vector_field, pa.list_(pa.float64())),
|
||||
pa.field(self.attributes_field, pa.string()),
|
||||
])
|
||||
# NOTE: If modifying the next section of code, ensure that the schema remains the same.
|
||||
# The pyarrow format of the 'vector' field may change if the order of operations is changed
|
||||
@ -65,19 +67,20 @@ class LanceDBVectorStore(BaseVectorStore):
|
||||
if overwrite:
|
||||
if data:
|
||||
self.document_collection = self.db_connection.create_table(
|
||||
self.collection_name, data=data, mode="overwrite"
|
||||
self.index_name, data=data, mode="overwrite"
|
||||
)
|
||||
else:
|
||||
self.document_collection = self.db_connection.create_table(
|
||||
self.collection_name, schema=schema, mode="overwrite"
|
||||
self.index_name, schema=schema, mode="overwrite"
|
||||
)
|
||||
else:
|
||||
# add data to existing table
|
||||
self.document_collection = self.db_connection.open_table(
|
||||
self.collection_name
|
||||
self.index_name
|
||||
)
|
||||
if data:
|
||||
self.document_collection.add(data)
|
||||
self.document_collection.create_index(vector_column_name=self.vector_field)
|
||||
|
||||
def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
|
||||
"""Build a query filter to filter documents by id."""
|
||||
@ -100,7 +103,7 @@ class LanceDBVectorStore(BaseVectorStore):
|
||||
if self.query_filter:
|
||||
docs = (
|
||||
self.document_collection.search(
|
||||
query=query_embedding, vector_column_name="vector"
|
||||
query=query_embedding, vector_column_name=self.vector_field
|
||||
)
|
||||
.where(self.query_filter, prefilter=True)
|
||||
.limit(k)
|
||||
@ -109,7 +112,7 @@ class LanceDBVectorStore(BaseVectorStore):
|
||||
else:
|
||||
docs = (
|
||||
self.document_collection.search(
|
||||
query=query_embedding, vector_column_name="vector"
|
||||
query=query_embedding, vector_column_name=self.vector_field
|
||||
)
|
||||
.limit(k)
|
||||
.to_list()
|
||||
@ -117,10 +120,10 @@ class LanceDBVectorStore(BaseVectorStore):
|
||||
return [
|
||||
VectorStoreSearchResult(
|
||||
document=VectorStoreDocument(
|
||||
id=doc["id"],
|
||||
text=doc["text"],
|
||||
vector=doc["vector"],
|
||||
attributes=json.loads(doc["attributes"]),
|
||||
id=doc[self.id_field],
|
||||
text=doc[self.text_field],
|
||||
vector=doc[self.vector_field],
|
||||
attributes=json.loads(doc[self.attributes_field]),
|
||||
),
|
||||
score=1 - abs(float(doc["_distance"])),
|
||||
)
|
||||
@ -145,9 +148,9 @@ class LanceDBVectorStore(BaseVectorStore):
|
||||
)
|
||||
if doc:
|
||||
return VectorStoreDocument(
|
||||
id=doc[0]["id"],
|
||||
text=doc[0]["text"],
|
||||
vector=doc[0]["vector"],
|
||||
attributes=json.loads(doc[0]["attributes"]),
|
||||
id=doc[0][self.id_field],
|
||||
text=doc[0][self.text_field],
|
||||
vector=doc[0][self.vector_field],
|
||||
attributes=json.loads(doc[0][self.attributes_field]),
|
||||
)
|
||||
return VectorStoreDocument(id=id, text=None, vector=None)
|
||||
|
||||
@ -3,19 +3,19 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from graphrag.config.embeddings import create_collection_name
|
||||
from graphrag.config.embeddings import create_index_name
|
||||
|
||||
|
||||
def test_create_collection_name():
|
||||
collection = create_collection_name("default", "entity.title")
|
||||
def test_create_index_name():
|
||||
collection = create_index_name("default", "entity.title")
|
||||
assert collection == "default-entity-title"
|
||||
|
||||
|
||||
def test_create_collection_name_invalid_embedding_throws():
|
||||
def test_create_index_name_invalid_embedding_throws():
|
||||
with pytest.raises(KeyError):
|
||||
create_collection_name("default", "invalid.name")
|
||||
create_index_name("default", "invalid.name")
|
||||
|
||||
|
||||
def test_create_collection_name_invalid_embedding_does_not_throw():
|
||||
collection = create_collection_name("default", "invalid.name", validate=False)
|
||||
def test_create_index_name_invalid_embedding_does_not_throw():
|
||||
collection = create_index_name("default", "invalid.name", validate=False)
|
||||
assert collection == "default-invalid-name"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user