mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 09:07:20 +08:00
cosmosdb implementation
This commit is contained in:
parent
9b14b9da36
commit
cffb03a776
@ -3,10 +3,18 @@
|
||||
|
||||
"""Parameterization settings for the default configuration."""
|
||||
|
||||
import re
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
DEFAULT_VECTOR_SIZE: int = 1536
|
||||
|
||||
VALID_IDENTIFIER_REGEX = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
def is_valid_field_name(field: str) -> bool:
|
||||
"""Check if a field name is valid for CosmosDB."""
|
||||
return bool(VALID_IDENTIFIER_REGEX.match(field))
|
||||
|
||||
class VectorStoreSchemaConfig(BaseModel):
|
||||
"""The default configuration section for Vector Store Schema."""
|
||||
|
||||
@ -40,9 +48,17 @@ class VectorStoreSchemaConfig(BaseModel):
|
||||
default=DEFAULT_VECTOR_SIZE,
|
||||
)
|
||||
|
||||
#TODO GAUDY
|
||||
def _validate_schema(self) -> None:
|
||||
"""Validate the schema."""
|
||||
for field in [
|
||||
self.id_field,
|
||||
self.vector_field,
|
||||
self.text_field,
|
||||
self.attributes_field,
|
||||
]:
|
||||
if not is_valid_field_name(field):
|
||||
msg = f"Unsafe or invalid field name: {field}"
|
||||
raise ValueError(msg)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_model(self):
|
||||
|
||||
@ -11,6 +11,7 @@ from azure.cosmos.exceptions import CosmosHttpResponseError
|
||||
from azure.cosmos.partition_key import PartitionKey
|
||||
from azure.identity import DefaultAzureCredential
|
||||
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.data_model.types import TextEmbedder
|
||||
from graphrag.vector_stores.base import (
|
||||
BaseVectorStore,
|
||||
@ -26,8 +27,8 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
_database_client: DatabaseProxy
|
||||
_container_client: ContainerProxy
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
def __init__(self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any) -> None:
|
||||
super().__init__(vector_store_schema_config=vector_store_schema_config, **kwargs)
|
||||
|
||||
def connect(self, **kwargs: Any) -> Any:
|
||||
"""Connect to CosmosDB vector storage."""
|
||||
@ -48,13 +49,12 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
msg = "Database name must be provided."
|
||||
raise ValueError(msg)
|
||||
self._database_name = database_name
|
||||
collection_name = self.index_name
|
||||
if collection_name is None:
|
||||
msg = "Collection name is empty or not provided."
|
||||
if self.index_name is None:
|
||||
msg = "Index name is empty or not provided."
|
||||
raise ValueError(msg)
|
||||
self._container_name = collection_name
|
||||
self._container_name = self.index_name
|
||||
|
||||
self.vector_size = kwargs.get("vector_size", 1024) #TODO GAUDY fix it
|
||||
self.vector_size = self.vector_size
|
||||
self._create_database()
|
||||
self._create_container()
|
||||
|
||||
@ -85,7 +85,7 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
vector_embedding_policy = {
|
||||
"vectorEmbeddings": [
|
||||
{
|
||||
"path": "/vector",
|
||||
"path": f"/{self.vector_field}",
|
||||
"dataType": "float32",
|
||||
"distanceFunction": "cosine",
|
||||
"dimensions": self.vector_size,
|
||||
@ -98,13 +98,13 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
"indexingMode": "consistent",
|
||||
"automatic": True,
|
||||
"includedPaths": [{"path": "/*"}],
|
||||
"excludedPaths": [{"path": "/_etag/?"}, {"path": "/vector/*"}],
|
||||
"excludedPaths": [{"path": "/_etag/?"}, {"path": f"/{self.vector_field}/*"}],
|
||||
}
|
||||
|
||||
# Currently, the CosmosDB emulator does not support the diskANN policy.
|
||||
try:
|
||||
# First try with the standard diskANN policy
|
||||
indexing_policy["vectorIndexes"] = [{"path": "/vector", "type": "diskANN"}]
|
||||
indexing_policy["vectorIndexes"] = [{"path": f"/{self.vector_field}", "type": "diskANN"}]
|
||||
|
||||
# Create the container and container client
|
||||
self._database_client.create_container_if_not_exists(
|
||||
@ -158,10 +158,10 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
for doc in documents:
|
||||
if doc.vector is not None:
|
||||
doc_json = {
|
||||
"id": doc.id,
|
||||
"vector": doc.vector,
|
||||
"text": doc.text,
|
||||
"attributes": json.dumps(doc.attributes),
|
||||
self.id_field: doc.id,
|
||||
self.vector_field: doc.vector,
|
||||
self.text_field: doc.text,
|
||||
self.attributes_field: json.dumps(doc.attributes),
|
||||
}
|
||||
self._container_client.upsert_item(doc_json)
|
||||
|
||||
@ -174,7 +174,7 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
raise ValueError(msg)
|
||||
|
||||
try:
|
||||
query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608
|
||||
query = f"SELECT TOP {k} c.{self.id_field}, c.{self.text_field}, c.{self.vector_field}, c.{self.attributes_field}, VectorDistance(c.{self.vector_field}, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.{self.vector_field}, @embedding)" # noqa: S608
|
||||
query_params = [{"name": "@embedding", "value": query_embedding}]
|
||||
items = list(
|
||||
self._container_client.query_items(
|
||||
@ -186,7 +186,7 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
except (CosmosHttpResponseError, ValueError):
|
||||
# Currently, the CosmosDB emulator does not support the VectorDistance function.
|
||||
# For emulator or test environments - fetch all items and calculate distance locally
|
||||
query = "SELECT c.id, c.text, c.vector, c.attributes FROM c"
|
||||
query = f"SELECT c.{self.id_field}, c.{self.text_field}, c.{self.vector_field}, c.{self.attributes_field} FROM c" # noqa: S608
|
||||
items = list(
|
||||
self._container_client.query_items(
|
||||
query=query,
|
||||
@ -205,7 +205,7 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
|
||||
# Calculate scores for all items
|
||||
for item in items:
|
||||
item_vector = item.get("vector", [])
|
||||
item_vector = item.get(self.vector_field, [])
|
||||
similarity = cosine_similarity(query_embedding, item_vector)
|
||||
item["SimilarityScore"] = similarity
|
||||
|
||||
@ -217,10 +217,10 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
return [
|
||||
VectorStoreSearchResult(
|
||||
document=VectorStoreDocument(
|
||||
id=item.get("id", ""),
|
||||
text=item.get("text", ""),
|
||||
vector=item.get("vector", []),
|
||||
attributes=(json.loads(item.get("attributes", "{}"))),
|
||||
id=item.get(self.id_field, ""),
|
||||
text=item.get(self.text_field, ""),
|
||||
vector=item.get(self.vector_field, []),
|
||||
attributes=(json.loads(item.get(self.attributes_field, "{}"))),
|
||||
),
|
||||
score=item.get("SimilarityScore", 0.0),
|
||||
)
|
||||
@ -247,7 +247,7 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
id_filter = ", ".join([f"'{id}'" for id in include_ids])
|
||||
else:
|
||||
id_filter = ", ".join([str(id) for id in include_ids])
|
||||
self.query_filter = f"SELECT * FROM c WHERE c.id IN ({id_filter})" # noqa: S608
|
||||
self.query_filter = f"SELECT * FROM c WHERE c.{self.id_field} IN ({id_filter})" # noqa: S608
|
||||
return self.query_filter
|
||||
|
||||
def search_by_id(self, id: str) -> VectorStoreDocument:
|
||||
@ -258,10 +258,10 @@ class CosmosDBVectorStore(BaseVectorStore):
|
||||
|
||||
item = self._container_client.read_item(item=id, partition_key=id)
|
||||
return VectorStoreDocument(
|
||||
id=item.get("id", ""),
|
||||
vector=item.get("vector", []),
|
||||
text=item.get("text", ""),
|
||||
attributes=(json.loads(item.get("attributes", "{}"))),
|
||||
id=item.get(self.id_field, ""),
|
||||
vector=item.get(self.vector_field, []),
|
||||
text=item.get(self.text_field, ""),
|
||||
attributes=(json.loads(item.get(self.attributes_field, "{}"))),
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user