mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 09:07:20 +08:00
fix for lancedb vectors
This commit is contained in:
parent
bfaa7ef016
commit
9b14b9da36
17
.vscode/launch.json
vendored
17
.vscode/launch.json
vendored
@ -6,21 +6,24 @@
|
||||
"name": "Indexer",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uv",
|
||||
"module": "graphrag",
|
||||
"args": [
|
||||
"poe", "index",
|
||||
"--root", "<path_to_ragtest_root_demo>"
|
||||
"index",
|
||||
"--root",
|
||||
"<path_to_index_folder>"
|
||||
],
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "Query",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uv",
|
||||
"module": "graphrag",
|
||||
"args": [
|
||||
"poe", "query",
|
||||
"--root", "<path_to_ragtest_root_demo>",
|
||||
"--method", "global",
|
||||
"query",
|
||||
"--root",
|
||||
"<path_to_index_folder>",
|
||||
"--method", "basic",
|
||||
"--query", "What are the top themes in this story",
|
||||
]
|
||||
},
|
||||
|
||||
@ -5,9 +5,8 @@
|
||||
|
||||
import json # noqa: I001
|
||||
from typing import Any
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
import numpy as np
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.data_model.types import TextEmbedder
|
||||
|
||||
@ -37,42 +36,54 @@ class LanceDBVectorStore(BaseVectorStore):
|
||||
self.index_name
|
||||
)
|
||||
|
||||
|
||||
def load_documents(
|
||||
self, documents: list[VectorStoreDocument], overwrite: bool = True
|
||||
) -> None:
|
||||
"""Load documents into vector storage."""
|
||||
data = [
|
||||
{
|
||||
self.id_field: document.id,
|
||||
self.text_field: document.text,
|
||||
self.vector_field: document.vector,
|
||||
self.attributes_field: json.dumps(document.attributes),
|
||||
}
|
||||
for document in documents
|
||||
if document.vector is not None
|
||||
]
|
||||
# TODO GAUDY Step 1: Prepare data columns manually
|
||||
ids = []
|
||||
texts = []
|
||||
vectors = []
|
||||
attributes = []
|
||||
|
||||
if len(data) == 0:
|
||||
for document in documents:
|
||||
if document.vector is not None and len(document.vector) == self.vector_size:
|
||||
ids.append(document.id)
|
||||
texts.append(document.text)
|
||||
vectors.append(np.array(document.vector, dtype=np.float32))
|
||||
attributes.append(json.dumps(document.attributes))
|
||||
|
||||
# Step 2: Handle empty case
|
||||
if len(ids) == 0:
|
||||
data = None
|
||||
else:
|
||||
# Step 3: Flatten the vectors and build FixedSizeListArray manually
|
||||
flat_vector = np.concatenate(vectors).astype(np.float32)
|
||||
flat_array = pa.array(flat_vector, type=pa.float32())
|
||||
vector_column = pa.FixedSizeListArray.from_arrays(flat_array, self.vector_size)
|
||||
|
||||
# Step 4: Create PyArrow table (let schema be inferred)
|
||||
data = pa.table({
|
||||
self.id_field: pa.array(ids, type=pa.string()),
|
||||
self.text_field: pa.array(texts, type=pa.string()),
|
||||
self.vector_field: vector_column,
|
||||
self.attributes_field: pa.array(attributes, type=pa.string())
|
||||
})
|
||||
|
||||
schema = pa.schema([
|
||||
pa.field(self.id_field, pa.string()),
|
||||
pa.field(self.text_field, pa.string()),
|
||||
pa.field(self.vector_field, pa.list_(pa.float64())),
|
||||
pa.field(self.attributes_field, pa.string()),
|
||||
])
|
||||
# NOTE: If modifying the next section of code, ensure that the schema remains the same.
|
||||
# The pyarrow format of the 'vector' field may change if the order of operations is changed
|
||||
# and will break vector search.
|
||||
if overwrite:
|
||||
if data:
|
||||
self.document_collection = self.db_connection.create_table(
|
||||
self.index_name, data=data, mode="overwrite"
|
||||
self.index_name, data=data, mode="overwrite", schema=data.schema
|
||||
)
|
||||
else:
|
||||
self.document_collection = self.db_connection.create_table(
|
||||
self.index_name, schema=schema, mode="overwrite"
|
||||
self.index_name, mode="overwrite"
|
||||
)
|
||||
self.document_collection.create_index(vector_column_name=self.vector_field, index_type="IVF_FLAT")
|
||||
else:
|
||||
# add data to existing table
|
||||
self.document_collection = self.db_connection.open_table(
|
||||
@ -80,7 +91,6 @@ class LanceDBVectorStore(BaseVectorStore):
|
||||
)
|
||||
if data:
|
||||
self.document_collection.add(data)
|
||||
self.document_collection.create_index(vector_column_name=self.vector_field)
|
||||
|
||||
def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
|
||||
"""Build a query filter to filter documents by id."""
|
||||
@ -97,7 +107,7 @@ class LanceDBVectorStore(BaseVectorStore):
|
||||
return self.query_filter
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self, query_embedding: list[float], k: int = 10, **kwargs: Any
|
||||
self, query_embedding: list[float] | np.ndarray, k: int = 10, **kwargs: Any
|
||||
) -> list[VectorStoreSearchResult]:
|
||||
"""Perform a vector-based similarity search."""
|
||||
if self.query_filter:
|
||||
@ -110,6 +120,8 @@ class LanceDBVectorStore(BaseVectorStore):
|
||||
.to_list()
|
||||
)
|
||||
else:
|
||||
query_embedding = np.array(query_embedding, dtype=np.float32)
|
||||
|
||||
docs = (
|
||||
self.document_collection.search(
|
||||
query=query_embedding, vector_column_name=self.vector_field
|
||||
|
||||
Loading…
Reference in New Issue
Block a user