mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 09:07:20 +08:00
uv run poe format
This commit is contained in:
parent
cffb03a776
commit
15f2144c5e
@ -11,17 +11,16 @@ DEFAULT_VECTOR_SIZE: int = 1536
|
||||
|
||||
VALID_IDENTIFIER_REGEX = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
|
||||
def is_valid_field_name(field: str) -> bool:
|
||||
"""Check if a field name is valid for CosmosDB."""
|
||||
return bool(VALID_IDENTIFIER_REGEX.match(field))
|
||||
|
||||
|
||||
class VectorStoreSchemaConfig(BaseModel):
|
||||
"""The default configuration section for Vector Store Schema."""
|
||||
|
||||
index_name: str = Field(
|
||||
description="The index name to use.",
|
||||
default=""
|
||||
)
|
||||
index_name: str = Field(description="The index name to use.", default="")
|
||||
|
||||
id_field: str = Field(
|
||||
description="The ID field to use.",
|
||||
|
||||
@ -188,10 +188,16 @@ def _create_vector_store(
|
||||
) -> BaseVectorStore:
|
||||
vector_store_type: str = str(vector_store_config.get("type"))
|
||||
|
||||
embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get("embeddings_schema", {})
|
||||
embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get(
|
||||
"embeddings_schema", {}
|
||||
)
|
||||
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
|
||||
|
||||
if embeddings_schema is not None and embedding_name is not None and embedding_name in embeddings_schema:
|
||||
if (
|
||||
embeddings_schema is not None
|
||||
and embedding_name is not None
|
||||
and embedding_name in embeddings_schema
|
||||
):
|
||||
raw_config = embeddings_schema[embedding_name]
|
||||
if isinstance(raw_config, dict):
|
||||
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
|
||||
@ -202,7 +208,9 @@ def _create_vector_store(
|
||||
single_embedding_config.index_name = index_name
|
||||
|
||||
vector_store = VectorStoreFactory().create_vector_store(
|
||||
vector_store_schema_config=single_embedding_config, vector_store_type=vector_store_type, kwargs=vector_store_config
|
||||
vector_store_schema_config=single_embedding_config,
|
||||
vector_store_type=vector_store_type,
|
||||
kwargs=vector_store_config,
|
||||
)
|
||||
|
||||
vector_store.connect(**vector_store_config)
|
||||
|
||||
@ -108,10 +108,16 @@ def get_embedding_store(
|
||||
store.get("container_name", "default"), embedding_name
|
||||
)
|
||||
|
||||
embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get("embeddings_schema", {})
|
||||
embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get(
|
||||
"embeddings_schema", {}
|
||||
)
|
||||
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
|
||||
|
||||
if embeddings_schema is not None and embedding_name is not None and embedding_name in embeddings_schema:
|
||||
if (
|
||||
embeddings_schema is not None
|
||||
and embedding_name is not None
|
||||
and embedding_name in embeddings_schema
|
||||
):
|
||||
raw_config = embeddings_schema[embedding_name]
|
||||
if isinstance(raw_config, dict):
|
||||
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
|
||||
|
||||
@ -38,8 +38,12 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
||||
|
||||
index_client: SearchIndexClient
|
||||
|
||||
def __init__(self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any) -> None:
|
||||
super().__init__(vector_store_schema_config=vector_store_schema_config, **kwargs)
|
||||
def __init__(
|
||||
self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(
|
||||
vector_store_schema_config=vector_store_schema_config, **kwargs
|
||||
)
|
||||
|
||||
def connect(self, **kwargs: Any) -> Any:
|
||||
"""Connect to AI search vector storage."""
|
||||
@ -77,8 +81,11 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
||||
) -> None:
|
||||
"""Load documents into an Azure AI Search index."""
|
||||
if overwrite:
|
||||
if self.index_name != "" and self.index_name in self.index_client.list_index_names():
|
||||
self.index_client.delete_index(self.index_name)
|
||||
if (
|
||||
self.index_name != ""
|
||||
and self.index_name in self.index_client.list_index_names()
|
||||
):
|
||||
self.index_client.delete_index(self.index_name)
|
||||
|
||||
# Configure vector search profile
|
||||
vector_search = VectorSearch(
|
||||
@ -114,8 +121,12 @@ class AzureAISearchVectorStore(BaseVectorStore):
|
||||
vector_search_dimensions=self.vector_size,
|
||||
vector_search_profile_name=self.vector_search_profile_name,
|
||||
),
|
||||
SearchableField(name=self.text_field, type=SearchFieldDataType.String),
|
||||
SimpleField(name=self.attributes_field, type=SearchFieldDataType.String,
|
||||
SearchableField(
|
||||
name=self.text_field, type=SearchFieldDataType.String
|
||||
),
|
||||
SimpleField(
|
||||
name=self.attributes_field,
|
||||
type=SearchFieldDataType.String,
|
||||
),
|
||||
],
|
||||
vector_search=vector_search,
|
||||
|
||||
@ -51,7 +51,7 @@ class BaseVectorStore(ABC):
|
||||
self.document_collection = document_collection
|
||||
self.query_filter = query_filter
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
self.index_name = vector_store_schema_config.index_name
|
||||
self.id_field = vector_store_schema_config.id_field
|
||||
self.text_field = vector_store_schema_config.text_field
|
||||
@ -59,7 +59,6 @@ class BaseVectorStore(ABC):
|
||||
self.attributes_field = vector_store_schema_config.attributes_field
|
||||
self.vector_size = vector_store_schema_config.vector_size
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def connect(self, **kwargs: Any) -> None:
|
||||
"""Connect to vector storage."""
|
||||
|
||||
@ -27,8 +27,12 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
_database_client: DatabaseProxy
|
||||
_container_client: ContainerProxy
|
||||
|
||||
def __init__(self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any) -> None:
|
||||
super().__init__(vector_store_schema_config=vector_store_schema_config, **kwargs)
|
||||
def __init__(
|
||||
self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(
|
||||
vector_store_schema_config=vector_store_schema_config, **kwargs
|
||||
)
|
||||
|
||||
def connect(self, **kwargs: Any) -> Any:
|
||||
"""Connect to CosmosDB vector storage."""
|
||||
@ -98,13 +102,18 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
"indexingMode": "consistent",
|
||||
"automatic": True,
|
||||
"includedPaths": [{"path": "/*"}],
|
||||
"excludedPaths": [{"path": "/_etag/?"}, {"path": f"/{self.vector_field}/*"}],
|
||||
"excludedPaths": [
|
||||
{"path": "/_etag/?"},
|
||||
{"path": f"/{self.vector_field}/*"},
|
||||
],
|
||||
}
|
||||
|
||||
# Currently, the CosmosDB emulator does not support the diskANN policy.
|
||||
try:
|
||||
# First try with the standard diskANN policy
|
||||
indexing_policy["vectorIndexes"] = [{"path": f"/{self.vector_field}", "type": "diskANN"}]
|
||||
indexing_policy["vectorIndexes"] = [
|
||||
{"path": f"/{self.vector_field}", "type": "diskANN"}
|
||||
]
|
||||
|
||||
# Create the container and container client
|
||||
self._database_client.create_container_if_not_exists(
|
||||
@ -247,7 +256,9 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
id_filter = ", ".join([f"'{id}'" for id in include_ids])
|
||||
else:
|
||||
id_filter = ", ".join([str(id) for id in include_ids])
|
||||
self.query_filter = f"SELECT * FROM c WHERE c.{self.id_field} IN ({id_filter})" # noqa: S608
|
||||
self.query_filter = (
|
||||
f"SELECT * FROM c WHERE c.{self.id_field} IN ({id_filter})" # noqa: S608
|
||||
)
|
||||
return self.query_filter
|
||||
|
||||
def search_by_id(self, id: str) -> VectorStoreDocument:
|
||||
|
||||
@ -50,7 +50,10 @@ class VectorStoreFactory:
|
||||
|
||||
@classmethod
|
||||
def create_vector_store(
|
||||
cls, vector_store_type: str, vector_store_schema_config: VectorStoreSchemaConfig, kwargs: dict
|
||||
cls,
|
||||
vector_store_type: str,
|
||||
vector_store_schema_config: VectorStoreSchemaConfig,
|
||||
kwargs: dict,
|
||||
) -> BaseVectorStore:
|
||||
"""Create a vector store object from the provided type.
|
||||
|
||||
@ -71,10 +74,9 @@ class VectorStoreFactory:
|
||||
raise ValueError(msg)
|
||||
|
||||
return cls._registry[vector_store_type](
|
||||
vector_store_schema_config=vector_store_schema_config,
|
||||
**kwargs
|
||||
vector_store_schema_config=vector_store_schema_config, **kwargs
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_vector_store_types(cls) -> list[str]:
|
||||
"""Get the registered vector store implementations."""
|
||||
|
||||
@ -21,21 +21,19 @@ import lancedb
|
||||
class LanceDBVectorStore(BaseVectorStore):
|
||||
"""LanceDB vector storage implementation."""
|
||||
|
||||
def __init__(self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any) -> None:
|
||||
super().__init__(vector_store_schema_config=vector_store_schema_config, **kwargs)
|
||||
def __init__(
|
||||
self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(
|
||||
vector_store_schema_config=vector_store_schema_config, **kwargs
|
||||
)
|
||||
|
||||
def connect(self, **kwargs: Any) -> Any:
|
||||
"""Connect to the vector storage."""
|
||||
self.db_connection = lancedb.connect(kwargs["db_uri"])
|
||||
|
||||
if (
|
||||
self.index_name
|
||||
and self.index_name in self.db_connection.table_names()
|
||||
):
|
||||
self.document_collection = self.db_connection.open_table(
|
||||
self.index_name
|
||||
)
|
||||
|
||||
if self.index_name and self.index_name in self.db_connection.table_names():
|
||||
self.document_collection = self.db_connection.open_table(self.index_name)
|
||||
|
||||
def load_documents(
|
||||
self, documents: list[VectorStoreDocument], overwrite: bool = True
|
||||
@ -61,14 +59,16 @@ class LanceDBVectorStore(BaseVectorStore):
|
||||
# Step 3: Flatten the vectors and build FixedSizeListArray manually
|
||||
flat_vector = np.concatenate(vectors).astype(np.float32)
|
||||
flat_array = pa.array(flat_vector, type=pa.float32())
|
||||
vector_column = pa.FixedSizeListArray.from_arrays(flat_array, self.vector_size)
|
||||
vector_column = pa.FixedSizeListArray.from_arrays(
|
||||
flat_array, self.vector_size
|
||||
)
|
||||
|
||||
# Step 4: Create PyArrow table (let schema be inferred)
|
||||
data = pa.table({
|
||||
self.id_field: pa.array(ids, type=pa.string()),
|
||||
self.text_field: pa.array(texts, type=pa.string()),
|
||||
self.vector_field: vector_column,
|
||||
self.attributes_field: pa.array(attributes, type=pa.string())
|
||||
self.attributes_field: pa.array(attributes, type=pa.string()),
|
||||
})
|
||||
|
||||
# NOTE: If modifying the next section of code, ensure that the schema remains the same.
|
||||
@ -83,12 +83,12 @@ class LanceDBVectorStore(BaseVectorStore):
|
||||
self.document_collection = self.db_connection.create_table(
|
||||
self.index_name, mode="overwrite"
|
||||
)
|
||||
self.document_collection.create_index(vector_column_name=self.vector_field, index_type="IVF_FLAT")
|
||||
self.document_collection.create_index(
|
||||
vector_column_name=self.vector_field, index_type="IVF_FLAT"
|
||||
)
|
||||
else:
|
||||
# add data to existing table
|
||||
self.document_collection = self.db_connection.open_table(
|
||||
self.index_name
|
||||
)
|
||||
self.document_collection = self.db_connection.open_table(self.index_name)
|
||||
if data:
|
||||
self.document_collection.add(data)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user