mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Refactor CacheFactory, StorageFactory, and VectorStoreFactory to use consistent registration patterns and add custom vector store documentation (#2006)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
* Initial plan * Refactor VectorStoreFactory to use registration functionality like StorageFactory Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Fix linting issues in VectorStoreFactory refactoring Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Remove backward compatibility support from VectorStoreFactory and StorageFactory Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Run ruff check --fix and ruff format, add semversioner file Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * ruff formatting fixes * Fix pytest errors in storage factory tests by updating PipelineStorage interface implementation Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * ruff formatting fixes * update storage factory design * Refactor CacheFactory to use registration functionality like StorageFactory Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * revert copilot changes * fix copilot changes * update comments * Fix failing pytest compatibility for factory tests Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * update class instantiation issue * ruff fixes * fix pytest * add default value * ruff formatting changes * ruff fixes * revert minor changes * cleanup cache factory * Update CacheFactory tests to match consistent factory pattern Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * update pytest thresholds * adjust threshold levels * Add custom vector store implementation notebook Create comprehensive notebook demonstrating how to implement and register custom vector stores with GraphRAG as a plug-and-play framework. Includes: - Complete implementation of SimpleInMemoryVectorStore - Registration with VectorStoreFactory - Testing and validation examples - Configuration examples for GraphRAG settings - Advanced features and best practices - Production considerations checklist The notebook provides a complete walkthrough for developers to understand and implement their own vector store backends. Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * remove sample notebook for now * update tests * fix cache pytests * add pandas-stub to dev dependencies * disable warning check for well known key * skip tests when running on ubuntu * add documentation for custom vector store implementations * ignore ruff findings in notebooks * fix merge breakages * speedup CLI import statements * remove unnecessary import statements in init file * Add str type option on storage/cache type * Fix store name * Add LoggerFactory * Fix up logging setup across CLI/API * Add LoggerFactory test * Fix err message * Semver * Remove enums from factory methods --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> Co-authored-by: Josh Bradley <joshbradley@microsoft.com> Co-authored-by: Nathan Evans <github@talkswithnumbers.com>
This commit is contained in:
parent
69ad36e735
commit
2030f94eb4
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "minor",
|
||||
"description": "Add LoggerFactory and clean up related API."
|
||||
}
|
||||
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Make cache, storage, and vector_store factories consistent with similar registration support"
|
||||
}
|
||||
675
docs/examples_notebooks/custom_vector_store.ipynb
Normal file
675
docs/examples_notebooks/custom_vector_store.ipynb
Normal file
@ -0,0 +1,675 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Copyright (c) 2024 Microsoft Corporation.\n",
|
||||
"# Licensed under the MIT License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Bring-Your-Own Vector Store\n",
|
||||
"\n",
|
||||
"This notebook demonstrates how to implement a custom vector store and register for usage with GraphRAG.\n",
|
||||
"\n",
|
||||
"## Overview\n",
|
||||
"\n",
|
||||
"GraphRAG uses a plug-and-play architecture that allow for easy integration of custom vector stores (outside of what is natively supported) by following a factory design pattern. This allows you to:\n",
|
||||
"\n",
|
||||
"- **Extend functionality**: Add support for new vector database backends\n",
|
||||
"- **Customize behavior**: Implement specialized search logic or data structures\n",
|
||||
"- **Integrate existing systems**: Connect GraphRAG to your existing vector database infrastructure\n",
|
||||
"\n",
|
||||
"### What You'll Learn\n",
|
||||
"\n",
|
||||
"1. Understanding the `BaseVectorStore` interface\n",
|
||||
"2. Implementing a custom vector store class\n",
|
||||
"3. Registering your vector store with the `VectorStoreFactory`\n",
|
||||
"4. Testing and validating your implementation\n",
|
||||
"5. Configuring GraphRAG to use your custom vector store\n",
|
||||
"\n",
|
||||
"Let's get started!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 1: Import Required Dependencies\n",
|
||||
"\n",
|
||||
"First, let's import the necessary GraphRAG components and other dependencies we'll need.\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"pip install graphrag\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Any\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import yaml\n",
|
||||
"\n",
|
||||
"from graphrag.data_model.types import TextEmbedder\n",
|
||||
"\n",
|
||||
"# GraphRAG vector store components\n",
|
||||
"from graphrag.vector_stores.base import (\n",
|
||||
" BaseVectorStore,\n",
|
||||
" VectorStoreDocument,\n",
|
||||
" VectorStoreSearchResult,\n",
|
||||
")\n",
|
||||
"from graphrag.vector_stores.factory import VectorStoreFactory"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 2: Understand the BaseVectorStore Interface\n",
|
||||
"\n",
|
||||
"Before using a custom vector store, let's examine the `BaseVectorStore` interface to understand what methods need to be implemented."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Let's inspect the BaseVectorStore class to understand the required methods\n",
|
||||
"import inspect\n",
|
||||
"\n",
|
||||
"print(\"BaseVectorStore Abstract Methods:\")\n",
|
||||
"print(\"=\" * 40)\n",
|
||||
"\n",
|
||||
"abstract_methods = []\n",
|
||||
"for name, method in inspect.getmembers(BaseVectorStore, predicate=inspect.isfunction):\n",
|
||||
" if getattr(method, \"__isabstractmethod__\", False):\n",
|
||||
" signature = inspect.signature(method)\n",
|
||||
" abstract_methods.append(f\"• {name}{signature}\")\n",
|
||||
" print(f\"• {name}{signature}\")\n",
|
||||
"\n",
|
||||
"print(f\"\\nTotal abstract methods to implement: {len(abstract_methods)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 3: Implement a Custom Vector Store\n",
|
||||
"\n",
|
||||
"Now let's implement a simple in-memory vector store as an example. This vector store will:\n",
|
||||
"\n",
|
||||
"- Store documents and vectors in memory using Python data structures\n",
|
||||
"- Support all required BaseVectorStore methods\n",
|
||||
"\n",
|
||||
"**Note**: This is a simplified example for demonstration. Production vector stores would typically use optimized libraries like FAISS, more sophisticated indexing, and persistent storage."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class SimpleInMemoryVectorStore(BaseVectorStore):\n",
|
||||
" \"\"\"A simple in-memory vector store implementation for demonstration purposes.\n",
|
||||
"\n",
|
||||
" This vector store stores documents and their embeddings in memory and provides\n",
|
||||
" basic similarity search functionality using cosine similarity.\n",
|
||||
"\n",
|
||||
" WARNING: This is for demonstration only - not suitable for production use.\n",
|
||||
" For production, consider using optimized vector databases like LanceDB,\n",
|
||||
" Azure AI Search, or other specialized vector stores.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" # Internal storage for documents and vectors\n",
|
||||
" documents: dict[str, VectorStoreDocument]\n",
|
||||
" vectors: dict[str, np.ndarray]\n",
|
||||
" connected: bool\n",
|
||||
"\n",
|
||||
" def __init__(self, **kwargs: Any):\n",
|
||||
" \"\"\"Initialize the in-memory vector store.\"\"\"\n",
|
||||
" super().__init__(**kwargs)\n",
|
||||
"\n",
|
||||
" self.documents: dict[str, VectorStoreDocument] = {}\n",
|
||||
" self.vectors: dict[str, np.ndarray] = {}\n",
|
||||
" self.connected = False\n",
|
||||
"\n",
|
||||
" print(\n",
|
||||
" f\"🚀 SimpleInMemoryVectorStore initialized for collection: {self.collection_name}\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def connect(self, **kwargs: Any) -> None:\n",
|
||||
" \"\"\"Connect to the vector storage (no-op for in-memory store).\"\"\"\n",
|
||||
" self.connected = True\n",
|
||||
" print(f\"✅ Connected to in-memory vector store: {self.collection_name}\")\n",
|
||||
"\n",
|
||||
" def load_documents(\n",
|
||||
" self, documents: list[VectorStoreDocument], overwrite: bool = True\n",
|
||||
" ) -> None:\n",
|
||||
" \"\"\"Load documents into the vector store.\"\"\"\n",
|
||||
" if not self.connected:\n",
|
||||
" msg = \"Vector store not connected. Call connect() first.\"\n",
|
||||
" raise RuntimeError(msg)\n",
|
||||
"\n",
|
||||
" if overwrite:\n",
|
||||
" self.documents.clear()\n",
|
||||
" self.vectors.clear()\n",
|
||||
"\n",
|
||||
" loaded_count = 0\n",
|
||||
" for doc in documents:\n",
|
||||
" if doc.vector is not None:\n",
|
||||
" doc_id = str(doc.id)\n",
|
||||
" self.documents[doc_id] = doc\n",
|
||||
" self.vectors[doc_id] = np.array(doc.vector, dtype=np.float32)\n",
|
||||
" loaded_count += 1\n",
|
||||
"\n",
|
||||
" print(f\"📚 Loaded {loaded_count} documents into vector store\")\n",
|
||||
"\n",
|
||||
" def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:\n",
|
||||
" \"\"\"Calculate cosine similarity between two vectors.\"\"\"\n",
|
||||
" # Normalize vectors\n",
|
||||
" norm1 = np.linalg.norm(vec1)\n",
|
||||
" norm2 = np.linalg.norm(vec2)\n",
|
||||
"\n",
|
||||
" if norm1 == 0 or norm2 == 0:\n",
|
||||
" return 0.0\n",
|
||||
"\n",
|
||||
" return float(np.dot(vec1, vec2) / (norm1 * norm2))\n",
|
||||
"\n",
|
||||
" def similarity_search_by_vector(\n",
|
||||
" self, query_embedding: list[float], k: int = 10, **kwargs: Any\n",
|
||||
" ) -> list[VectorStoreSearchResult]:\n",
|
||||
" \"\"\"Perform similarity search using a query vector.\"\"\"\n",
|
||||
" if not self.connected:\n",
|
||||
" msg = \"Vector store not connected. Call connect() first.\"\n",
|
||||
" raise RuntimeError(msg)\n",
|
||||
"\n",
|
||||
" if not self.vectors:\n",
|
||||
" return []\n",
|
||||
"\n",
|
||||
" query_vec = np.array(query_embedding, dtype=np.float32)\n",
|
||||
" similarities = []\n",
|
||||
"\n",
|
||||
" # Calculate similarity with all stored vectors\n",
|
||||
" for doc_id, stored_vec in self.vectors.items():\n",
|
||||
" similarity = self._cosine_similarity(query_vec, stored_vec)\n",
|
||||
" similarities.append((doc_id, similarity))\n",
|
||||
"\n",
|
||||
" # Sort by similarity (descending) and take top k\n",
|
||||
" similarities.sort(key=lambda x: x[1], reverse=True)\n",
|
||||
" top_k = similarities[:k]\n",
|
||||
"\n",
|
||||
" # Create search results\n",
|
||||
" results = []\n",
|
||||
" for doc_id, score in top_k:\n",
|
||||
" document = self.documents[doc_id]\n",
|
||||
" result = VectorStoreSearchResult(document=document, score=score)\n",
|
||||
" results.append(result)\n",
|
||||
"\n",
|
||||
" return results\n",
|
||||
"\n",
|
||||
" def similarity_search_by_text(\n",
|
||||
" self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any\n",
|
||||
" ) -> list[VectorStoreSearchResult]:\n",
|
||||
" \"\"\"Perform similarity search using text (which gets embedded first).\"\"\"\n",
|
||||
" # Embed the text first\n",
|
||||
" query_embedding = text_embedder(text)\n",
|
||||
"\n",
|
||||
" # Use vector search with the embedding\n",
|
||||
" return self.similarity_search_by_vector(query_embedding, k, **kwargs)\n",
|
||||
"\n",
|
||||
" def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:\n",
|
||||
" \"\"\"Build a query filter to filter documents by id.\n",
|
||||
"\n",
|
||||
" For this simple implementation, we return the list of IDs as the filter.\n",
|
||||
" \"\"\"\n",
|
||||
" return [str(id_) for id_ in include_ids]\n",
|
||||
"\n",
|
||||
" def search_by_id(self, id: str) -> VectorStoreDocument:\n",
|
||||
" \"\"\"Search for a document by id.\"\"\"\n",
|
||||
" doc_id = str(id)\n",
|
||||
" if doc_id not in self.documents:\n",
|
||||
" msg = f\"Document with id '{id}' not found\"\n",
|
||||
" raise KeyError(msg)\n",
|
||||
"\n",
|
||||
" return self.documents[doc_id]\n",
|
||||
"\n",
|
||||
" def get_stats(self) -> dict[str, Any]:\n",
|
||||
" \"\"\"Get statistics about the vector store (custom method).\"\"\"\n",
|
||||
" return {\n",
|
||||
" \"collection_name\": self.collection_name,\n",
|
||||
" \"document_count\": len(self.documents),\n",
|
||||
" \"vector_count\": len(self.vectors),\n",
|
||||
" \"connected\": self.connected,\n",
|
||||
" \"vector_dimension\": len(next(iter(self.vectors.values())))\n",
|
||||
" if self.vectors\n",
|
||||
" else 0,\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(\"✅ SimpleInMemoryVectorStore class defined!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 4: Register the Custom Vector Store\n",
|
||||
"\n",
|
||||
"Now let's register our custom vector store with the `VectorStoreFactory` so it can be used throughout GraphRAG."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Register our custom vector store with a unique identifier\n",
|
||||
"CUSTOM_VECTOR_STORE_TYPE = \"simple_memory\"\n",
|
||||
"\n",
|
||||
"# Register the vector store class\n",
|
||||
"VectorStoreFactory.register(CUSTOM_VECTOR_STORE_TYPE, SimpleInMemoryVectorStore)\n",
|
||||
"\n",
|
||||
"print(f\"✅ Registered custom vector store with type: '{CUSTOM_VECTOR_STORE_TYPE}'\")\n",
|
||||
"\n",
|
||||
"# Verify registration\n",
|
||||
"available_types = VectorStoreFactory.get_vector_store_types()\n",
|
||||
"print(f\"\\n📋 Available vector store types: {available_types}\")\n",
|
||||
"print(\n",
|
||||
" f\"🔍 Is our custom type supported? {VectorStoreFactory.is_supported_type(CUSTOM_VECTOR_STORE_TYPE)}\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 5: Test the Custom Vector Store\n",
|
||||
"\n",
|
||||
"Let's create some sample data and test our custom vector store implementation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create sample documents with mock embeddings\n",
|
||||
"def create_mock_embedding(dimension: int = 384) -> list[float]:\n",
|
||||
" \"\"\"Create a random embedding vector for testing.\"\"\"\n",
|
||||
" return np.random.normal(0, 1, dimension).tolist()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Sample documents\n",
|
||||
"sample_documents = [\n",
|
||||
" VectorStoreDocument(\n",
|
||||
" id=\"doc_1\",\n",
|
||||
" text=\"GraphRAG is a powerful knowledge graph extraction and reasoning framework.\",\n",
|
||||
" vector=create_mock_embedding(),\n",
|
||||
" attributes={\"category\": \"technology\", \"source\": \"documentation\"},\n",
|
||||
" ),\n",
|
||||
" VectorStoreDocument(\n",
|
||||
" id=\"doc_2\",\n",
|
||||
" text=\"Vector stores enable efficient similarity search over high-dimensional data.\",\n",
|
||||
" vector=create_mock_embedding(),\n",
|
||||
" attributes={\"category\": \"technology\", \"source\": \"research\"},\n",
|
||||
" ),\n",
|
||||
" VectorStoreDocument(\n",
|
||||
" id=\"doc_3\",\n",
|
||||
" text=\"Machine learning models can process and understand natural language text.\",\n",
|
||||
" vector=create_mock_embedding(),\n",
|
||||
" attributes={\"category\": \"AI\", \"source\": \"article\"},\n",
|
||||
" ),\n",
|
||||
" VectorStoreDocument(\n",
|
||||
" id=\"doc_4\",\n",
|
||||
" text=\"Custom implementations allow for specialized behavior and integration.\",\n",
|
||||
" vector=create_mock_embedding(),\n",
|
||||
" attributes={\"category\": \"development\", \"source\": \"tutorial\"},\n",
|
||||
" ),\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"print(f\"📝 Created {len(sample_documents)} sample documents\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test creating vector store using the factory\n",
|
||||
"vector_store_config = {\"collection_name\": \"test_collection\"}\n",
|
||||
"\n",
|
||||
"# Create vector store instance using factory\n",
|
||||
"vector_store = VectorStoreFactory.create_vector_store(\n",
|
||||
" CUSTOM_VECTOR_STORE_TYPE, vector_store_config\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"✅ Created vector store instance: {type(vector_store).__name__}\")\n",
|
||||
"print(f\"📊 Initial stats: {vector_store.get_stats()}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Connect and load documents\n",
|
||||
"vector_store.connect()\n",
|
||||
"vector_store.load_documents(sample_documents)\n",
|
||||
"\n",
|
||||
"print(f\"📊 Updated stats: {vector_store.get_stats()}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test similarity search\n",
|
||||
"query_vector = create_mock_embedding() # Random query vector for testing\n",
|
||||
"\n",
|
||||
"search_results = vector_store.similarity_search_by_vector(\n",
|
||||
" query_vector,\n",
|
||||
" k=3, # Get top 3 similar documents\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"🔍 Found {len(search_results)} similar documents:\\n\")\n",
|
||||
"\n",
|
||||
"for i, result in enumerate(search_results, 1):\n",
|
||||
" doc = result.document\n",
|
||||
" print(f\"{i}. ID: {doc.id}\")\n",
|
||||
" print(f\" Text: {doc.text[:60]}...\")\n",
|
||||
" print(f\" Similarity Score: {result.score:.4f}\")\n",
|
||||
" print(f\" Category: {doc.attributes.get('category', 'N/A')}\")\n",
|
||||
" print()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test search by ID\n",
|
||||
"try:\n",
|
||||
" found_doc = vector_store.search_by_id(\"doc_2\")\n",
|
||||
" print(\"✅ Found document by ID:\")\n",
|
||||
" print(f\" ID: {found_doc.id}\")\n",
|
||||
" print(f\" Text: {found_doc.text}\")\n",
|
||||
" print(f\" Attributes: {found_doc.attributes}\")\n",
|
||||
"except KeyError as e:\n",
|
||||
" print(f\"❌ Error: {e}\")\n",
|
||||
"\n",
|
||||
"# Test filter by ID\n",
|
||||
"id_filter = vector_store.filter_by_id([\"doc_1\", \"doc_3\"])\n",
|
||||
"print(f\"\\n🔧 ID filter result: {id_filter}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 6: Configuration for GraphRAG\n",
|
||||
"\n",
|
||||
"Now let's see how you would configure GraphRAG to use your custom vector store in a settings file."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Example GraphRAG yaml settings\n",
|
||||
"example_settings = {\n",
|
||||
" \"vector_store\": {\n",
|
||||
" \"default_vector_store\": {\n",
|
||||
" \"type\": CUSTOM_VECTOR_STORE_TYPE, # \"simple_memory\"\n",
|
||||
" \"collection_name\": \"graphrag_entities\",\n",
|
||||
" # Add any custom parameters your vector store needs\n",
|
||||
" \"custom_parameter\": \"custom_value\",\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" # Other GraphRAG configuration...\n",
|
||||
" \"models\": {\n",
|
||||
" \"default_embedding_model\": {\n",
|
||||
" \"type\": \"openai_embedding\",\n",
|
||||
" \"model\": \"text-embedding-3-small\",\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Convert to YAML format for settings.yml\n",
|
||||
"yaml_config = yaml.dump(example_settings, default_flow_style=False, indent=2)\n",
|
||||
"\n",
|
||||
"print(\"📄 Example settings.yml configuration:\")\n",
|
||||
"print(\"=\" * 40)\n",
|
||||
"print(yaml_config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 7: Integration with GraphRAG Pipeline\n",
|
||||
"\n",
|
||||
"Here's how your custom vector store would be used in a typical GraphRAG pipeline."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Example of how GraphRAG would use your custom vector store\n",
|
||||
"def simulate_graphrag_pipeline():\n",
|
||||
" \"\"\"Simulate how GraphRAG would use the custom vector store.\"\"\"\n",
|
||||
" print(\"🚀 Simulating GraphRAG pipeline with custom vector store...\\n\")\n",
|
||||
"\n",
|
||||
" # 1. GraphRAG creates vector store using factory\n",
|
||||
" config = {\"collection_name\": \"graphrag_entities\", \"similarity_threshold\": 0.3}\n",
|
||||
"\n",
|
||||
" store = VectorStoreFactory.create_vector_store(CUSTOM_VECTOR_STORE_TYPE, config)\n",
|
||||
" store.connect()\n",
|
||||
"\n",
|
||||
" print(\"✅ Step 1: Vector store created and connected\")\n",
|
||||
"\n",
|
||||
" # 2. During indexing, GraphRAG loads extracted entities\n",
|
||||
" entity_documents = [\n",
|
||||
" VectorStoreDocument(\n",
|
||||
" id=f\"entity_{i}\",\n",
|
||||
" text=f\"Entity {i} description: Important concept in the knowledge graph\",\n",
|
||||
" vector=create_mock_embedding(),\n",
|
||||
" attributes={\"type\": \"entity\", \"importance\": i % 3 + 1},\n",
|
||||
" )\n",
|
||||
" for i in range(10)\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" store.load_documents(entity_documents)\n",
|
||||
" print(f\"✅ Step 2: Loaded {len(entity_documents)} entity documents\")\n",
|
||||
"\n",
|
||||
" # 3. During query time, GraphRAG searches for relevant entities\n",
|
||||
" query_embedding = create_mock_embedding()\n",
|
||||
" relevant_entities = store.similarity_search_by_vector(query_embedding, k=5)\n",
|
||||
"\n",
|
||||
" print(f\"✅ Step 3: Found {len(relevant_entities)} relevant entities for query\")\n",
|
||||
"\n",
|
||||
" # 4. GraphRAG uses these entities for context building\n",
|
||||
" context_entities = [result.document for result in relevant_entities]\n",
|
||||
"\n",
|
||||
" print(\"✅ Step 4: Context built using retrieved entities\")\n",
|
||||
" print(f\"📊 Final stats: {store.get_stats()}\")\n",
|
||||
"\n",
|
||||
" return context_entities\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Run the simulation\n",
|
||||
"context = simulate_graphrag_pipeline()\n",
|
||||
"print(f\"\\n🎯 Retrieved {len(context)} entities for context building\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 8: Testing and Validation\n",
|
||||
"\n",
|
||||
"Let's create a comprehensive test suite to ensure our vector store works correctly."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def test_custom_vector_store():\n",
|
||||
" \"\"\"Comprehensive test suite for the custom vector store.\"\"\"\n",
|
||||
" print(\"🧪 Running comprehensive vector store tests...\\n\")\n",
|
||||
"\n",
|
||||
" # Test 1: Basic functionality\n",
|
||||
" print(\"Test 1: Basic functionality\")\n",
|
||||
" store = VectorStoreFactory.create_vector_store(\n",
|
||||
" CUSTOM_VECTOR_STORE_TYPE, {\"collection_name\": \"test\"}\n",
|
||||
" )\n",
|
||||
" store.connect()\n",
|
||||
"\n",
|
||||
" # Load test documents\n",
|
||||
" test_docs = sample_documents[:2]\n",
|
||||
" store.load_documents(test_docs)\n",
|
||||
"\n",
|
||||
" assert len(store.documents) == 2, \"Should have 2 documents\"\n",
|
||||
" assert len(store.vectors) == 2, \"Should have 2 vectors\"\n",
|
||||
" print(\"✅ Basic functionality test passed\")\n",
|
||||
"\n",
|
||||
" # Test 2: Search functionality\n",
|
||||
" print(\"\\nTest 2: Search functionality\")\n",
|
||||
" query_vec = create_mock_embedding()\n",
|
||||
" results = store.similarity_search_by_vector(query_vec, k=5)\n",
|
||||
"\n",
|
||||
" assert len(results) <= 2, \"Should not return more results than documents\"\n",
|
||||
" assert all(isinstance(r, VectorStoreSearchResult) for r in results), (\n",
|
||||
" \"Should return VectorStoreSearchResult objects\"\n",
|
||||
" )\n",
|
||||
" assert all(-1 <= r.score <= 1 for r in results), (\n",
|
||||
" \"Similarity scores should be between -1 and 1\"\n",
|
||||
" )\n",
|
||||
" print(\"✅ Search functionality test passed\")\n",
|
||||
"\n",
|
||||
" # Test 3: Search by ID\n",
|
||||
" print(\"\\nTest 3: Search by ID\")\n",
|
||||
" found_doc = store.search_by_id(\"doc_1\")\n",
|
||||
" assert found_doc.id == \"doc_1\", \"Should find correct document\"\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" store.search_by_id(\"nonexistent\")\n",
|
||||
" assert False, \"Should raise KeyError for nonexistent ID\"\n",
|
||||
" except KeyError:\n",
|
||||
" pass # Expected\n",
|
||||
"\n",
|
||||
" print(\"✅ Search by ID test passed\")\n",
|
||||
"\n",
|
||||
" # Test 4: Filter functionality\n",
|
||||
" print(\"\\nTest 4: Filter functionality\")\n",
|
||||
" filter_result = store.filter_by_id([\"doc_1\", \"doc_2\"])\n",
|
||||
" assert filter_result == [\"doc_1\", \"doc_2\"], \"Should return filtered IDs\"\n",
|
||||
" print(\"✅ Filter functionality test passed\")\n",
|
||||
"\n",
|
||||
" # Test 5: Error handling\n",
|
||||
" print(\"\\nTest 5: Error handling\")\n",
|
||||
" disconnected_store = VectorStoreFactory.create_vector_store(\n",
|
||||
" CUSTOM_VECTOR_STORE_TYPE, {\"collection_name\": \"test2\"}\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" disconnected_store.load_documents(test_docs)\n",
|
||||
" assert False, \"Should raise error when not connected\"\n",
|
||||
" except RuntimeError:\n",
|
||||
" pass # Expected\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" disconnected_store.similarity_search_by_vector(query_vec)\n",
|
||||
" assert False, \"Should raise error when not connected\"\n",
|
||||
" except RuntimeError:\n",
|
||||
" pass # Expected\n",
|
||||
"\n",
|
||||
" print(\"✅ Error handling test passed\")\n",
|
||||
"\n",
|
||||
" print(\"\\n🎉 All tests passed! Your custom vector store is working correctly.\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Run the tests\n",
|
||||
"test_custom_vector_store()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Summary and Next Steps\n",
|
||||
"\n",
|
||||
"Congratulations! You've successfully learned how to implement and register a custom vector store with GraphRAG. Here's what you accomplished:\n",
|
||||
"\n",
|
||||
"### What You Built\n",
|
||||
"- ✅ **Custom Vector Store Class**: Implemented `SimpleInMemoryVectorStore` with all required methods\n",
|
||||
"- ✅ **Factory Integration**: Registered your vector store with `VectorStoreFactory`\n",
|
||||
"- ✅ **Comprehensive Testing**: Validated functionality with a full test suite\n",
|
||||
"- ✅ **Configuration Examples**: Learned how to configure GraphRAG to use your vector store\n",
|
||||
"\n",
|
||||
"### Key Takeaways\n",
|
||||
"1. **Interface Compliance**: Always implement all methods from `BaseVectorStore`\n",
|
||||
"2. **Factory Pattern**: Use `VectorStoreFactory.register()` to make your vector store available\n",
|
||||
"3. **Configuration**: Vector stores are configured in GraphRAG settings files\n",
|
||||
"4. **Testing**: Thoroughly test all functionality before deploying\n",
|
||||
"\n",
|
||||
"### Next Steps\n",
|
||||
"Check out the API Overview notebook to learn how to index and query data via the graphrag API.\n",
|
||||
"\n",
|
||||
"### Resources\n",
|
||||
"- [GraphRAG Documentation](https://microsoft.github.io/graphrag/)\n",
|
||||
"\n",
|
||||
"Happy building! 🚀"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "graphrag-venv (3.10.18)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.18"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@ -1,6 +1,7 @@
|
||||
# API Notebooks
|
||||
|
||||
- [API Overview Notebook](../../examples_notebooks/api_overview.ipynb)
|
||||
- [Bring-Your-Own Vector Store](../../examples_notebooks/custom_vector_store.ipynb)
|
||||
|
||||
# Query Engine Notebooks
|
||||
|
||||
|
||||
@ -32,6 +32,7 @@ async def build_index(
|
||||
memory_profile: bool = False,
|
||||
callbacks: list[WorkflowCallbacks] | None = None,
|
||||
additional_context: dict[str, Any] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> list[PipelineRunResult]:
|
||||
"""Run the pipeline with the given configuration.
|
||||
|
||||
@ -53,7 +54,7 @@ async def build_index(
|
||||
list[PipelineRunResult]
|
||||
The list of pipeline run results
|
||||
"""
|
||||
init_loggers(config=config)
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
|
||||
# Create callbacks for pipeline lifecycle events if provided
|
||||
workflow_callbacks = (
|
||||
|
||||
@ -67,6 +67,7 @@ async def generate_indexing_prompts(
|
||||
min_examples_required: PositiveInt = 2,
|
||||
n_subset_max: PositiveInt = 300,
|
||||
k: PositiveInt = 15,
|
||||
verbose: bool = False,
|
||||
) -> tuple[str, str, str]:
|
||||
"""Generate indexing prompts.
|
||||
|
||||
@ -89,7 +90,7 @@ async def generate_indexing_prompts(
|
||||
-------
|
||||
tuple[str, str, str]: entity extraction prompt, entity summarization prompt, community summarization prompt
|
||||
"""
|
||||
init_loggers(config=config)
|
||||
init_loggers(config=config, verbose=verbose, filename="prompt-tuning.log")
|
||||
|
||||
# Retrieve documents
|
||||
logger.info("Chunking documents...")
|
||||
|
||||
@ -93,7 +93,7 @@ async def global_search(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
callbacks = callbacks or []
|
||||
full_response = ""
|
||||
@ -156,7 +156,7 @@ def global_search_streaming(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
communities_ = read_indexer_communities(communities, community_reports)
|
||||
reports = read_indexer_reports(
|
||||
@ -229,7 +229,7 @@ async def multi_index_global_search(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
# Streaming not supported yet
|
||||
if streaming:
|
||||
@ -369,7 +369,7 @@ async def local_search(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
callbacks = callbacks or []
|
||||
full_response = ""
|
||||
@ -435,7 +435,7 @@ def local_search_streaming(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
vector_store_args = {}
|
||||
for index, store in config.vector_store.items():
|
||||
@ -508,7 +508,7 @@ async def multi_index_local_search(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
# Streaming not supported yet
|
||||
if streaming:
|
||||
@ -730,7 +730,7 @@ async def drift_search(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
callbacks = callbacks or []
|
||||
full_response = ""
|
||||
@ -792,7 +792,7 @@ def drift_search_streaming(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
vector_store_args = {}
|
||||
for index, store in config.vector_store.items():
|
||||
@ -872,7 +872,7 @@ async def multi_index_drift_search(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
# Streaming not supported yet
|
||||
if streaming:
|
||||
@ -1065,7 +1065,7 @@ async def basic_search(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
callbacks = callbacks or []
|
||||
full_response = ""
|
||||
@ -1111,7 +1111,7 @@ def basic_search_streaming(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
vector_store_args = {}
|
||||
for index, store in config.vector_store.items():
|
||||
@ -1119,7 +1119,7 @@ def basic_search_streaming(
|
||||
msg = f"Vector Store Args: {redact(vector_store_args)}"
|
||||
logger.debug(msg)
|
||||
|
||||
description_embedding_store = get_embedding_store(
|
||||
embedding_store = get_embedding_store(
|
||||
config_args=vector_store_args,
|
||||
embedding_name=text_unit_text_embedding,
|
||||
)
|
||||
@ -1130,7 +1130,7 @@ def basic_search_streaming(
|
||||
search_engine = get_basic_search_engine(
|
||||
config=config,
|
||||
text_units=read_indexer_text_units(text_units),
|
||||
text_unit_embeddings=description_embedding_store,
|
||||
text_unit_embeddings=embedding_store,
|
||||
system_prompt=prompt,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
@ -1164,7 +1164,7 @@ async def multi_index_basic_search(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
# Streaming not supported yet
|
||||
if streaming:
|
||||
|
||||
118
graphrag/cache/factory.py
vendored
118
graphrag/cache/factory.py
vendored
@ -1,23 +1,24 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing create_cache method definition."""
|
||||
"""Factory functions for creating a cache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from graphrag.config.enums import CacheType
|
||||
from graphrag.storage.blob_pipeline_storage import create_blob_storage
|
||||
from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage
|
||||
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
|
||||
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
|
||||
from graphrag.cache.memory_pipeline_cache import InMemoryCache
|
||||
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
|
||||
from graphrag.config.enums import CacheType
|
||||
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
|
||||
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
|
||||
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
|
||||
|
||||
class CacheFactory:
|
||||
@ -25,39 +26,80 @@ class CacheFactory:
|
||||
|
||||
Includes a method for users to register a custom cache implementation.
|
||||
|
||||
Configuration arguments are passed to each cache implementation as kwargs (where possible)
|
||||
Configuration arguments are passed to each cache implementation as kwargs
|
||||
for individual enforcement of required/optional arguments.
|
||||
"""
|
||||
|
||||
cache_types: ClassVar[dict[str, type]] = {}
|
||||
_registry: ClassVar[dict[str, Callable[..., PipelineCache]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, cache_type: str, cache: type):
|
||||
"""Register a custom cache implementation."""
|
||||
cls.cache_types[cache_type] = cache
|
||||
def register(cls, cache_type: str, creator: Callable[..., PipelineCache]) -> None:
|
||||
"""Register a custom cache implementation.
|
||||
|
||||
Args:
|
||||
cache_type: The type identifier for the cache.
|
||||
creator: A class or callable that creates an instance of PipelineCache.
|
||||
"""
|
||||
cls._registry[cache_type] = creator
|
||||
|
||||
@classmethod
|
||||
def create_cache(
|
||||
cls, cache_type: CacheType | str | None, root_dir: str, kwargs: dict
|
||||
) -> PipelineCache:
|
||||
"""Create or get a cache from the provided type."""
|
||||
if not cache_type:
|
||||
return NoopPipelineCache()
|
||||
match cache_type:
|
||||
case CacheType.none:
|
||||
return NoopPipelineCache()
|
||||
case CacheType.memory:
|
||||
return InMemoryCache()
|
||||
case CacheType.file:
|
||||
return JsonPipelineCache(
|
||||
FilePipelineStorage(root_dir=root_dir).child(kwargs["base_dir"])
|
||||
)
|
||||
case CacheType.blob:
|
||||
return JsonPipelineCache(create_blob_storage(**kwargs))
|
||||
case CacheType.cosmosdb:
|
||||
return JsonPipelineCache(create_cosmosdb_storage(**kwargs))
|
||||
case _:
|
||||
if cache_type in cls.cache_types:
|
||||
return cls.cache_types[cache_type](**kwargs)
|
||||
msg = f"Unknown cache type: {cache_type}"
|
||||
raise ValueError(msg)
|
||||
def create_cache(cls, cache_type: str, kwargs: dict) -> PipelineCache:
|
||||
"""Create a cache object from the provided type.
|
||||
|
||||
Args:
|
||||
cache_type: The type of cache to create.
|
||||
root_dir: The root directory for file-based caches.
|
||||
kwargs: Additional keyword arguments for the cache constructor.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A PipelineCache instance.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError: If the cache type is not registered.
|
||||
"""
|
||||
if cache_type not in cls._registry:
|
||||
msg = f"Unknown cache type: {cache_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return cls._registry[cache_type](**kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_cache_types(cls) -> list[str]:
|
||||
"""Get the registered cache implementations."""
|
||||
return list(cls._registry.keys())
|
||||
|
||||
@classmethod
|
||||
def is_supported_type(cls, cache_type: str) -> bool:
|
||||
"""Check if the given cache type is supported."""
|
||||
return cache_type in cls._registry
|
||||
|
||||
|
||||
# --- register built-in cache implementations ---
|
||||
def create_file_cache(root_dir: str, base_dir: str, **kwargs) -> PipelineCache:
|
||||
"""Create a file-based cache implementation."""
|
||||
# Create storage with base_dir in kwargs since FilePipelineStorage expects it there
|
||||
storage_kwargs = {"base_dir": root_dir, **kwargs}
|
||||
storage = FilePipelineStorage(**storage_kwargs).child(base_dir)
|
||||
return JsonPipelineCache(storage)
|
||||
|
||||
|
||||
def create_blob_cache(**kwargs) -> PipelineCache:
|
||||
"""Create a blob storage-based cache implementation."""
|
||||
storage = BlobPipelineStorage(**kwargs)
|
||||
return JsonPipelineCache(storage)
|
||||
|
||||
|
||||
def create_cosmosdb_cache(**kwargs) -> PipelineCache:
|
||||
"""Create a CosmosDB-based cache implementation."""
|
||||
storage = CosmosDBPipelineStorage(**kwargs)
|
||||
return JsonPipelineCache(storage)
|
||||
|
||||
|
||||
# --- register built-in cache implementations ---
|
||||
CacheFactory.register(CacheType.none.value, NoopPipelineCache)
|
||||
CacheFactory.register(CacheType.memory.value, InMemoryCache)
|
||||
CacheFactory.register(CacheType.file.value, create_file_cache)
|
||||
CacheFactory.register(CacheType.blob.value, create_blob_cache)
|
||||
CacheFactory.register(CacheType.cosmosdb.value, create_cosmosdb_cache)
|
||||
|
||||
@ -11,10 +11,9 @@ from pathlib import Path
|
||||
|
||||
import graphrag.api as api
|
||||
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
|
||||
from graphrag.config.enums import CacheType, IndexingMethod, ReportingType
|
||||
from graphrag.config.enums import CacheType, IndexingMethod
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.index.validate_config import validate_config_names
|
||||
from graphrag.logger.standard_logging import DEFAULT_LOG_FILENAME
|
||||
from graphrag.utils.cli import redact
|
||||
|
||||
# Ignore warnings from numba
|
||||
@ -123,17 +122,6 @@ def _run_index(
|
||||
if not cache:
|
||||
config.cache.type = CacheType.none
|
||||
|
||||
# Log the configuration details
|
||||
if config.reporting.type == ReportingType.file:
|
||||
log_dir = Path(config.root_dir) / config.reporting.base_dir
|
||||
log_path = log_dir / DEFAULT_LOG_FILENAME
|
||||
logger.info("Logging enabled at %s", log_path)
|
||||
else:
|
||||
logger.info(
|
||||
"Logging not enabled for config %s",
|
||||
redact(config.model_dump()),
|
||||
)
|
||||
|
||||
if not skip_validation:
|
||||
validate_config_names(config)
|
||||
|
||||
@ -156,6 +144,7 @@ def _run_index(
|
||||
is_update_run=is_update_run,
|
||||
memory_profile=memprofile,
|
||||
callbacks=[ConsoleWorkflowCallbacks(verbose=verbose)],
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
encountered_errors = any(
|
||||
|
||||
@ -7,9 +7,7 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import graphrag.api as api
|
||||
from graphrag.config.enums import ReportingType
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.logger.standard_logging import DEFAULT_LOG_FILENAME
|
||||
from graphrag.prompt_tune.generator.community_report_summarization import (
|
||||
COMMUNITY_SUMMARIZATION_FILENAME,
|
||||
)
|
||||
@ -74,21 +72,13 @@ async def prompt_tune(
|
||||
from graphrag.logger.standard_logging import init_loggers
|
||||
|
||||
# initialize loggers with config
|
||||
init_loggers(
|
||||
config=graph_config,
|
||||
verbose=verbose,
|
||||
)
|
||||
init_loggers(config=graph_config, verbose=verbose, filename="prompt-tuning.log")
|
||||
|
||||
# log the configuration details
|
||||
if graph_config.reporting.type == ReportingType.file:
|
||||
log_dir = Path(root_path) / graph_config.reporting.base_dir
|
||||
log_path = log_dir / DEFAULT_LOG_FILENAME
|
||||
logger.info("Logging enabled at %s", log_path)
|
||||
else:
|
||||
logger.info(
|
||||
"Logging not enabled for config %s",
|
||||
redact(graph_config.model_dump()),
|
||||
)
|
||||
logger.info("Starting prompt tune.")
|
||||
logger.info(
|
||||
"Using default configuration: %s",
|
||||
redact(graph_config.model_dump()),
|
||||
)
|
||||
|
||||
prompts = await api.generate_indexing_prompts(
|
||||
config=graph_config,
|
||||
@ -103,6 +93,7 @@ async def prompt_tune(
|
||||
min_examples_required=min_examples_required,
|
||||
n_subset_max=n_subset_max,
|
||||
k=k,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
output_path = output.resolve()
|
||||
|
||||
@ -18,11 +18,11 @@ from graphrag.config.enums import (
|
||||
NounPhraseExtractorType,
|
||||
ReportingType,
|
||||
StorageType,
|
||||
VectorStoreType,
|
||||
)
|
||||
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
|
||||
EN_STOP_WORDS,
|
||||
)
|
||||
from graphrag.vector_stores.factory import VectorStoreType
|
||||
|
||||
DEFAULT_OUTPUT_BASE_DIR = "output"
|
||||
DEFAULT_CHAT_MODEL_ID = "default_chat_model"
|
||||
|
||||
@ -59,6 +59,14 @@ class StorageType(str, Enum):
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
class VectorStoreType(str, Enum):
|
||||
"""The supported vector store types."""
|
||||
|
||||
LanceDB = "lancedb"
|
||||
AzureAISearch = "azure_ai_search"
|
||||
CosmosDB = "cosmosdb"
|
||||
|
||||
|
||||
class ReportingType(str, Enum):
|
||||
"""The reporting configuration type for the pipeline."""
|
||||
|
||||
|
||||
@ -82,7 +82,7 @@ cache:
|
||||
base_dir: "{graphrag_config_defaults.cache.base_dir}"
|
||||
|
||||
reporting:
|
||||
type: {graphrag_config_defaults.reporting.type.value} # [file, blob, cosmosdb]
|
||||
type: {graphrag_config_defaults.reporting.type.value} # [file, blob]
|
||||
base_dir: "{graphrag_config_defaults.reporting.base_dir}"
|
||||
|
||||
vector_store:
|
||||
|
||||
@ -12,7 +12,7 @@ from graphrag.config.enums import CacheType
|
||||
class CacheConfig(BaseModel):
|
||||
"""The default configuration section for Cache."""
|
||||
|
||||
type: CacheType = Field(
|
||||
type: CacheType | str = Field(
|
||||
description="The cache type to use.",
|
||||
default=graphrag_config_defaults.cache.type,
|
||||
)
|
||||
|
||||
@ -11,6 +11,7 @@ from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import VectorStoreType
|
||||
from graphrag.config.errors import LanguageModelConfigMissingError
|
||||
from graphrag.config.models.basic_search_config import BasicSearchConfig
|
||||
from graphrag.config.models.cache_config import CacheConfig
|
||||
@ -36,7 +37,6 @@ from graphrag.config.models.summarize_descriptions_config import (
|
||||
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig
|
||||
from graphrag.config.models.umap_config import UmapConfig
|
||||
from graphrag.config.models.vector_store_config import VectorStoreConfig
|
||||
from graphrag.vector_stores.factory import VectorStoreType
|
||||
|
||||
|
||||
class GraphRagConfig(BaseModel):
|
||||
|
||||
@ -12,7 +12,7 @@ from graphrag.config.enums import ReportingType
|
||||
class ReportingConfig(BaseModel):
|
||||
"""The default configuration section for Reporting."""
|
||||
|
||||
type: ReportingType = Field(
|
||||
type: ReportingType | str = Field(
|
||||
description="The reporting type to use.",
|
||||
default=graphrag_config_defaults.reporting.type,
|
||||
)
|
||||
|
||||
@ -14,7 +14,7 @@ from graphrag.config.enums import StorageType
|
||||
class StorageConfig(BaseModel):
|
||||
"""The default configuration section for storage."""
|
||||
|
||||
type: StorageType = Field(
|
||||
type: StorageType | str = Field(
|
||||
description="The storage type to use.",
|
||||
default=graphrag_config_defaults.storage.type,
|
||||
)
|
||||
|
||||
@ -39,7 +39,7 @@ class SummarizeDescriptionsConfig(BaseModel):
|
||||
self, root_dir: str, model_config: LanguageModelConfig
|
||||
) -> dict:
|
||||
"""Get the resolved description summarization strategy."""
|
||||
from graphrag.index.operations.summarize_descriptions import (
|
||||
from graphrag.index.operations.summarize_descriptions.summarize_descriptions import (
|
||||
SummarizeStrategyType,
|
||||
)
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ class TextEmbeddingConfig(BaseModel):
|
||||
|
||||
def resolved_strategy(self, model_config: LanguageModelConfig) -> dict:
|
||||
"""Get the resolved text embedding strategy."""
|
||||
from graphrag.index.operations.embed_text import (
|
||||
from graphrag.index.operations.embed_text.embed_text import (
|
||||
TextEmbedStrategyType,
|
||||
)
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from graphrag.config.defaults import vector_store_defaults
|
||||
from graphrag.vector_stores.factory import VectorStoreType
|
||||
from graphrag.config.enums import VectorStoreType
|
||||
|
||||
|
||||
class VectorStoreConfig(BaseModel):
|
||||
|
||||
@ -2,10 +2,3 @@
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""The Indexing Engine text embed package root."""
|
||||
|
||||
from graphrag.index.operations.embed_text.embed_text import (
|
||||
TextEmbedStrategyType,
|
||||
embed_text,
|
||||
)
|
||||
|
||||
__all__ = ["TextEmbedStrategyType", "embed_text"]
|
||||
|
||||
@ -2,17 +2,3 @@
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Root package for description summarization."""
|
||||
|
||||
from graphrag.index.operations.summarize_descriptions.summarize_descriptions import (
|
||||
summarize_descriptions,
|
||||
)
|
||||
from graphrag.index.operations.summarize_descriptions.typing import (
|
||||
SummarizationStrategy,
|
||||
SummarizeStrategyType,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SummarizationStrategy",
|
||||
"SummarizeStrategyType",
|
||||
"summarize_descriptions",
|
||||
]
|
||||
|
||||
@ -15,7 +15,7 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.operations.extract_graph.extract_graph import (
|
||||
extract_graph as extractor,
|
||||
)
|
||||
from graphrag.index.operations.summarize_descriptions import (
|
||||
from graphrag.index.operations.summarize_descriptions.summarize_descriptions import (
|
||||
summarize_descriptions,
|
||||
)
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
|
||||
@ -21,7 +21,7 @@ from graphrag.config.embeddings import (
|
||||
)
|
||||
from graphrag.config.get_embedding_settings import get_embedding_settings
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.operations.embed_text import embed_text
|
||||
from graphrag.index.operations.embed_text.embed_text import embed_text
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import (
|
||||
|
||||
@ -7,7 +7,7 @@ from collections.abc import Callable
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from graphrag.config.enums import ModelType
|
||||
from graphrag.language_model.protocol import ChatModel, EmbeddingModel
|
||||
from graphrag.language_model.protocol.base import ChatModel, EmbeddingModel
|
||||
from graphrag.language_model.providers.fnllm.models import (
|
||||
AzureOpenAIChatFNLLM,
|
||||
AzureOpenAIEmbeddingFNLLM,
|
||||
@ -100,15 +100,16 @@ class ModelFactory:
|
||||
|
||||
# --- Register default implementations ---
|
||||
ModelFactory.register_chat(
|
||||
ModelType.AzureOpenAIChat, lambda **kwargs: AzureOpenAIChatFNLLM(**kwargs)
|
||||
ModelType.AzureOpenAIChat.value, lambda **kwargs: AzureOpenAIChatFNLLM(**kwargs)
|
||||
)
|
||||
ModelFactory.register_chat(
|
||||
ModelType.OpenAIChat, lambda **kwargs: OpenAIChatFNLLM(**kwargs)
|
||||
ModelType.OpenAIChat.value, lambda **kwargs: OpenAIChatFNLLM(**kwargs)
|
||||
)
|
||||
|
||||
ModelFactory.register_embedding(
|
||||
ModelType.AzureOpenAIEmbedding, lambda **kwargs: AzureOpenAIEmbeddingFNLLM(**kwargs)
|
||||
ModelType.AzureOpenAIEmbedding.value,
|
||||
lambda **kwargs: AzureOpenAIEmbeddingFNLLM(**kwargs),
|
||||
)
|
||||
ModelFactory.register_embedding(
|
||||
ModelType.OpenAIEmbedding, lambda **kwargs: OpenAIEmbeddingFNLLM(**kwargs)
|
||||
ModelType.OpenAIEmbedding.value, lambda **kwargs: OpenAIEmbeddingFNLLM(**kwargs)
|
||||
)
|
||||
|
||||
@ -16,7 +16,7 @@ from typing_extensions import Self
|
||||
from graphrag.language_model.factory import ModelFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.language_model.protocol import ChatModel, EmbeddingModel
|
||||
from graphrag.language_model.protocol.base import ChatModel, EmbeddingModel
|
||||
|
||||
|
||||
class ModelManager:
|
||||
|
||||
@ -2,7 +2,3 @@
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Base protocol definitions for LLMs."""
|
||||
|
||||
from .base import ChatModel, EmbeddingModel
|
||||
|
||||
__all__ = ["ChatModel", "EmbeddingModel"]
|
||||
|
||||
113
graphrag/logger/factory.py
Normal file
113
graphrag/logger/factory.py
Normal file
@ -0,0 +1,113 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Factory functions for creating a logger."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from graphrag.config.enums import ReportingType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
LOG_FORMAT = "%(asctime)s.%(msecs)04d - %(levelname)s - %(name)s - %(message)s"
|
||||
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
|
||||
class LoggerFactory:
|
||||
"""A factory class for logger implementations.
|
||||
|
||||
Includes a method for users to register a custom logger implementation.
|
||||
|
||||
Configuration arguments are passed to each logger implementation as kwargs
|
||||
for individual enforcement of required/optional arguments.
|
||||
|
||||
Note that because we rely on the built-in Python logging architecture, this factory does not return an instance,
|
||||
it merely configures the logger to your specified storage location.
|
||||
"""
|
||||
|
||||
_registry: ClassVar[dict[str, Callable[..., logging.Handler]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
cls, reporting_type: str, creator: Callable[..., logging.Handler]
|
||||
) -> None:
|
||||
"""Register a custom logger implementation.
|
||||
|
||||
Args:
|
||||
reporting_type: The type identifier for the logger.
|
||||
creator: A class or callable that initializes logging.
|
||||
"""
|
||||
cls._registry[reporting_type] = creator
|
||||
|
||||
@classmethod
|
||||
def create_logger(cls, reporting_type: str, kwargs: dict) -> logging.Handler:
|
||||
"""Create a logger from the provided type.
|
||||
|
||||
Args:
|
||||
reporting_type: The type of logger to create.
|
||||
logger: The logger instance for the application.
|
||||
kwargs: Additional keyword arguments for the constructor.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A logger instance.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError: If the logger type is not registered.
|
||||
"""
|
||||
if reporting_type not in cls._registry:
|
||||
msg = f"Unknown reporting type: {reporting_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return cls._registry[reporting_type](**kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_logger_types(cls) -> list[str]:
|
||||
"""Get the registered logger implementations."""
|
||||
return list(cls._registry.keys())
|
||||
|
||||
@classmethod
|
||||
def is_supported_type(cls, reporting_type: str) -> bool:
|
||||
"""Check if the given logger type is supported."""
|
||||
return reporting_type in cls._registry
|
||||
|
||||
|
||||
# --- register built-in logger implementations ---
|
||||
def create_file_logger(**kwargs) -> logging.Handler:
|
||||
"""Create a file-based logger."""
|
||||
root_dir = kwargs["root_dir"]
|
||||
base_dir = kwargs["base_dir"]
|
||||
filename = kwargs["filename"]
|
||||
log_dir = Path(root_dir) / base_dir
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
log_file_path = log_dir / filename
|
||||
|
||||
handler = logging.FileHandler(str(log_file_path), mode="a")
|
||||
|
||||
formatter = logging.Formatter(fmt=LOG_FORMAT, datefmt=DATE_FORMAT)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def create_blob_logger(**kwargs) -> logging.Handler:
|
||||
"""Create a blob storage-based logger."""
|
||||
from graphrag.logger.blob_workflow_logger import BlobWorkflowLogger
|
||||
|
||||
return BlobWorkflowLogger(
|
||||
connection_string=kwargs["connection_string"],
|
||||
container_name=kwargs["container_name"],
|
||||
base_dir=kwargs["base_dir"],
|
||||
storage_account_blob_url=kwargs["storage_account_blob_url"],
|
||||
)
|
||||
|
||||
|
||||
# --- register built-in implementations ---
|
||||
LoggerFactory.register(ReportingType.file.value, create_file_logger)
|
||||
LoggerFactory.register(ReportingType.blob.value, create_blob_logger)
|
||||
@ -32,15 +32,19 @@ Notes
|
||||
All progress logging now uses this standard logging system for consistency.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from __future__ import annotations
|
||||
|
||||
from graphrag.config.enums import ReportingType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from graphrag.logger.factory import (
|
||||
LoggerFactory,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
|
||||
DEFAULT_LOG_FILENAME = "indexing-engine.log"
|
||||
LOG_FORMAT = "%(asctime)s.%(msecs)04d - %(levelname)s - %(name)s - %(message)s"
|
||||
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
|
||||
def init_loggers(
|
||||
@ -50,26 +54,15 @@ def init_loggers(
|
||||
) -> None:
|
||||
"""Initialize logging handlers for graphrag based on configuration.
|
||||
|
||||
This function merges the functionality of configure_logging() and create_pipeline_logger()
|
||||
to provide a unified way to set up logging for the graphrag package.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : GraphRagConfig | None, default=None
|
||||
The GraphRAG configuration. If None, defaults to file-based reporting.
|
||||
root_dir : str | None, default=None
|
||||
The root directory for file-based logging.
|
||||
verbose : bool, default=False
|
||||
Whether to enable verbose (DEBUG) logging.
|
||||
log_file : Optional[Union[str, Path]], default=None
|
||||
Path to a specific log file. If provided, takes precedence over config.
|
||||
filename : Optional[str]
|
||||
Log filename on disk. If unset, will use a default name.
|
||||
"""
|
||||
# import BlobWorkflowLogger here to avoid circular imports
|
||||
from graphrag.logger.blob_workflow_logger import BlobWorkflowLogger
|
||||
|
||||
# extract reporting config from GraphRagConfig if provided
|
||||
reporting_config = config.reporting
|
||||
|
||||
logger = logging.getLogger("graphrag")
|
||||
log_level = logging.DEBUG if verbose else logging.INFO
|
||||
logger.setLevel(log_level)
|
||||
@ -82,27 +75,9 @@ def init_loggers(
|
||||
handler.close()
|
||||
logger.handlers.clear()
|
||||
|
||||
# create formatter with custom format
|
||||
formatter = logging.Formatter(fmt=LOG_FORMAT, datefmt=DATE_FORMAT)
|
||||
reporting_config = config.reporting
|
||||
config_dict = reporting_config.model_dump()
|
||||
args = {**config_dict, "root_dir": config.root_dir, "filename": filename}
|
||||
|
||||
# add more handlers based on configuration
|
||||
handler: logging.Handler
|
||||
match reporting_config.type:
|
||||
case ReportingType.file:
|
||||
# use the config-based file path
|
||||
log_dir = Path(config.root_dir) / (reporting_config.base_dir)
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
log_file_path = log_dir / filename
|
||||
handler = logging.FileHandler(str(log_file_path), mode="a")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
case ReportingType.blob:
|
||||
handler = BlobWorkflowLogger(
|
||||
reporting_config.connection_string,
|
||||
reporting_config.container_name,
|
||||
base_dir=reporting_config.base_dir,
|
||||
storage_account_blob_url=reporting_config.storage_account_blob_url,
|
||||
)
|
||||
logger.addHandler(handler)
|
||||
case _:
|
||||
logger.error("Unknown reporting type '%s'.", reporting_config.type)
|
||||
handler = LoggerFactory.create_logger(reporting_config.type, args)
|
||||
logger.addHandler(handler)
|
||||
|
||||
@ -29,15 +29,20 @@ class BlobPipelineStorage(PipelineStorage):
|
||||
_encoding: str
|
||||
_storage_account_blob_url: str | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_string: str | None,
|
||||
container_name: str,
|
||||
encoding: str = "utf-8",
|
||||
path_prefix: str | None = None,
|
||||
storage_account_blob_url: str | None = None,
|
||||
):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Create a new BlobStorage instance."""
|
||||
connection_string = kwargs.get("connection_string")
|
||||
storage_account_blob_url = kwargs.get("storage_account_blob_url")
|
||||
path_prefix = kwargs.get("base_dir")
|
||||
container_name = kwargs["container_name"]
|
||||
if container_name is None:
|
||||
msg = "No container name provided for blob storage."
|
||||
raise ValueError(msg)
|
||||
if connection_string is None and storage_account_blob_url is None:
|
||||
msg = "No storage account blob url provided for blob storage."
|
||||
raise ValueError(msg)
|
||||
|
||||
logger.info("Creating blob storage at %s", container_name)
|
||||
if connection_string:
|
||||
self._blob_service_client = BlobServiceClient.from_connection_string(
|
||||
connection_string
|
||||
@ -51,7 +56,7 @@ class BlobPipelineStorage(PipelineStorage):
|
||||
account_url=storage_account_blob_url,
|
||||
credential=DefaultAzureCredential(),
|
||||
)
|
||||
self._encoding = encoding
|
||||
self._encoding = kwargs.get("encoding", "utf-8")
|
||||
self._container_name = container_name
|
||||
self._connection_string = connection_string
|
||||
self._path_prefix = path_prefix or ""
|
||||
@ -271,11 +276,11 @@ class BlobPipelineStorage(PipelineStorage):
|
||||
return self
|
||||
path = str(Path(self._path_prefix) / name)
|
||||
return BlobPipelineStorage(
|
||||
self._connection_string,
|
||||
self._container_name,
|
||||
self._encoding,
|
||||
path,
|
||||
self._storage_account_blob_url,
|
||||
connection_string=self._connection_string,
|
||||
container_name=self._container_name,
|
||||
encoding=self._encoding,
|
||||
base_dir=path,
|
||||
storage_account_blob_url=self._storage_account_blob_url,
|
||||
)
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
@ -307,27 +312,6 @@ class BlobPipelineStorage(PipelineStorage):
|
||||
return ""
|
||||
|
||||
|
||||
def create_blob_storage(**kwargs: Any) -> PipelineStorage:
|
||||
"""Create a blob based storage."""
|
||||
connection_string = kwargs.get("connection_string")
|
||||
storage_account_blob_url = kwargs.get("storage_account_blob_url")
|
||||
base_dir = kwargs.get("base_dir")
|
||||
container_name = kwargs["container_name"]
|
||||
logger.info("Creating blob storage at %s", container_name)
|
||||
if container_name is None:
|
||||
msg = "No container name provided for blob storage."
|
||||
raise ValueError(msg)
|
||||
if connection_string is None and storage_account_blob_url is None:
|
||||
msg = "No storage account blob url provided for blob storage."
|
||||
raise ValueError(msg)
|
||||
return BlobPipelineStorage(
|
||||
connection_string=connection_string,
|
||||
container_name=container_name,
|
||||
path_prefix=base_dir,
|
||||
storage_account_blob_url=storage_account_blob_url,
|
||||
)
|
||||
|
||||
|
||||
def validate_blob_container_name(container_name: str):
|
||||
"""
|
||||
Check if the provided blob container name is valid based on Azure rules.
|
||||
|
||||
@ -39,15 +39,20 @@ class CosmosDBPipelineStorage(PipelineStorage):
|
||||
_encoding: str
|
||||
_no_id_prefixes: list[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
database_name: str,
|
||||
container_name: str,
|
||||
cosmosdb_account_url: str | None = None,
|
||||
connection_string: str | None = None,
|
||||
encoding: str = "utf-8",
|
||||
):
|
||||
"""Initialize the CosmosDB Storage."""
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Create a CosmosDB storage instance."""
|
||||
logger.info("Creating cosmosdb storage")
|
||||
cosmosdb_account_url = kwargs.get("cosmosdb_account_url")
|
||||
connection_string = kwargs.get("connection_string")
|
||||
database_name = kwargs["base_dir"]
|
||||
container_name = kwargs["container_name"]
|
||||
if not database_name:
|
||||
msg = "No base_dir provided for database name"
|
||||
raise ValueError(msg)
|
||||
if connection_string is None and cosmosdb_account_url is None:
|
||||
msg = "connection_string or cosmosdb_account_url is required."
|
||||
raise ValueError(msg)
|
||||
|
||||
if connection_string:
|
||||
self._cosmos_client = CosmosClient.from_connection_string(connection_string)
|
||||
else:
|
||||
@ -60,7 +65,7 @@ class CosmosDBPipelineStorage(PipelineStorage):
|
||||
url=cosmosdb_account_url,
|
||||
credential=DefaultAzureCredential(),
|
||||
)
|
||||
self._encoding = encoding
|
||||
self._encoding = kwargs.get("encoding", "utf-8")
|
||||
self._database_name = database_name
|
||||
self._connection_string = connection_string
|
||||
self._cosmosdb_account_url = cosmosdb_account_url
|
||||
@ -348,29 +353,6 @@ class CosmosDBPipelineStorage(PipelineStorage):
|
||||
return ""
|
||||
|
||||
|
||||
# TODO remove this helper function and have the factory instantiate the class directly
|
||||
# once the new config system is in place and will enforce the correct types/existence of certain fields
|
||||
def create_cosmosdb_storage(**kwargs: Any) -> PipelineStorage:
|
||||
"""Create a CosmosDB storage instance."""
|
||||
logger.info("Creating cosmosdb storage")
|
||||
cosmosdb_account_url = kwargs.get("cosmosdb_account_url")
|
||||
connection_string = kwargs.get("connection_string")
|
||||
base_dir = kwargs["base_dir"]
|
||||
container_name = kwargs["container_name"]
|
||||
if not base_dir:
|
||||
msg = "No base_dir provided for database name"
|
||||
raise ValueError(msg)
|
||||
if connection_string is None and cosmosdb_account_url is None:
|
||||
msg = "connection_string or cosmosdb_account_url is required."
|
||||
raise ValueError(msg)
|
||||
return CosmosDBPipelineStorage(
|
||||
cosmosdb_account_url=cosmosdb_account_url,
|
||||
connection_string=connection_string,
|
||||
database_name=base_dir,
|
||||
container_name=container_name,
|
||||
)
|
||||
|
||||
|
||||
def _create_progress_status(
|
||||
num_loaded: int, num_filtered: int, num_total: int
|
||||
) -> Progress:
|
||||
|
||||
@ -5,13 +5,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import suppress
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from graphrag.config.enums import StorageType
|
||||
from graphrag.storage.blob_pipeline_storage import create_blob_storage
|
||||
from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage
|
||||
from graphrag.storage.file_pipeline_storage import create_file_storage
|
||||
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
|
||||
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
|
||||
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
||||
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -29,8 +28,7 @@ class StorageFactory:
|
||||
for individual enforcement of required/optional arguments.
|
||||
"""
|
||||
|
||||
_storage_registry: ClassVar[dict[str, Callable[..., PipelineStorage]]] = {}
|
||||
storage_types: ClassVar[dict[str, type]] = {} # For backward compatibility
|
||||
_registry: ClassVar[dict[str, Callable[..., PipelineStorage]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
@ -40,23 +38,13 @@ class StorageFactory:
|
||||
|
||||
Args:
|
||||
storage_type: The type identifier for the storage.
|
||||
creator: A callable that creates an instance of the storage.
|
||||
"""
|
||||
cls._storage_registry[storage_type] = creator
|
||||
creator: A class or callable that creates an instance of PipelineStorage.
|
||||
|
||||
# For backward compatibility with code that may access storage_types directly
|
||||
if (
|
||||
callable(creator)
|
||||
and hasattr(creator, "__annotations__")
|
||||
and "return" in creator.__annotations__
|
||||
):
|
||||
with suppress(TypeError, KeyError):
|
||||
cls.storage_types[storage_type] = creator.__annotations__["return"]
|
||||
"""
|
||||
cls._registry[storage_type] = creator
|
||||
|
||||
@classmethod
|
||||
def create_storage(
|
||||
cls, storage_type: StorageType | str, kwargs: dict
|
||||
) -> PipelineStorage:
|
||||
def create_storage(cls, storage_type: str, kwargs: dict) -> PipelineStorage:
|
||||
"""Create a storage object from the provided type.
|
||||
|
||||
Args:
|
||||
@ -71,31 +59,25 @@ class StorageFactory:
|
||||
------
|
||||
ValueError: If the storage type is not registered.
|
||||
"""
|
||||
storage_type_str = (
|
||||
storage_type.value
|
||||
if isinstance(storage_type, StorageType)
|
||||
else storage_type
|
||||
)
|
||||
|
||||
if storage_type_str not in cls._storage_registry:
|
||||
if storage_type not in cls._registry:
|
||||
msg = f"Unknown storage type: {storage_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return cls._storage_registry[storage_type_str](**kwargs)
|
||||
return cls._registry[storage_type](**kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_storage_types(cls) -> list[str]:
|
||||
"""Get the registered storage implementations."""
|
||||
return list(cls._storage_registry.keys())
|
||||
return list(cls._registry.keys())
|
||||
|
||||
@classmethod
|
||||
def is_supported_storage_type(cls, storage_type: str) -> bool:
|
||||
def is_supported_type(cls, storage_type: str) -> bool:
|
||||
"""Check if the given storage type is supported."""
|
||||
return storage_type in cls._storage_registry
|
||||
return storage_type in cls._registry
|
||||
|
||||
|
||||
# --- Register default implementations ---
|
||||
StorageFactory.register(StorageType.blob.value, create_blob_storage)
|
||||
StorageFactory.register(StorageType.cosmosdb.value, create_cosmosdb_storage)
|
||||
StorageFactory.register(StorageType.file.value, create_file_storage)
|
||||
StorageFactory.register(StorageType.memory.value, lambda **_: MemoryPipelineStorage())
|
||||
# --- register built-in storage implementations ---
|
||||
StorageFactory.register(StorageType.blob.value, BlobPipelineStorage)
|
||||
StorageFactory.register(StorageType.cosmosdb.value, CosmosDBPipelineStorage)
|
||||
StorageFactory.register(StorageType.file.value, FilePipelineStorage)
|
||||
StorageFactory.register(StorageType.memory.value, MemoryPipelineStorage)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'FileStorage' and 'FilePipelineStorage' models."""
|
||||
"""File-based Storage implementation of PipelineStorage."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
@ -30,10 +30,11 @@ class FilePipelineStorage(PipelineStorage):
|
||||
_root_dir: str
|
||||
_encoding: str
|
||||
|
||||
def __init__(self, root_dir: str = "", encoding: str = "utf-8"):
|
||||
"""Init method definition."""
|
||||
self._root_dir = root_dir
|
||||
self._encoding = encoding
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Create a file based storage."""
|
||||
self._root_dir = kwargs.get("base_dir", "")
|
||||
self._encoding = kwargs.get("encoding", "utf-8")
|
||||
logger.info("Creating file storage at %s", self._root_dir)
|
||||
Path(self._root_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def find(
|
||||
@ -148,7 +149,8 @@ class FilePipelineStorage(PipelineStorage):
|
||||
"""Create a child storage instance."""
|
||||
if name is None:
|
||||
return self
|
||||
return FilePipelineStorage(str(Path(self._root_dir) / Path(name)))
|
||||
child_path = str(Path(self._root_dir) / Path(name))
|
||||
return FilePipelineStorage(base_dir=child_path, encoding=self._encoding)
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
"""Return the keys in the storage."""
|
||||
@ -167,10 +169,3 @@ class FilePipelineStorage(PipelineStorage):
|
||||
def join_path(file_path: str, file_name: str) -> Path:
|
||||
"""Join a path and a file. Independent of the OS."""
|
||||
return Path(file_path) / Path(file_name).parent / Path(file_name).name
|
||||
|
||||
|
||||
def create_file_storage(**kwargs: Any) -> PipelineStorage:
|
||||
"""Create a file based storage."""
|
||||
base_dir = kwargs["base_dir"]
|
||||
logger.info("Creating file storage at %s", base_dir)
|
||||
return FilePipelineStorage(root_dir=base_dir)
|
||||
|
||||
@ -250,10 +250,10 @@ def create_storage_from_config(output: StorageConfig) -> PipelineStorage:
|
||||
def create_cache_from_config(cache: CacheConfig, root_dir: str) -> PipelineCache:
|
||||
"""Create a cache object from the config."""
|
||||
cache_config = cache.model_dump()
|
||||
kwargs = {**cache_config, "root_dir": root_dir}
|
||||
return CacheFactory().create_cache(
|
||||
cache_type=cache_config["type"],
|
||||
root_dir=root_dir,
|
||||
kwargs=cache_config,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,52 +1,88 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A package containing a factory and supported vector store types."""
|
||||
"""Factory functions for creating a vector store."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import ClassVar
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from graphrag.config.enums import VectorStoreType
|
||||
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
|
||||
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
class VectorStoreType(str, Enum):
|
||||
"""The supported vector store types."""
|
||||
|
||||
LanceDB = "lancedb"
|
||||
AzureAISearch = "azure_ai_search"
|
||||
CosmosDB = "cosmosdb"
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
|
||||
|
||||
class VectorStoreFactory:
|
||||
"""A factory for vector stores.
|
||||
|
||||
Includes a method for users to register a custom vector store implementation.
|
||||
|
||||
Configuration arguments are passed to each vector store implementation as kwargs
|
||||
for individual enforcement of required/optional arguments.
|
||||
"""
|
||||
|
||||
vector_store_types: ClassVar[dict[str, type]] = {}
|
||||
_registry: ClassVar[dict[str, Callable[..., BaseVectorStore]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, vector_store_type: str, vector_store: type):
|
||||
"""Register a custom vector store implementation."""
|
||||
cls.vector_store_types[vector_store_type] = vector_store
|
||||
def register(
|
||||
cls, vector_store_type: str, creator: Callable[..., BaseVectorStore]
|
||||
) -> None:
|
||||
"""Register a custom vector store implementation.
|
||||
|
||||
Args:
|
||||
vector_store_type: The type identifier for the vector store.
|
||||
creator: A class or callable that creates an instance of BaseVectorStore.
|
||||
|
||||
Raises
|
||||
------
|
||||
TypeError: If creator is a class type instead of a factory function.
|
||||
"""
|
||||
cls._registry[vector_store_type] = creator
|
||||
|
||||
@classmethod
|
||||
def create_vector_store(
|
||||
cls, vector_store_type: VectorStoreType | str, kwargs: dict
|
||||
cls, vector_store_type: str, kwargs: dict
|
||||
) -> BaseVectorStore:
|
||||
"""Create or get a vector store from the provided type."""
|
||||
match vector_store_type:
|
||||
case VectorStoreType.LanceDB:
|
||||
return LanceDBVectorStore(**kwargs)
|
||||
case VectorStoreType.AzureAISearch:
|
||||
return AzureAISearchVectorStore(**kwargs)
|
||||
case VectorStoreType.CosmosDB:
|
||||
return CosmosDBVectorStore(**kwargs)
|
||||
case _:
|
||||
if vector_store_type in cls.vector_store_types:
|
||||
return cls.vector_store_types[vector_store_type](**kwargs)
|
||||
msg = f"Unknown vector store type: {vector_store_type}"
|
||||
raise ValueError(msg)
|
||||
"""Create a vector store object from the provided type.
|
||||
|
||||
Args:
|
||||
vector_store_type: The type of vector store to create.
|
||||
kwargs: Additional keyword arguments for the vector store constructor.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A BaseVectorStore instance.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError: If the vector store type is not registered.
|
||||
"""
|
||||
if vector_store_type not in cls._registry:
|
||||
msg = f"Unknown vector store type: {vector_store_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return cls._registry[vector_store_type](**kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_vector_store_types(cls) -> list[str]:
|
||||
"""Get the registered vector store implementations."""
|
||||
return list(cls._registry.keys())
|
||||
|
||||
@classmethod
|
||||
def is_supported_type(cls, vector_store_type: str) -> bool:
|
||||
"""Check if the given vector store type is supported."""
|
||||
return vector_store_type in cls._registry
|
||||
|
||||
|
||||
# --- register built-in vector store implementations ---
|
||||
VectorStoreFactory.register(VectorStoreType.LanceDB.value, LanceDBVectorStore)
|
||||
VectorStoreFactory.register(
|
||||
VectorStoreType.AzureAISearch.value, AzureAISearchVectorStore
|
||||
)
|
||||
VectorStoreFactory.register(VectorStoreType.CosmosDB.value, CosmosDBVectorStore)
|
||||
|
||||
@ -75,6 +75,7 @@ dev = [
|
||||
"jupyter>=1.1.1",
|
||||
"nbconvert>=7.16.4",
|
||||
"poethepoet>=0.31.1",
|
||||
"pandas-stubs>=2.3.0.250703",
|
||||
"pyright>=1.1.390",
|
||||
"pytest>=8.3.4",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
@ -242,7 +243,7 @@ ignore = [
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/*" = ["S", "D", "ANN", "T201", "ASYNC", "ARG", "PTH", "TRY"]
|
||||
"graphrag/index/config/*" = ["TCH"]
|
||||
"*.ipynb" = ["T201"]
|
||||
"*.ipynb" = ["T201", "S101", "PT015", "B011"]
|
||||
|
||||
[tool.ruff.lint.flake8-builtins]
|
||||
builtins-ignorelist = ["input", "id", "bytes"]
|
||||
|
||||
2
tests/integration/cache/__init__.py
vendored
Normal file
2
tests/integration/cache/__init__.py
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
169
tests/integration/cache/test_factory.py
vendored
Normal file
169
tests/integration/cache/test_factory.py
vendored
Normal file
@ -0,0 +1,169 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""CacheFactory Tests.
|
||||
|
||||
These tests will test the CacheFactory class and the creation of each cache type that is natively supported.
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from graphrag.cache.factory import CacheFactory
|
||||
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
|
||||
from graphrag.cache.memory_pipeline_cache import InMemoryCache
|
||||
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.config.enums import CacheType
|
||||
|
||||
# cspell:disable-next-line well-known-key
|
||||
WELL_KNOWN_BLOB_STORAGE_KEY = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;"
|
||||
# cspell:disable-next-line well-known-key
|
||||
WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
|
||||
|
||||
|
||||
def test_create_noop_cache():
|
||||
kwargs = {}
|
||||
cache = CacheFactory.create_cache(CacheType.none.value, kwargs)
|
||||
assert isinstance(cache, NoopPipelineCache)
|
||||
|
||||
|
||||
def test_create_memory_cache():
|
||||
kwargs = {}
|
||||
cache = CacheFactory.create_cache(CacheType.memory.value, kwargs)
|
||||
assert isinstance(cache, InMemoryCache)
|
||||
|
||||
|
||||
def test_create_file_cache():
|
||||
kwargs = {"root_dir": "/tmp", "base_dir": "testcache"}
|
||||
cache = CacheFactory.create_cache(CacheType.file.value, kwargs)
|
||||
assert isinstance(cache, JsonPipelineCache)
|
||||
|
||||
|
||||
def test_create_blob_cache():
|
||||
kwargs = {
|
||||
"connection_string": WELL_KNOWN_BLOB_STORAGE_KEY,
|
||||
"container_name": "testcontainer",
|
||||
"base_dir": "testcache",
|
||||
}
|
||||
cache = CacheFactory.create_cache(CacheType.blob.value, kwargs)
|
||||
assert isinstance(cache, JsonPipelineCache)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not sys.platform.startswith("win"),
|
||||
reason="cosmosdb emulator is only available on windows runners at this time",
|
||||
)
|
||||
def test_create_cosmosdb_cache():
|
||||
kwargs = {
|
||||
"connection_string": WELL_KNOWN_COSMOS_CONNECTION_STRING,
|
||||
"base_dir": "testdatabase",
|
||||
"container_name": "testcontainer",
|
||||
}
|
||||
cache = CacheFactory.create_cache(CacheType.cosmosdb.value, kwargs)
|
||||
assert isinstance(cache, JsonPipelineCache)
|
||||
|
||||
|
||||
def test_register_and_create_custom_cache():
|
||||
"""Test registering and creating a custom cache type."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create a mock that satisfies the PipelineCache interface
|
||||
custom_cache_class = MagicMock(spec=PipelineCache)
|
||||
# Make the mock return a mock instance when instantiated
|
||||
instance = MagicMock()
|
||||
instance.initialized = True
|
||||
custom_cache_class.return_value = instance
|
||||
|
||||
CacheFactory.register("custom", lambda **kwargs: custom_cache_class(**kwargs))
|
||||
cache = CacheFactory.create_cache("custom", {})
|
||||
|
||||
assert custom_cache_class.called
|
||||
assert cache is instance
|
||||
# Access the attribute we set on our mock
|
||||
assert cache.initialized is True # type: ignore # Attribute only exists on our mock
|
||||
|
||||
# Check if it's in the list of registered cache types
|
||||
assert "custom" in CacheFactory.get_cache_types()
|
||||
assert CacheFactory.is_supported_type("custom")
|
||||
|
||||
|
||||
def test_get_cache_types():
|
||||
cache_types = CacheFactory.get_cache_types()
|
||||
# Check that built-in types are registered
|
||||
assert CacheType.none.value in cache_types
|
||||
assert CacheType.memory.value in cache_types
|
||||
assert CacheType.file.value in cache_types
|
||||
assert CacheType.blob.value in cache_types
|
||||
assert CacheType.cosmosdb.value in cache_types
|
||||
|
||||
|
||||
def test_create_unknown_cache():
|
||||
with pytest.raises(ValueError, match="Unknown cache type: unknown"):
|
||||
CacheFactory.create_cache("unknown", {})
|
||||
|
||||
|
||||
def test_is_supported_type():
|
||||
# Test built-in types
|
||||
assert CacheFactory.is_supported_type(CacheType.none.value)
|
||||
assert CacheFactory.is_supported_type(CacheType.memory.value)
|
||||
assert CacheFactory.is_supported_type(CacheType.file.value)
|
||||
assert CacheFactory.is_supported_type(CacheType.blob.value)
|
||||
assert CacheFactory.is_supported_type(CacheType.cosmosdb.value)
|
||||
|
||||
# Test unknown type
|
||||
assert not CacheFactory.is_supported_type("unknown")
|
||||
|
||||
|
||||
def test_enum_and_string_compatibility():
|
||||
"""Test that both enum and string types work for cache creation."""
|
||||
kwargs = {}
|
||||
|
||||
# Test with enum
|
||||
cache_enum = CacheFactory.create_cache(CacheType.memory, kwargs)
|
||||
assert isinstance(cache_enum, InMemoryCache)
|
||||
|
||||
# Test with string
|
||||
cache_str = CacheFactory.create_cache("memory", kwargs)
|
||||
assert isinstance(cache_str, InMemoryCache)
|
||||
|
||||
# Both should create the same type
|
||||
assert type(cache_enum) is type(cache_str)
|
||||
|
||||
|
||||
def test_register_class_directly_works():
|
||||
"""Test that registering a class directly works (CacheFactory allows this)."""
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
|
||||
class CustomCache(PipelineCache):
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
async def get(self, key: str):
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value, debug_data=None):
|
||||
pass
|
||||
|
||||
async def has(self, key: str):
|
||||
return False
|
||||
|
||||
async def delete(self, key: str):
|
||||
pass
|
||||
|
||||
async def clear(self):
|
||||
pass
|
||||
|
||||
def child(self, name: str):
|
||||
return self
|
||||
|
||||
# CacheFactory allows registering classes directly (no TypeError)
|
||||
CacheFactory.register("custom_class", CustomCache)
|
||||
|
||||
# Verify it was registered
|
||||
assert "custom_class" in CacheFactory.get_cache_types()
|
||||
assert CacheFactory.is_supported_type("custom_class")
|
||||
|
||||
# Test creating an instance
|
||||
cache = CacheFactory.create_cache("custom_class", {})
|
||||
assert isinstance(cache, CustomCache)
|
||||
65
tests/integration/logging/test_factory.py
Normal file
65
tests/integration/logging/test_factory.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""LoggerFactory Tests.
|
||||
|
||||
These tests will test the LoggerFactory class and the creation of each reporting type that is natively supported.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from graphrag.config.enums import ReportingType
|
||||
from graphrag.logger.blob_workflow_logger import BlobWorkflowLogger
|
||||
from graphrag.logger.factory import LoggerFactory
|
||||
|
||||
# cspell:disable-next-line well-known-key
|
||||
WELL_KNOWN_BLOB_STORAGE_KEY = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;"
|
||||
# cspell:disable-next-line well-known-key
|
||||
WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Blob storage emulator is not available in this environment")
|
||||
def test_create_blob_logger():
|
||||
kwargs = {
|
||||
"type": "blob",
|
||||
"connection_string": WELL_KNOWN_BLOB_STORAGE_KEY,
|
||||
"base_dir": "testbasedir",
|
||||
"container_name": "testcontainer",
|
||||
}
|
||||
logger = LoggerFactory.create_logger(ReportingType.blob.value, kwargs)
|
||||
assert isinstance(logger, BlobWorkflowLogger)
|
||||
|
||||
|
||||
def test_register_and_create_custom_logger():
|
||||
"""Test registering and creating a custom logger type."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
custom_logger_class = MagicMock(spec=logging.Handler)
|
||||
instance = MagicMock()
|
||||
instance.initialized = True
|
||||
custom_logger_class.return_value = instance
|
||||
|
||||
LoggerFactory.register("custom", lambda **kwargs: custom_logger_class(**kwargs))
|
||||
logger = LoggerFactory.create_logger("custom", {})
|
||||
|
||||
assert custom_logger_class.called
|
||||
assert logger is instance
|
||||
# Access the attribute we set on our mock
|
||||
assert logger.initialized is True # type: ignore # Attribute only exists on our mock
|
||||
|
||||
# Check if it's in the list of registered logger types
|
||||
assert "custom" in LoggerFactory.get_logger_types()
|
||||
assert LoggerFactory.is_supported_type("custom")
|
||||
|
||||
|
||||
def test_get_logger_types():
|
||||
logger_types = LoggerFactory.get_logger_types()
|
||||
# Check that built-in types are registered
|
||||
assert ReportingType.file.value in logger_types
|
||||
assert ReportingType.blob.value in logger_types
|
||||
|
||||
|
||||
def test_create_unknown_logger():
|
||||
with pytest.raises(ValueError, match="Unknown reporting type: unknown"):
|
||||
LoggerFactory.create_logger("unknown", {})
|
||||
@ -24,7 +24,7 @@ if not sys.platform.startswith("win"):
|
||||
async def test_find():
|
||||
storage = CosmosDBPipelineStorage(
|
||||
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
|
||||
database_name="testfind",
|
||||
base_dir="testfind",
|
||||
container_name="testfindcontainer",
|
||||
)
|
||||
try:
|
||||
@ -70,7 +70,7 @@ async def test_find():
|
||||
async def test_child():
|
||||
storage = CosmosDBPipelineStorage(
|
||||
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
|
||||
database_name="testchild",
|
||||
base_dir="testchild",
|
||||
container_name="testchildcontainer",
|
||||
)
|
||||
try:
|
||||
@ -83,7 +83,7 @@ async def test_child():
|
||||
async def test_clear():
|
||||
storage = CosmosDBPipelineStorage(
|
||||
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
|
||||
database_name="testclear",
|
||||
base_dir="testclear",
|
||||
container_name="testclearcontainer",
|
||||
)
|
||||
try:
|
||||
@ -113,7 +113,7 @@ async def test_clear():
|
||||
async def test_get_creation_date():
|
||||
storage = CosmosDBPipelineStorage(
|
||||
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
|
||||
database_name="testclear",
|
||||
base_dir="testclear",
|
||||
container_name="testclearcontainer",
|
||||
)
|
||||
try:
|
||||
|
||||
@ -31,7 +31,7 @@ def test_create_blob_storage():
|
||||
"base_dir": "testbasedir",
|
||||
"container_name": "testcontainer",
|
||||
}
|
||||
storage = StorageFactory.create_storage(StorageType.blob, kwargs)
|
||||
storage = StorageFactory.create_storage(StorageType.blob.value, kwargs)
|
||||
assert isinstance(storage, BlobPipelineStorage)
|
||||
|
||||
|
||||
@ -46,19 +46,19 @@ def test_create_cosmosdb_storage():
|
||||
"base_dir": "testdatabase",
|
||||
"container_name": "testcontainer",
|
||||
}
|
||||
storage = StorageFactory.create_storage(StorageType.cosmosdb, kwargs)
|
||||
storage = StorageFactory.create_storage(StorageType.cosmosdb.value, kwargs)
|
||||
assert isinstance(storage, CosmosDBPipelineStorage)
|
||||
|
||||
|
||||
def test_create_file_storage():
|
||||
kwargs = {"type": "file", "base_dir": "/tmp/teststorage"}
|
||||
storage = StorageFactory.create_storage(StorageType.file, kwargs)
|
||||
storage = StorageFactory.create_storage(StorageType.file.value, kwargs)
|
||||
assert isinstance(storage, FilePipelineStorage)
|
||||
|
||||
|
||||
def test_create_memory_storage():
|
||||
kwargs = {"type": "memory"}
|
||||
storage = StorageFactory.create_storage(StorageType.memory, kwargs)
|
||||
kwargs = {} # MemoryPipelineStorage doesn't accept any constructor parameters
|
||||
storage = StorageFactory.create_storage(StorageType.memory.value, kwargs)
|
||||
assert isinstance(storage, MemoryPipelineStorage)
|
||||
|
||||
|
||||
@ -84,7 +84,7 @@ def test_register_and_create_custom_storage():
|
||||
|
||||
# Check if it's in the list of registered storage types
|
||||
assert "custom" in StorageFactory.get_storage_types()
|
||||
assert StorageFactory.is_supported_storage_type("custom")
|
||||
assert StorageFactory.is_supported_type("custom")
|
||||
|
||||
|
||||
def test_get_storage_types():
|
||||
@ -96,13 +96,65 @@ def test_get_storage_types():
|
||||
assert StorageType.cosmosdb.value in storage_types
|
||||
|
||||
|
||||
def test_backward_compatibility():
|
||||
"""Test that the storage_types attribute is still accessible for backward compatibility."""
|
||||
assert hasattr(StorageFactory, "storage_types")
|
||||
# The storage_types attribute should be a dict
|
||||
assert isinstance(StorageFactory.storage_types, dict)
|
||||
|
||||
|
||||
def test_create_unknown_storage():
|
||||
with pytest.raises(ValueError, match="Unknown storage type: unknown"):
|
||||
StorageFactory.create_storage("unknown", {})
|
||||
|
||||
|
||||
def test_register_class_directly_works():
|
||||
"""Test that registering a class directly works (StorageFactory allows this)."""
|
||||
import re
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
class CustomStorage(PipelineStorage):
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def find(
|
||||
self,
|
||||
file_pattern: re.Pattern[str],
|
||||
base_dir: str | None = None,
|
||||
file_filter: dict[str, Any] | None = None,
|
||||
max_count=-1,
|
||||
) -> Iterator[tuple[str, dict[str, Any]]]:
|
||||
return iter([])
|
||||
|
||||
async def get(
|
||||
self, key: str, as_bytes: bool | None = None, encoding: str | None = None
|
||||
) -> Any:
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
|
||||
pass
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
pass
|
||||
|
||||
async def has(self, key: str) -> bool:
|
||||
return False
|
||||
|
||||
async def clear(self) -> None:
|
||||
pass
|
||||
|
||||
def child(self, name: str | None) -> "PipelineStorage":
|
||||
return self
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
return []
|
||||
|
||||
async def get_creation_date(self, key: str) -> str:
|
||||
return "2024-01-01 00:00:00 +0000"
|
||||
|
||||
# StorageFactory allows registering classes directly (no TypeError)
|
||||
StorageFactory.register("custom_class", CustomStorage)
|
||||
|
||||
# Verify it was registered
|
||||
assert "custom_class" in StorageFactory.get_storage_types()
|
||||
assert StorageFactory.is_supported_type("custom_class")
|
||||
|
||||
# Test creating an instance
|
||||
storage = StorageFactory.create_storage("custom_class", {})
|
||||
assert isinstance(storage, CustomStorage)
|
||||
|
||||
144
tests/integration/vector_stores/test_factory.py
Normal file
144
tests/integration/vector_stores/test_factory.py
Normal file
@ -0,0 +1,144 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""VectorStoreFactory Tests.
|
||||
|
||||
These tests will test the VectorStoreFactory class and the creation of each vector store type that is natively supported.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from graphrag.config.enums import VectorStoreType
|
||||
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
|
||||
from graphrag.vector_stores.factory import VectorStoreFactory
|
||||
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
|
||||
|
||||
def test_create_lancedb_vector_store():
|
||||
kwargs = {
|
||||
"collection_name": "test_collection",
|
||||
"db_uri": "/tmp/lancedb",
|
||||
}
|
||||
vector_store = VectorStoreFactory.create_vector_store(
|
||||
VectorStoreType.LanceDB.value, kwargs
|
||||
)
|
||||
assert isinstance(vector_store, LanceDBVectorStore)
|
||||
assert vector_store.collection_name == "test_collection"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Azure AI Search requires credentials and setup")
|
||||
def test_create_azure_ai_search_vector_store():
|
||||
kwargs = {
|
||||
"collection_name": "test_collection",
|
||||
"url": "https://test.search.windows.net",
|
||||
"api_key": "test_key",
|
||||
}
|
||||
vector_store = VectorStoreFactory.create_vector_store(
|
||||
VectorStoreType.AzureAISearch.value, kwargs
|
||||
)
|
||||
assert isinstance(vector_store, AzureAISearchVectorStore)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="CosmosDB requires credentials and setup")
|
||||
def test_create_cosmosdb_vector_store():
|
||||
kwargs = {
|
||||
"collection_name": "test_collection",
|
||||
"connection_string": "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=test_key==",
|
||||
"database_name": "test_db",
|
||||
}
|
||||
vector_store = VectorStoreFactory.create_vector_store(
|
||||
VectorStoreType.CosmosDB.value, kwargs
|
||||
)
|
||||
assert isinstance(vector_store, CosmosDBVectorStore)
|
||||
|
||||
|
||||
def test_register_and_create_custom_vector_store():
|
||||
"""Test registering and creating a custom vector store type."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create a mock that satisfies the BaseVectorStore interface
|
||||
custom_vector_store_class = MagicMock(spec=BaseVectorStore)
|
||||
# Make the mock return a mock instance when instantiated
|
||||
instance = MagicMock()
|
||||
instance.initialized = True
|
||||
custom_vector_store_class.return_value = instance
|
||||
|
||||
VectorStoreFactory.register(
|
||||
"custom", lambda **kwargs: custom_vector_store_class(**kwargs)
|
||||
)
|
||||
vector_store = VectorStoreFactory.create_vector_store("custom", {})
|
||||
|
||||
assert custom_vector_store_class.called
|
||||
assert vector_store is instance
|
||||
# Access the attribute we set on our mock
|
||||
assert vector_store.initialized is True # type: ignore # Attribute only exists on our mock
|
||||
|
||||
# Check if it's in the list of registered vector store types
|
||||
assert "custom" in VectorStoreFactory.get_vector_store_types()
|
||||
assert VectorStoreFactory.is_supported_type("custom")
|
||||
|
||||
|
||||
def test_get_vector_store_types():
|
||||
vector_store_types = VectorStoreFactory.get_vector_store_types()
|
||||
# Check that built-in types are registered
|
||||
assert VectorStoreType.LanceDB.value in vector_store_types
|
||||
assert VectorStoreType.AzureAISearch.value in vector_store_types
|
||||
assert VectorStoreType.CosmosDB.value in vector_store_types
|
||||
|
||||
|
||||
def test_create_unknown_vector_store():
|
||||
with pytest.raises(ValueError, match="Unknown vector store type: unknown"):
|
||||
VectorStoreFactory.create_vector_store("unknown", {})
|
||||
|
||||
|
||||
def test_is_supported_type():
|
||||
# Test built-in types
|
||||
assert VectorStoreFactory.is_supported_type(VectorStoreType.LanceDB.value)
|
||||
assert VectorStoreFactory.is_supported_type(VectorStoreType.AzureAISearch.value)
|
||||
assert VectorStoreFactory.is_supported_type(VectorStoreType.CosmosDB.value)
|
||||
|
||||
# Test unknown type
|
||||
assert not VectorStoreFactory.is_supported_type("unknown")
|
||||
|
||||
|
||||
def test_register_class_directly_works():
|
||||
"""Test that registering a class directly works (VectorStoreFactory allows this)."""
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
|
||||
class CustomVectorStore(BaseVectorStore):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def connect(self, **kwargs):
|
||||
pass
|
||||
|
||||
def load_documents(self, documents, overwrite=True):
|
||||
pass
|
||||
|
||||
def similarity_search_by_vector(self, query_embedding, k=10, **kwargs):
|
||||
return []
|
||||
|
||||
def similarity_search_by_text(self, text, text_embedder, k=10, **kwargs):
|
||||
return []
|
||||
|
||||
def filter_by_id(self, include_ids):
|
||||
return {}
|
||||
|
||||
def search_by_id(self, id):
|
||||
from graphrag.vector_stores.base import VectorStoreDocument
|
||||
|
||||
return VectorStoreDocument(id=id, text="test", vector=None)
|
||||
|
||||
# VectorStoreFactory allows registering classes directly (no TypeError)
|
||||
VectorStoreFactory.register("custom_class", CustomVectorStore)
|
||||
|
||||
# Verify it was registered
|
||||
assert "custom_class" in VectorStoreFactory.get_vector_store_types()
|
||||
assert VectorStoreFactory.is_supported_type("custom_class")
|
||||
|
||||
# Test creating an instance
|
||||
vector_store = VectorStoreFactory.create_vector_store(
|
||||
"custom_class", {"collection_name": "test"}
|
||||
)
|
||||
assert isinstance(vector_store, CustomVectorStore)
|
||||
@ -13,7 +13,7 @@ TEMP_DIR = "./.tmp"
|
||||
|
||||
|
||||
def create_cache():
|
||||
storage = FilePipelineStorage(root_dir=os.path.join(os.getcwd(), ".tmp"))
|
||||
storage = FilePipelineStorage(base_dir=os.path.join(os.getcwd(), ".tmp"))
|
||||
return JsonPipelineCache(storage)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user