graphrag/graphrag/vector_stores/cosmosdb.py
2025-09-11 13:42:49 -06:00

271 lines
10 KiB
Python

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A package containing the CosmosDB vector store implementation."""
import json
from typing import Any
from azure.cosmos import ContainerProxy, CosmosClient, DatabaseProxy
from azure.cosmos.exceptions import CosmosHttpResponseError
from azure.cosmos.partition_key import PartitionKey
from azure.identity import DefaultAzureCredential
from graphrag.data_model.types import TextEmbedder
from graphrag.vector_stores.base import (
BaseVectorStore,
VectorStoreDocument,
VectorStoreSearchResult,
)
class CosmosDBVectorStore(BaseVectorStore):
"""Azure CosmosDB vector storage implementation."""
_cosmos_client: CosmosClient
_database_client: DatabaseProxy
_container_client: ContainerProxy
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
def connect(self, **kwargs: Any) -> Any:
"""Connect to CosmosDB vector storage."""
connection_string = kwargs.get("connection_string")
if connection_string:
self._cosmos_client = CosmosClient.from_connection_string(connection_string)
else:
url = kwargs.get("url")
if not url:
msg = "Either connection_string or url must be provided."
raise ValueError(msg)
self._cosmos_client = CosmosClient(
url=url, credential=DefaultAzureCredential()
)
database_name = kwargs.get("database_name")
if database_name is None:
msg = "Database name must be provided."
raise ValueError(msg)
self._database_name = database_name
collection_name = self.index_name
if collection_name is None:
msg = "Collection name is empty or not provided."
raise ValueError(msg)
self._container_name = collection_name
self.vector_size = kwargs.get("vector_size", 1024) #TODO GAUDY fix it
self._create_database()
self._create_container()
def _create_database(self) -> None:
"""Create the database if it doesn't exist."""
self._cosmos_client.create_database_if_not_exists(id=self._database_name)
self._database_client = self._cosmos_client.get_database_client(
self._database_name
)
def _delete_database(self) -> None:
"""Delete the database if it exists."""
if self._database_exists():
self._cosmos_client.delete_database(self._database_name)
def _database_exists(self) -> bool:
"""Check if the database exists."""
existing_database_names = [
database["id"] for database in self._cosmos_client.list_databases()
]
return self._database_name in existing_database_names
def _create_container(self) -> None:
"""Create the container if it doesn't exist."""
partition_key = PartitionKey(path="/id", kind="Hash")
# Define the container vector policy
vector_embedding_policy = {
"vectorEmbeddings": [
{
"path": "/vector",
"dataType": "float32",
"distanceFunction": "cosine",
"dimensions": self.vector_size,
}
]
}
# Define the vector indexing policy
indexing_policy = {
"indexingMode": "consistent",
"automatic": True,
"includedPaths": [{"path": "/*"}],
"excludedPaths": [{"path": "/_etag/?"}, {"path": "/vector/*"}],
}
# Currently, the CosmosDB emulator does not support the diskANN policy.
try:
# First try with the standard diskANN policy
indexing_policy["vectorIndexes"] = [{"path": "/vector", "type": "diskANN"}]
# Create the container and container client
self._database_client.create_container_if_not_exists(
id=self._container_name,
partition_key=partition_key,
indexing_policy=indexing_policy,
vector_embedding_policy=vector_embedding_policy,
)
except CosmosHttpResponseError:
# If diskANN fails (likely in emulator), retry without vector indexes
indexing_policy.pop("vectorIndexes", None)
# Create the container with compatible indexing policy
self._database_client.create_container_if_not_exists(
id=self._container_name,
partition_key=partition_key,
indexing_policy=indexing_policy,
vector_embedding_policy=vector_embedding_policy,
)
self._container_client = self._database_client.get_container_client(
self._container_name
)
def _delete_container(self) -> None:
"""Delete the vector store container in the database if it exists."""
if self._container_exists():
self._database_client.delete_container(self._container_name)
def _container_exists(self) -> bool:
"""Check if the container name exists in the database."""
existing_container_names = [
container["id"] for container in self._database_client.list_containers()
]
return self._container_name in existing_container_names
def load_documents(
self, documents: list[VectorStoreDocument], overwrite: bool = True
) -> None:
"""Load documents into CosmosDB."""
# Create a CosmosDB container on overwrite
if overwrite:
self._delete_container()
self._create_container()
if self._container_client is None:
msg = "Container client is not initialized."
raise ValueError(msg)
# Upload documents to CosmosDB
for doc in documents:
if doc.vector is not None:
doc_json = {
"id": doc.id,
"vector": doc.vector,
"text": doc.text,
"attributes": json.dumps(doc.attributes),
}
self._container_client.upsert_item(doc_json)
def similarity_search_by_vector(
self, query_embedding: list[float], k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
"""Perform a vector-based similarity search."""
if self._container_client is None:
msg = "Container client is not initialized."
raise ValueError(msg)
try:
query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608
query_params = [{"name": "@embedding", "value": query_embedding}]
items = list(
self._container_client.query_items(
query=query,
parameters=query_params,
enable_cross_partition_query=True,
)
)
except (CosmosHttpResponseError, ValueError):
# Currently, the CosmosDB emulator does not support the VectorDistance function.
# For emulator or test environments - fetch all items and calculate distance locally
query = "SELECT c.id, c.text, c.vector, c.attributes FROM c"
items = list(
self._container_client.query_items(
query=query,
enable_cross_partition_query=True,
)
)
# Calculate cosine similarity locally (1 - cosine distance)
from numpy import dot
from numpy.linalg import norm
def cosine_similarity(a, b):
if norm(a) * norm(b) == 0:
return 0.0
return dot(a, b) / (norm(a) * norm(b))
# Calculate scores for all items
for item in items:
item_vector = item.get("vector", [])
similarity = cosine_similarity(query_embedding, item_vector)
item["SimilarityScore"] = similarity
# Sort by similarity score (higher is better) and take top k
items = sorted(
items, key=lambda x: x.get("SimilarityScore", 0.0), reverse=True
)[:k]
return [
VectorStoreSearchResult(
document=VectorStoreDocument(
id=item.get("id", ""),
text=item.get("text", ""),
vector=item.get("vector", []),
attributes=(json.loads(item.get("attributes", "{}"))),
),
score=item.get("SimilarityScore", 0.0),
)
for item in items
]
def similarity_search_by_text(
self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
"""Perform a text-based similarity search."""
query_embedding = text_embedder(text)
if query_embedding:
return self.similarity_search_by_vector(
query_embedding=query_embedding, k=k
)
return []
def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
"""Build a query filter to filter documents by a list of ids."""
if include_ids is None or len(include_ids) == 0:
self.query_filter = None
else:
if isinstance(include_ids[0], str):
id_filter = ", ".join([f"'{id}'" for id in include_ids])
else:
id_filter = ", ".join([str(id) for id in include_ids])
self.query_filter = f"SELECT * FROM c WHERE c.id IN ({id_filter})" # noqa: S608
return self.query_filter
def search_by_id(self, id: str) -> VectorStoreDocument:
"""Search for a document by id."""
if self._container_client is None:
msg = "Container client is not initialized."
raise ValueError(msg)
item = self._container_client.read_item(item=id, partition_key=id)
return VectorStoreDocument(
id=item.get("id", ""),
vector=item.get("vector", []),
text=item.get("text", ""),
attributes=(json.loads(item.get("attributes", "{}"))),
)
def clear(self) -> None:
"""Clear the vector store."""
self._delete_container()
self._delete_database()