cosmosdb implementation

This commit is contained in:
Gaudy Blanco 2025-09-18 13:07:59 -06:00
parent 9b14b9da36
commit cffb03a776
2 changed files with 43 additions and 27 deletions

View File

@ -3,10 +3,18 @@
"""Parameterization settings for the default configuration."""
import re
from pydantic import BaseModel, Field, model_validator
DEFAULT_VECTOR_SIZE: int = 1536
VALID_IDENTIFIER_REGEX = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
def is_valid_field_name(field: str) -> bool:
"""Check if a field name is valid for CosmosDB."""
return bool(VALID_IDENTIFIER_REGEX.match(field))
class VectorStoreSchemaConfig(BaseModel):
"""The default configuration section for Vector Store Schema."""
@ -40,9 +48,17 @@ class VectorStoreSchemaConfig(BaseModel):
default=DEFAULT_VECTOR_SIZE,
)
#TODO GAUDY
def _validate_schema(self) -> None:
"""Validate the schema."""
for field in [
self.id_field,
self.vector_field,
self.text_field,
self.attributes_field,
]:
if not is_valid_field_name(field):
msg = f"Unsafe or invalid field name: {field}"
raise ValueError(msg)
@model_validator(mode="after")
def _validate_model(self):

View File

@ -11,6 +11,7 @@ from azure.cosmos.exceptions import CosmosHttpResponseError
from azure.cosmos.partition_key import PartitionKey
from azure.identity import DefaultAzureCredential
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.data_model.types import TextEmbedder
from graphrag.vector_stores.base import (
BaseVectorStore,
@ -26,8 +27,8 @@ class CosmosDBVectorStore(BaseVectorStore):
_database_client: DatabaseProxy
_container_client: ContainerProxy
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
def __init__(self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any) -> None:
super().__init__(vector_store_schema_config=vector_store_schema_config, **kwargs)
def connect(self, **kwargs: Any) -> Any:
"""Connect to CosmosDB vector storage."""
@ -48,13 +49,12 @@ class CosmosDBVectorStore(BaseVectorStore):
msg = "Database name must be provided."
raise ValueError(msg)
self._database_name = database_name
collection_name = self.index_name
if collection_name is None:
msg = "Collection name is empty or not provided."
if self.index_name is None:
msg = "Index name is empty or not provided."
raise ValueError(msg)
self._container_name = collection_name
self._container_name = self.index_name
self.vector_size = kwargs.get("vector_size", 1024) #TODO GAUDY fix it
self.vector_size = self.vector_size
self._create_database()
self._create_container()
@ -85,7 +85,7 @@ class CosmosDBVectorStore(BaseVectorStore):
vector_embedding_policy = {
"vectorEmbeddings": [
{
"path": "/vector",
"path": f"/{self.vector_field}",
"dataType": "float32",
"distanceFunction": "cosine",
"dimensions": self.vector_size,
@ -98,13 +98,13 @@ class CosmosDBVectorStore(BaseVectorStore):
"indexingMode": "consistent",
"automatic": True,
"includedPaths": [{"path": "/*"}],
"excludedPaths": [{"path": "/_etag/?"}, {"path": "/vector/*"}],
"excludedPaths": [{"path": "/_etag/?"}, {"path": f"/{self.vector_field}/*"}],
}
# Currently, the CosmosDB emulator does not support the diskANN policy.
try:
# First try with the standard diskANN policy
indexing_policy["vectorIndexes"] = [{"path": "/vector", "type": "diskANN"}]
indexing_policy["vectorIndexes"] = [{"path": f"/{self.vector_field}", "type": "diskANN"}]
# Create the container and container client
self._database_client.create_container_if_not_exists(
@ -158,10 +158,10 @@ class CosmosDBVectorStore(BaseVectorStore):
for doc in documents:
if doc.vector is not None:
doc_json = {
"id": doc.id,
"vector": doc.vector,
"text": doc.text,
"attributes": json.dumps(doc.attributes),
self.id_field: doc.id,
self.vector_field: doc.vector,
self.text_field: doc.text,
self.attributes_field: json.dumps(doc.attributes),
}
self._container_client.upsert_item(doc_json)
@ -174,7 +174,7 @@ class CosmosDBVectorStore(BaseVectorStore):
raise ValueError(msg)
try:
query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608
query = f"SELECT TOP {k} c.{self.id_field}, c.{self.text_field}, c.{self.vector_field}, c.{self.attributes_field}, VectorDistance(c.{self.vector_field}, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.{self.vector_field}, @embedding)" # noqa: S608
query_params = [{"name": "@embedding", "value": query_embedding}]
items = list(
self._container_client.query_items(
@ -186,7 +186,7 @@ class CosmosDBVectorStore(BaseVectorStore):
except (CosmosHttpResponseError, ValueError):
# Currently, the CosmosDB emulator does not support the VectorDistance function.
# For emulator or test environments - fetch all items and calculate distance locally
query = "SELECT c.id, c.text, c.vector, c.attributes FROM c"
query = f"SELECT c.{self.id_field}, c.{self.text_field}, c.{self.vector_field}, c.{self.attributes_field} FROM c" # noqa: S608
items = list(
self._container_client.query_items(
query=query,
@ -205,7 +205,7 @@ class CosmosDBVectorStore(BaseVectorStore):
# Calculate scores for all items
for item in items:
item_vector = item.get("vector", [])
item_vector = item.get(self.vector_field, [])
similarity = cosine_similarity(query_embedding, item_vector)
item["SimilarityScore"] = similarity
@ -217,10 +217,10 @@ class CosmosDBVectorStore(BaseVectorStore):
return [
VectorStoreSearchResult(
document=VectorStoreDocument(
id=item.get("id", ""),
text=item.get("text", ""),
vector=item.get("vector", []),
attributes=(json.loads(item.get("attributes", "{}"))),
id=item.get(self.id_field, ""),
text=item.get(self.text_field, ""),
vector=item.get(self.vector_field, []),
attributes=(json.loads(item.get(self.attributes_field, "{}"))),
),
score=item.get("SimilarityScore", 0.0),
)
@ -247,7 +247,7 @@ class CosmosDBVectorStore(BaseVectorStore):
id_filter = ", ".join([f"'{id}'" for id in include_ids])
else:
id_filter = ", ".join([str(id) for id in include_ids])
self.query_filter = f"SELECT * FROM c WHERE c.id IN ({id_filter})" # noqa: S608
self.query_filter = f"SELECT * FROM c WHERE c.{self.id_field} IN ({id_filter})" # noqa: S608
return self.query_filter
def search_by_id(self, id: str) -> VectorStoreDocument:
@ -258,10 +258,10 @@ class CosmosDBVectorStore(BaseVectorStore):
item = self._container_client.read_item(item=id, partition_key=id)
return VectorStoreDocument(
id=item.get("id", ""),
vector=item.get("vector", []),
text=item.get("text", ""),
attributes=(json.loads(item.get("attributes", "{}"))),
id=item.get(self.id_field, ""),
vector=item.get(self.vector_field, []),
text=item.get(self.text_field, ""),
attributes=(json.loads(item.get(self.attributes_field, "{}"))),
)
def clear(self) -> None: