uv run poe format

This commit is contained in:
Gaudy Blanco 2025-09-18 13:10:18 -06:00
parent cffb03a776
commit 15f2144c5e
8 changed files with 78 additions and 42 deletions

View File

@ -11,17 +11,16 @@ DEFAULT_VECTOR_SIZE: int = 1536
VALID_IDENTIFIER_REGEX = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
def is_valid_field_name(field: str) -> bool:
"""Check if a field name is valid for CosmosDB."""
return bool(VALID_IDENTIFIER_REGEX.match(field))
class VectorStoreSchemaConfig(BaseModel):
"""The default configuration section for Vector Store Schema."""
index_name: str = Field(
description="The index name to use.",
default=""
)
index_name: str = Field(description="The index name to use.", default="")
id_field: str = Field(
description="The ID field to use.",

View File

@ -188,10 +188,16 @@ def _create_vector_store(
) -> BaseVectorStore:
vector_store_type: str = str(vector_store_config.get("type"))
embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get("embeddings_schema", {})
embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get(
"embeddings_schema", {}
)
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
if embeddings_schema is not None and embedding_name is not None and embedding_name in embeddings_schema:
if (
embeddings_schema is not None
and embedding_name is not None
and embedding_name in embeddings_schema
):
raw_config = embeddings_schema[embedding_name]
if isinstance(raw_config, dict):
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
@ -202,7 +208,9 @@ def _create_vector_store(
single_embedding_config.index_name = index_name
vector_store = VectorStoreFactory().create_vector_store(
vector_store_schema_config=single_embedding_config, vector_store_type=vector_store_type, kwargs=vector_store_config
vector_store_schema_config=single_embedding_config,
vector_store_type=vector_store_type,
kwargs=vector_store_config,
)
vector_store.connect(**vector_store_config)

View File

@ -108,10 +108,16 @@ def get_embedding_store(
store.get("container_name", "default"), embedding_name
)
embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get("embeddings_schema", {})
embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get(
"embeddings_schema", {}
)
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
if embeddings_schema is not None and embedding_name is not None and embedding_name in embeddings_schema:
if (
embeddings_schema is not None
and embedding_name is not None
and embedding_name in embeddings_schema
):
raw_config = embeddings_schema[embedding_name]
if isinstance(raw_config, dict):
single_embedding_config = VectorStoreSchemaConfig(**raw_config)

View File

@ -38,8 +38,12 @@ class AzureAISearchVectorStore(BaseVectorStore):
index_client: SearchIndexClient
def __init__(self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any) -> None:
super().__init__(vector_store_schema_config=vector_store_schema_config, **kwargs)
def __init__(
self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
) -> None:
super().__init__(
vector_store_schema_config=vector_store_schema_config, **kwargs
)
def connect(self, **kwargs: Any) -> Any:
"""Connect to AI search vector storage."""
@ -77,8 +81,11 @@ class AzureAISearchVectorStore(BaseVectorStore):
) -> None:
"""Load documents into an Azure AI Search index."""
if overwrite:
if self.index_name != "" and self.index_name in self.index_client.list_index_names():
self.index_client.delete_index(self.index_name)
if (
self.index_name != ""
and self.index_name in self.index_client.list_index_names()
):
self.index_client.delete_index(self.index_name)
# Configure vector search profile
vector_search = VectorSearch(
@ -114,8 +121,12 @@ class AzureAISearchVectorStore(BaseVectorStore):
vector_search_dimensions=self.vector_size,
vector_search_profile_name=self.vector_search_profile_name,
),
SearchableField(name=self.text_field, type=SearchFieldDataType.String),
SimpleField(name=self.attributes_field, type=SearchFieldDataType.String,
SearchableField(
name=self.text_field, type=SearchFieldDataType.String
),
SimpleField(
name=self.attributes_field,
type=SearchFieldDataType.String,
),
],
vector_search=vector_search,

View File

@ -51,7 +51,7 @@ class BaseVectorStore(ABC):
self.document_collection = document_collection
self.query_filter = query_filter
self.kwargs = kwargs
self.index_name = vector_store_schema_config.index_name
self.id_field = vector_store_schema_config.id_field
self.text_field = vector_store_schema_config.text_field
@ -59,7 +59,6 @@ class BaseVectorStore(ABC):
self.attributes_field = vector_store_schema_config.attributes_field
self.vector_size = vector_store_schema_config.vector_size
@abstractmethod
def connect(self, **kwargs: Any) -> None:
"""Connect to vector storage."""

View File

@ -27,8 +27,12 @@ class CosmosDBVectorStore(BaseVectorStore):
_database_client: DatabaseProxy
_container_client: ContainerProxy
def __init__(self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any) -> None:
super().__init__(vector_store_schema_config=vector_store_schema_config, **kwargs)
def __init__(
self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
) -> None:
super().__init__(
vector_store_schema_config=vector_store_schema_config, **kwargs
)
def connect(self, **kwargs: Any) -> Any:
"""Connect to CosmosDB vector storage."""
@ -98,13 +102,18 @@ class CosmosDBVectorStore(BaseVectorStore):
"indexingMode": "consistent",
"automatic": True,
"includedPaths": [{"path": "/*"}],
"excludedPaths": [{"path": "/_etag/?"}, {"path": f"/{self.vector_field}/*"}],
"excludedPaths": [
{"path": "/_etag/?"},
{"path": f"/{self.vector_field}/*"},
],
}
# Currently, the CosmosDB emulator does not support the diskANN policy.
try:
# First try with the standard diskANN policy
indexing_policy["vectorIndexes"] = [{"path": f"/{self.vector_field}", "type": "diskANN"}]
indexing_policy["vectorIndexes"] = [
{"path": f"/{self.vector_field}", "type": "diskANN"}
]
# Create the container and container client
self._database_client.create_container_if_not_exists(
@ -247,7 +256,9 @@ class CosmosDBVectorStore(BaseVectorStore):
id_filter = ", ".join([f"'{id}'" for id in include_ids])
else:
id_filter = ", ".join([str(id) for id in include_ids])
self.query_filter = f"SELECT * FROM c WHERE c.{self.id_field} IN ({id_filter})" # noqa: S608
self.query_filter = (
f"SELECT * FROM c WHERE c.{self.id_field} IN ({id_filter})" # noqa: S608
)
return self.query_filter
def search_by_id(self, id: str) -> VectorStoreDocument:

View File

@ -50,7 +50,10 @@ class VectorStoreFactory:
@classmethod
def create_vector_store(
cls, vector_store_type: str, vector_store_schema_config: VectorStoreSchemaConfig, kwargs: dict
cls,
vector_store_type: str,
vector_store_schema_config: VectorStoreSchemaConfig,
kwargs: dict,
) -> BaseVectorStore:
"""Create a vector store object from the provided type.
@ -71,10 +74,9 @@ class VectorStoreFactory:
raise ValueError(msg)
return cls._registry[vector_store_type](
vector_store_schema_config=vector_store_schema_config,
**kwargs
vector_store_schema_config=vector_store_schema_config, **kwargs
)
@classmethod
def get_vector_store_types(cls) -> list[str]:
"""Get the registered vector store implementations."""

View File

@ -21,21 +21,19 @@ import lancedb
class LanceDBVectorStore(BaseVectorStore):
"""LanceDB vector storage implementation."""
def __init__(self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any) -> None:
super().__init__(vector_store_schema_config=vector_store_schema_config, **kwargs)
def __init__(
self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any
) -> None:
super().__init__(
vector_store_schema_config=vector_store_schema_config, **kwargs
)
def connect(self, **kwargs: Any) -> Any:
"""Connect to the vector storage."""
self.db_connection = lancedb.connect(kwargs["db_uri"])
if (
self.index_name
and self.index_name in self.db_connection.table_names()
):
self.document_collection = self.db_connection.open_table(
self.index_name
)
if self.index_name and self.index_name in self.db_connection.table_names():
self.document_collection = self.db_connection.open_table(self.index_name)
def load_documents(
self, documents: list[VectorStoreDocument], overwrite: bool = True
@ -61,14 +59,16 @@ class LanceDBVectorStore(BaseVectorStore):
# Step 3: Flatten the vectors and build FixedSizeListArray manually
flat_vector = np.concatenate(vectors).astype(np.float32)
flat_array = pa.array(flat_vector, type=pa.float32())
vector_column = pa.FixedSizeListArray.from_arrays(flat_array, self.vector_size)
vector_column = pa.FixedSizeListArray.from_arrays(
flat_array, self.vector_size
)
# Step 4: Create PyArrow table (let schema be inferred)
data = pa.table({
self.id_field: pa.array(ids, type=pa.string()),
self.text_field: pa.array(texts, type=pa.string()),
self.vector_field: vector_column,
self.attributes_field: pa.array(attributes, type=pa.string())
self.attributes_field: pa.array(attributes, type=pa.string()),
})
# NOTE: If modifying the next section of code, ensure that the schema remains the same.
@ -83,12 +83,12 @@ class LanceDBVectorStore(BaseVectorStore):
self.document_collection = self.db_connection.create_table(
self.index_name, mode="overwrite"
)
self.document_collection.create_index(vector_column_name=self.vector_field, index_type="IVF_FLAT")
self.document_collection.create_index(
vector_column_name=self.vector_field, index_type="IVF_FLAT"
)
else:
# add data to existing table
self.document_collection = self.db_connection.open_table(
self.index_name
)
self.document_collection = self.db_connection.open_table(self.index_name)
if data:
self.document_collection.add(data)