fix for lancedb vectors

This commit is contained in:
Gaudy Blanco 2025-09-13 00:27:55 -06:00
parent bfaa7ef016
commit 9b14b9da36
2 changed files with 45 additions and 30 deletions

17
.vscode/launch.json vendored
View File

@ -6,21 +6,24 @@
"name": "Indexer",
"type": "debugpy",
"request": "launch",
"module": "uv",
"module": "graphrag",
"args": [
"poe", "index",
"--root", "<path_to_ragtest_root_demo>"
"index",
"--root",
"<path_to_index_folder>"
],
"console": "integratedTerminal"
},
{
"name": "Query",
"type": "debugpy",
"request": "launch",
"module": "uv",
"module": "graphrag",
"args": [
"poe", "query",
"--root", "<path_to_ragtest_root_demo>",
"--method", "global",
"query",
"--root",
"<path_to_index_folder>",
"--method", "basic",
"--query", "What are the top themes in this story",
]
},

View File

@ -5,9 +5,8 @@
import json # noqa: I001
from typing import Any
import pyarrow as pa
import numpy as np
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.data_model.types import TextEmbedder
@ -37,42 +36,54 @@ class LanceDBVectorStore(BaseVectorStore):
self.index_name
)
def load_documents(
self, documents: list[VectorStoreDocument], overwrite: bool = True
) -> None:
"""Load documents into vector storage."""
data = [
{
self.id_field: document.id,
self.text_field: document.text,
self.vector_field: document.vector,
self.attributes_field: json.dumps(document.attributes),
}
for document in documents
if document.vector is not None
]
# TODO GAUDY Step 1: Prepare data columns manually
ids = []
texts = []
vectors = []
attributes = []
if len(data) == 0:
for document in documents:
if document.vector is not None and len(document.vector) == self.vector_size:
ids.append(document.id)
texts.append(document.text)
vectors.append(np.array(document.vector, dtype=np.float32))
attributes.append(json.dumps(document.attributes))
# Step 2: Handle empty case
if len(ids) == 0:
data = None
else:
# Step 3: Flatten the vectors and build FixedSizeListArray manually
flat_vector = np.concatenate(vectors).astype(np.float32)
flat_array = pa.array(flat_vector, type=pa.float32())
vector_column = pa.FixedSizeListArray.from_arrays(flat_array, self.vector_size)
# Step 4: Create PyArrow table (let schema be inferred)
data = pa.table({
self.id_field: pa.array(ids, type=pa.string()),
self.text_field: pa.array(texts, type=pa.string()),
self.vector_field: vector_column,
self.attributes_field: pa.array(attributes, type=pa.string())
})
schema = pa.schema([
pa.field(self.id_field, pa.string()),
pa.field(self.text_field, pa.string()),
pa.field(self.vector_field, pa.list_(pa.float64())),
pa.field(self.attributes_field, pa.string()),
])
# NOTE: If modifying the next section of code, ensure that the schema remains the same.
# The pyarrow format of the 'vector' field may change if the order of operations is changed
# and will break vector search.
if overwrite:
if data:
self.document_collection = self.db_connection.create_table(
self.index_name, data=data, mode="overwrite"
self.index_name, data=data, mode="overwrite", schema=data.schema
)
else:
self.document_collection = self.db_connection.create_table(
self.index_name, schema=schema, mode="overwrite"
self.index_name, mode="overwrite"
)
self.document_collection.create_index(vector_column_name=self.vector_field, index_type="IVF_FLAT")
else:
# add data to existing table
self.document_collection = self.db_connection.open_table(
@ -80,7 +91,6 @@ class LanceDBVectorStore(BaseVectorStore):
)
if data:
self.document_collection.add(data)
self.document_collection.create_index(vector_column_name=self.vector_field)
def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
"""Build a query filter to filter documents by id."""
@ -97,7 +107,7 @@ class LanceDBVectorStore(BaseVectorStore):
return self.query_filter
def similarity_search_by_vector(
self, query_embedding: list[float], k: int = 10, **kwargs: Any
self, query_embedding: list[float] | np.ndarray, k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
"""Perform a vector-based similarity search."""
if self.query_filter:
@ -110,6 +120,8 @@ class LanceDBVectorStore(BaseVectorStore):
.to_list()
)
else:
query_embedding = np.array(query_embedding, dtype=np.float32)
docs = (
self.document_collection.search(
query=query_embedding, vector_column_name=self.vector_field