Refactor CacheFactory, StorageFactory, and VectorStoreFactory to use consistent registration patterns and add custom vector store documentation (#2006)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled

* Initial plan

* Refactor VectorStoreFactory to use registration functionality like StorageFactory

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Fix linting issues in VectorStoreFactory refactoring

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Remove backward compatibility support from VectorStoreFactory and StorageFactory

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Run ruff check --fix and ruff format, add semversioner file

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* ruff formatting fixes

* Fix pytest errors in storage factory tests by updating PipelineStorage interface implementation

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* ruff formatting fixes

* update storage factory design

* Refactor CacheFactory to use registration functionality like StorageFactory

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* revert copilot changes

* fix copilot changes

* update comments

* Fix failing pytest compatibility for factory tests

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* update class instantiation issue

* ruff fixes

* fix pytest

* add default value

* ruff formatting changes

* ruff fixes

* revert minor changes

* cleanup cache factory

* Update CacheFactory tests to match consistent factory pattern

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* update pytest thresholds

* adjust threshold levels

* Add custom vector store implementation notebook

Create comprehensive notebook demonstrating how to implement and register custom vector stores with GraphRAG as a plug-and-play framework. Includes:

- Complete implementation of SimpleInMemoryVectorStore
- Registration with VectorStoreFactory
- Testing and validation examples
- Configuration examples for GraphRAG settings
- Advanced features and best practices
- Production considerations checklist

The notebook provides a complete walkthrough for developers to understand and implement their own vector store backends.

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* remove sample notebook for now

* update tests

* fix cache pytests

* add pandas-stub to dev dependencies

* disable warning check for well known key

* skip tests when running on ubuntu

* add documentation for custom vector store implementations

* ignore ruff findings in notebooks

* fix merge breakages

* speedup CLI import statements

* remove unnecessary import statements in init file

* Add str type option on storage/cache type

* Fix store name

* Add LoggerFactory

* Fix up logging setup across CLI/API

* Add LoggerFactory test

* Fix err message

* Semver

* Remove enums from factory methods

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>
Co-authored-by: Josh Bradley <joshbradley@microsoft.com>
Co-authored-by: Nathan Evans <github@talkswithnumbers.com>
This commit is contained in:
Copilot 2025-08-28 13:53:07 -07:00 committed by GitHub
parent 69ad36e735
commit 2030f94eb4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 3184 additions and 1966 deletions

View File

@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Add LoggerFactory and clean up related API."
}

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Make cache, storage, and vector_store factories consistent with similar registration support"
}

View File

@ -0,0 +1,675 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Copyright (c) 2024 Microsoft Corporation.\n",
"# Licensed under the MIT License."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Bring-Your-Own Vector Store\n",
"\n",
"This notebook demonstrates how to implement a custom vector store and register for usage with GraphRAG.\n",
"\n",
"## Overview\n",
"\n",
"GraphRAG uses a plug-and-play architecture that allow for easy integration of custom vector stores (outside of what is natively supported) by following a factory design pattern. This allows you to:\n",
"\n",
"- **Extend functionality**: Add support for new vector database backends\n",
"- **Customize behavior**: Implement specialized search logic or data structures\n",
"- **Integrate existing systems**: Connect GraphRAG to your existing vector database infrastructure\n",
"\n",
"### What You'll Learn\n",
"\n",
"1. Understanding the `BaseVectorStore` interface\n",
"2. Implementing a custom vector store class\n",
"3. Registering your vector store with the `VectorStoreFactory`\n",
"4. Testing and validating your implementation\n",
"5. Configuring GraphRAG to use your custom vector store\n",
"\n",
"Let's get started!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 1: Import Required Dependencies\n",
"\n",
"First, let's import the necessary GraphRAG components and other dependencies we'll need.\n",
"\n",
"```bash\n",
"pip install graphrag\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import Any\n",
"\n",
"import numpy as np\n",
"import yaml\n",
"\n",
"from graphrag.data_model.types import TextEmbedder\n",
"\n",
"# GraphRAG vector store components\n",
"from graphrag.vector_stores.base import (\n",
" BaseVectorStore,\n",
" VectorStoreDocument,\n",
" VectorStoreSearchResult,\n",
")\n",
"from graphrag.vector_stores.factory import VectorStoreFactory"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 2: Understand the BaseVectorStore Interface\n",
"\n",
"Before using a custom vector store, let's examine the `BaseVectorStore` interface to understand what methods need to be implemented."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Let's inspect the BaseVectorStore class to understand the required methods\n",
"import inspect\n",
"\n",
"print(\"BaseVectorStore Abstract Methods:\")\n",
"print(\"=\" * 40)\n",
"\n",
"abstract_methods = []\n",
"for name, method in inspect.getmembers(BaseVectorStore, predicate=inspect.isfunction):\n",
" if getattr(method, \"__isabstractmethod__\", False):\n",
" signature = inspect.signature(method)\n",
" abstract_methods.append(f\"• {name}{signature}\")\n",
" print(f\"• {name}{signature}\")\n",
"\n",
"print(f\"\\nTotal abstract methods to implement: {len(abstract_methods)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 3: Implement a Custom Vector Store\n",
"\n",
"Now let's implement a simple in-memory vector store as an example. This vector store will:\n",
"\n",
"- Store documents and vectors in memory using Python data structures\n",
"- Support all required BaseVectorStore methods\n",
"\n",
"**Note**: This is a simplified example for demonstration. Production vector stores would typically use optimized libraries like FAISS, more sophisticated indexing, and persistent storage."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class SimpleInMemoryVectorStore(BaseVectorStore):\n",
" \"\"\"A simple in-memory vector store implementation for demonstration purposes.\n",
"\n",
" This vector store stores documents and their embeddings in memory and provides\n",
" basic similarity search functionality using cosine similarity.\n",
"\n",
" WARNING: This is for demonstration only - not suitable for production use.\n",
" For production, consider using optimized vector databases like LanceDB,\n",
" Azure AI Search, or other specialized vector stores.\n",
" \"\"\"\n",
"\n",
" # Internal storage for documents and vectors\n",
" documents: dict[str, VectorStoreDocument]\n",
" vectors: dict[str, np.ndarray]\n",
" connected: bool\n",
"\n",
" def __init__(self, **kwargs: Any):\n",
" \"\"\"Initialize the in-memory vector store.\"\"\"\n",
" super().__init__(**kwargs)\n",
"\n",
" self.documents: dict[str, VectorStoreDocument] = {}\n",
" self.vectors: dict[str, np.ndarray] = {}\n",
" self.connected = False\n",
"\n",
" print(\n",
" f\"🚀 SimpleInMemoryVectorStore initialized for collection: {self.collection_name}\"\n",
" )\n",
"\n",
" def connect(self, **kwargs: Any) -> None:\n",
" \"\"\"Connect to the vector storage (no-op for in-memory store).\"\"\"\n",
" self.connected = True\n",
" print(f\"✅ Connected to in-memory vector store: {self.collection_name}\")\n",
"\n",
" def load_documents(\n",
" self, documents: list[VectorStoreDocument], overwrite: bool = True\n",
" ) -> None:\n",
" \"\"\"Load documents into the vector store.\"\"\"\n",
" if not self.connected:\n",
" msg = \"Vector store not connected. Call connect() first.\"\n",
" raise RuntimeError(msg)\n",
"\n",
" if overwrite:\n",
" self.documents.clear()\n",
" self.vectors.clear()\n",
"\n",
" loaded_count = 0\n",
" for doc in documents:\n",
" if doc.vector is not None:\n",
" doc_id = str(doc.id)\n",
" self.documents[doc_id] = doc\n",
" self.vectors[doc_id] = np.array(doc.vector, dtype=np.float32)\n",
" loaded_count += 1\n",
"\n",
" print(f\"📚 Loaded {loaded_count} documents into vector store\")\n",
"\n",
" def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:\n",
" \"\"\"Calculate cosine similarity between two vectors.\"\"\"\n",
" # Normalize vectors\n",
" norm1 = np.linalg.norm(vec1)\n",
" norm2 = np.linalg.norm(vec2)\n",
"\n",
" if norm1 == 0 or norm2 == 0:\n",
" return 0.0\n",
"\n",
" return float(np.dot(vec1, vec2) / (norm1 * norm2))\n",
"\n",
" def similarity_search_by_vector(\n",
" self, query_embedding: list[float], k: int = 10, **kwargs: Any\n",
" ) -> list[VectorStoreSearchResult]:\n",
" \"\"\"Perform similarity search using a query vector.\"\"\"\n",
" if not self.connected:\n",
" msg = \"Vector store not connected. Call connect() first.\"\n",
" raise RuntimeError(msg)\n",
"\n",
" if not self.vectors:\n",
" return []\n",
"\n",
" query_vec = np.array(query_embedding, dtype=np.float32)\n",
" similarities = []\n",
"\n",
" # Calculate similarity with all stored vectors\n",
" for doc_id, stored_vec in self.vectors.items():\n",
" similarity = self._cosine_similarity(query_vec, stored_vec)\n",
" similarities.append((doc_id, similarity))\n",
"\n",
" # Sort by similarity (descending) and take top k\n",
" similarities.sort(key=lambda x: x[1], reverse=True)\n",
" top_k = similarities[:k]\n",
"\n",
" # Create search results\n",
" results = []\n",
" for doc_id, score in top_k:\n",
" document = self.documents[doc_id]\n",
" result = VectorStoreSearchResult(document=document, score=score)\n",
" results.append(result)\n",
"\n",
" return results\n",
"\n",
" def similarity_search_by_text(\n",
" self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any\n",
" ) -> list[VectorStoreSearchResult]:\n",
" \"\"\"Perform similarity search using text (which gets embedded first).\"\"\"\n",
" # Embed the text first\n",
" query_embedding = text_embedder(text)\n",
"\n",
" # Use vector search with the embedding\n",
" return self.similarity_search_by_vector(query_embedding, k, **kwargs)\n",
"\n",
" def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:\n",
" \"\"\"Build a query filter to filter documents by id.\n",
"\n",
" For this simple implementation, we return the list of IDs as the filter.\n",
" \"\"\"\n",
" return [str(id_) for id_ in include_ids]\n",
"\n",
" def search_by_id(self, id: str) -> VectorStoreDocument:\n",
" \"\"\"Search for a document by id.\"\"\"\n",
" doc_id = str(id)\n",
" if doc_id not in self.documents:\n",
" msg = f\"Document with id '{id}' not found\"\n",
" raise KeyError(msg)\n",
"\n",
" return self.documents[doc_id]\n",
"\n",
" def get_stats(self) -> dict[str, Any]:\n",
" \"\"\"Get statistics about the vector store (custom method).\"\"\"\n",
" return {\n",
" \"collection_name\": self.collection_name,\n",
" \"document_count\": len(self.documents),\n",
" \"vector_count\": len(self.vectors),\n",
" \"connected\": self.connected,\n",
" \"vector_dimension\": len(next(iter(self.vectors.values())))\n",
" if self.vectors\n",
" else 0,\n",
" }\n",
"\n",
"\n",
"print(\"✅ SimpleInMemoryVectorStore class defined!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 4: Register the Custom Vector Store\n",
"\n",
"Now let's register our custom vector store with the `VectorStoreFactory` so it can be used throughout GraphRAG."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Register our custom vector store with a unique identifier\n",
"CUSTOM_VECTOR_STORE_TYPE = \"simple_memory\"\n",
"\n",
"# Register the vector store class\n",
"VectorStoreFactory.register(CUSTOM_VECTOR_STORE_TYPE, SimpleInMemoryVectorStore)\n",
"\n",
"print(f\"✅ Registered custom vector store with type: '{CUSTOM_VECTOR_STORE_TYPE}'\")\n",
"\n",
"# Verify registration\n",
"available_types = VectorStoreFactory.get_vector_store_types()\n",
"print(f\"\\n📋 Available vector store types: {available_types}\")\n",
"print(\n",
" f\"🔍 Is our custom type supported? {VectorStoreFactory.is_supported_type(CUSTOM_VECTOR_STORE_TYPE)}\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 5: Test the Custom Vector Store\n",
"\n",
"Let's create some sample data and test our custom vector store implementation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create sample documents with mock embeddings\n",
"def create_mock_embedding(dimension: int = 384) -> list[float]:\n",
" \"\"\"Create a random embedding vector for testing.\"\"\"\n",
" return np.random.normal(0, 1, dimension).tolist()\n",
"\n",
"\n",
"# Sample documents\n",
"sample_documents = [\n",
" VectorStoreDocument(\n",
" id=\"doc_1\",\n",
" text=\"GraphRAG is a powerful knowledge graph extraction and reasoning framework.\",\n",
" vector=create_mock_embedding(),\n",
" attributes={\"category\": \"technology\", \"source\": \"documentation\"},\n",
" ),\n",
" VectorStoreDocument(\n",
" id=\"doc_2\",\n",
" text=\"Vector stores enable efficient similarity search over high-dimensional data.\",\n",
" vector=create_mock_embedding(),\n",
" attributes={\"category\": \"technology\", \"source\": \"research\"},\n",
" ),\n",
" VectorStoreDocument(\n",
" id=\"doc_3\",\n",
" text=\"Machine learning models can process and understand natural language text.\",\n",
" vector=create_mock_embedding(),\n",
" attributes={\"category\": \"AI\", \"source\": \"article\"},\n",
" ),\n",
" VectorStoreDocument(\n",
" id=\"doc_4\",\n",
" text=\"Custom implementations allow for specialized behavior and integration.\",\n",
" vector=create_mock_embedding(),\n",
" attributes={\"category\": \"development\", \"source\": \"tutorial\"},\n",
" ),\n",
"]\n",
"\n",
"print(f\"📝 Created {len(sample_documents)} sample documents\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test creating vector store using the factory\n",
"vector_store_config = {\"collection_name\": \"test_collection\"}\n",
"\n",
"# Create vector store instance using factory\n",
"vector_store = VectorStoreFactory.create_vector_store(\n",
" CUSTOM_VECTOR_STORE_TYPE, vector_store_config\n",
")\n",
"\n",
"print(f\"✅ Created vector store instance: {type(vector_store).__name__}\")\n",
"print(f\"📊 Initial stats: {vector_store.get_stats()}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Connect and load documents\n",
"vector_store.connect()\n",
"vector_store.load_documents(sample_documents)\n",
"\n",
"print(f\"📊 Updated stats: {vector_store.get_stats()}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test similarity search\n",
"query_vector = create_mock_embedding() # Random query vector for testing\n",
"\n",
"search_results = vector_store.similarity_search_by_vector(\n",
" query_vector,\n",
" k=3, # Get top 3 similar documents\n",
")\n",
"\n",
"print(f\"🔍 Found {len(search_results)} similar documents:\\n\")\n",
"\n",
"for i, result in enumerate(search_results, 1):\n",
" doc = result.document\n",
" print(f\"{i}. ID: {doc.id}\")\n",
" print(f\" Text: {doc.text[:60]}...\")\n",
" print(f\" Similarity Score: {result.score:.4f}\")\n",
" print(f\" Category: {doc.attributes.get('category', 'N/A')}\")\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test search by ID\n",
"try:\n",
" found_doc = vector_store.search_by_id(\"doc_2\")\n",
" print(\"✅ Found document by ID:\")\n",
" print(f\" ID: {found_doc.id}\")\n",
" print(f\" Text: {found_doc.text}\")\n",
" print(f\" Attributes: {found_doc.attributes}\")\n",
"except KeyError as e:\n",
" print(f\"❌ Error: {e}\")\n",
"\n",
"# Test filter by ID\n",
"id_filter = vector_store.filter_by_id([\"doc_1\", \"doc_3\"])\n",
"print(f\"\\n🔧 ID filter result: {id_filter}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 6: Configuration for GraphRAG\n",
"\n",
"Now let's see how you would configure GraphRAG to use your custom vector store in a settings file."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Example GraphRAG yaml settings\n",
"example_settings = {\n",
" \"vector_store\": {\n",
" \"default_vector_store\": {\n",
" \"type\": CUSTOM_VECTOR_STORE_TYPE, # \"simple_memory\"\n",
" \"collection_name\": \"graphrag_entities\",\n",
" # Add any custom parameters your vector store needs\n",
" \"custom_parameter\": \"custom_value\",\n",
" }\n",
" },\n",
" # Other GraphRAG configuration...\n",
" \"models\": {\n",
" \"default_embedding_model\": {\n",
" \"type\": \"openai_embedding\",\n",
" \"model\": \"text-embedding-3-small\",\n",
" }\n",
" },\n",
"}\n",
"\n",
"# Convert to YAML format for settings.yml\n",
"yaml_config = yaml.dump(example_settings, default_flow_style=False, indent=2)\n",
"\n",
"print(\"📄 Example settings.yml configuration:\")\n",
"print(\"=\" * 40)\n",
"print(yaml_config)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 7: Integration with GraphRAG Pipeline\n",
"\n",
"Here's how your custom vector store would be used in a typical GraphRAG pipeline."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Example of how GraphRAG would use your custom vector store\n",
"def simulate_graphrag_pipeline():\n",
" \"\"\"Simulate how GraphRAG would use the custom vector store.\"\"\"\n",
" print(\"🚀 Simulating GraphRAG pipeline with custom vector store...\\n\")\n",
"\n",
" # 1. GraphRAG creates vector store using factory\n",
" config = {\"collection_name\": \"graphrag_entities\", \"similarity_threshold\": 0.3}\n",
"\n",
" store = VectorStoreFactory.create_vector_store(CUSTOM_VECTOR_STORE_TYPE, config)\n",
" store.connect()\n",
"\n",
" print(\"✅ Step 1: Vector store created and connected\")\n",
"\n",
" # 2. During indexing, GraphRAG loads extracted entities\n",
" entity_documents = [\n",
" VectorStoreDocument(\n",
" id=f\"entity_{i}\",\n",
" text=f\"Entity {i} description: Important concept in the knowledge graph\",\n",
" vector=create_mock_embedding(),\n",
" attributes={\"type\": \"entity\", \"importance\": i % 3 + 1},\n",
" )\n",
" for i in range(10)\n",
" ]\n",
"\n",
" store.load_documents(entity_documents)\n",
" print(f\"✅ Step 2: Loaded {len(entity_documents)} entity documents\")\n",
"\n",
" # 3. During query time, GraphRAG searches for relevant entities\n",
" query_embedding = create_mock_embedding()\n",
" relevant_entities = store.similarity_search_by_vector(query_embedding, k=5)\n",
"\n",
" print(f\"✅ Step 3: Found {len(relevant_entities)} relevant entities for query\")\n",
"\n",
" # 4. GraphRAG uses these entities for context building\n",
" context_entities = [result.document for result in relevant_entities]\n",
"\n",
" print(\"✅ Step 4: Context built using retrieved entities\")\n",
" print(f\"📊 Final stats: {store.get_stats()}\")\n",
"\n",
" return context_entities\n",
"\n",
"\n",
"# Run the simulation\n",
"context = simulate_graphrag_pipeline()\n",
"print(f\"\\n🎯 Retrieved {len(context)} entities for context building\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 8: Testing and Validation\n",
"\n",
"Let's create a comprehensive test suite to ensure our vector store works correctly."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def test_custom_vector_store():\n",
" \"\"\"Comprehensive test suite for the custom vector store.\"\"\"\n",
" print(\"🧪 Running comprehensive vector store tests...\\n\")\n",
"\n",
" # Test 1: Basic functionality\n",
" print(\"Test 1: Basic functionality\")\n",
" store = VectorStoreFactory.create_vector_store(\n",
" CUSTOM_VECTOR_STORE_TYPE, {\"collection_name\": \"test\"}\n",
" )\n",
" store.connect()\n",
"\n",
" # Load test documents\n",
" test_docs = sample_documents[:2]\n",
" store.load_documents(test_docs)\n",
"\n",
" assert len(store.documents) == 2, \"Should have 2 documents\"\n",
" assert len(store.vectors) == 2, \"Should have 2 vectors\"\n",
" print(\"✅ Basic functionality test passed\")\n",
"\n",
" # Test 2: Search functionality\n",
" print(\"\\nTest 2: Search functionality\")\n",
" query_vec = create_mock_embedding()\n",
" results = store.similarity_search_by_vector(query_vec, k=5)\n",
"\n",
" assert len(results) <= 2, \"Should not return more results than documents\"\n",
" assert all(isinstance(r, VectorStoreSearchResult) for r in results), (\n",
" \"Should return VectorStoreSearchResult objects\"\n",
" )\n",
" assert all(-1 <= r.score <= 1 for r in results), (\n",
" \"Similarity scores should be between -1 and 1\"\n",
" )\n",
" print(\"✅ Search functionality test passed\")\n",
"\n",
" # Test 3: Search by ID\n",
" print(\"\\nTest 3: Search by ID\")\n",
" found_doc = store.search_by_id(\"doc_1\")\n",
" assert found_doc.id == \"doc_1\", \"Should find correct document\"\n",
"\n",
" try:\n",
" store.search_by_id(\"nonexistent\")\n",
" assert False, \"Should raise KeyError for nonexistent ID\"\n",
" except KeyError:\n",
" pass # Expected\n",
"\n",
" print(\"✅ Search by ID test passed\")\n",
"\n",
" # Test 4: Filter functionality\n",
" print(\"\\nTest 4: Filter functionality\")\n",
" filter_result = store.filter_by_id([\"doc_1\", \"doc_2\"])\n",
" assert filter_result == [\"doc_1\", \"doc_2\"], \"Should return filtered IDs\"\n",
" print(\"✅ Filter functionality test passed\")\n",
"\n",
" # Test 5: Error handling\n",
" print(\"\\nTest 5: Error handling\")\n",
" disconnected_store = VectorStoreFactory.create_vector_store(\n",
" CUSTOM_VECTOR_STORE_TYPE, {\"collection_name\": \"test2\"}\n",
" )\n",
"\n",
" try:\n",
" disconnected_store.load_documents(test_docs)\n",
" assert False, \"Should raise error when not connected\"\n",
" except RuntimeError:\n",
" pass # Expected\n",
"\n",
" try:\n",
" disconnected_store.similarity_search_by_vector(query_vec)\n",
" assert False, \"Should raise error when not connected\"\n",
" except RuntimeError:\n",
" pass # Expected\n",
"\n",
" print(\"✅ Error handling test passed\")\n",
"\n",
" print(\"\\n🎉 All tests passed! Your custom vector store is working correctly.\")\n",
"\n",
"\n",
"# Run the tests\n",
"test_custom_vector_store()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary and Next Steps\n",
"\n",
"Congratulations! You've successfully learned how to implement and register a custom vector store with GraphRAG. Here's what you accomplished:\n",
"\n",
"### What You Built\n",
"- ✅ **Custom Vector Store Class**: Implemented `SimpleInMemoryVectorStore` with all required methods\n",
"- ✅ **Factory Integration**: Registered your vector store with `VectorStoreFactory`\n",
"- ✅ **Comprehensive Testing**: Validated functionality with a full test suite\n",
"- ✅ **Configuration Examples**: Learned how to configure GraphRAG to use your vector store\n",
"\n",
"### Key Takeaways\n",
"1. **Interface Compliance**: Always implement all methods from `BaseVectorStore`\n",
"2. **Factory Pattern**: Use `VectorStoreFactory.register()` to make your vector store available\n",
"3. **Configuration**: Vector stores are configured in GraphRAG settings files\n",
"4. **Testing**: Thoroughly test all functionality before deploying\n",
"\n",
"### Next Steps\n",
"Check out the API Overview notebook to learn how to index and query data via the graphrag API.\n",
"\n",
"### Resources\n",
"- [GraphRAG Documentation](https://microsoft.github.io/graphrag/)\n",
"\n",
"Happy building! 🚀"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "graphrag-venv (3.10.18)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.18"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -1,6 +1,7 @@
# API Notebooks
- [API Overview Notebook](../../examples_notebooks/api_overview.ipynb)
- [Bring-Your-Own Vector Store](../../examples_notebooks/custom_vector_store.ipynb)
# Query Engine Notebooks

View File

@ -32,6 +32,7 @@ async def build_index(
memory_profile: bool = False,
callbacks: list[WorkflowCallbacks] | None = None,
additional_context: dict[str, Any] | None = None,
verbose: bool = False,
) -> list[PipelineRunResult]:
"""Run the pipeline with the given configuration.
@ -53,7 +54,7 @@ async def build_index(
list[PipelineRunResult]
The list of pipeline run results
"""
init_loggers(config=config)
init_loggers(config=config, verbose=verbose)
# Create callbacks for pipeline lifecycle events if provided
workflow_callbacks = (

View File

@ -67,6 +67,7 @@ async def generate_indexing_prompts(
min_examples_required: PositiveInt = 2,
n_subset_max: PositiveInt = 300,
k: PositiveInt = 15,
verbose: bool = False,
) -> tuple[str, str, str]:
"""Generate indexing prompts.
@ -89,7 +90,7 @@ async def generate_indexing_prompts(
-------
tuple[str, str, str]: entity extraction prompt, entity summarization prompt, community summarization prompt
"""
init_loggers(config=config)
init_loggers(config=config, verbose=verbose, filename="prompt-tuning.log")
# Retrieve documents
logger.info("Chunking documents...")

View File

@ -93,7 +93,7 @@ async def global_search(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
callbacks = callbacks or []
full_response = ""
@ -156,7 +156,7 @@ def global_search_streaming(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
communities_ = read_indexer_communities(communities, community_reports)
reports = read_indexer_reports(
@ -229,7 +229,7 @@ async def multi_index_global_search(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
# Streaming not supported yet
if streaming:
@ -369,7 +369,7 @@ async def local_search(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
callbacks = callbacks or []
full_response = ""
@ -435,7 +435,7 @@ def local_search_streaming(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
vector_store_args = {}
for index, store in config.vector_store.items():
@ -508,7 +508,7 @@ async def multi_index_local_search(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
# Streaming not supported yet
if streaming:
@ -730,7 +730,7 @@ async def drift_search(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
callbacks = callbacks or []
full_response = ""
@ -792,7 +792,7 @@ def drift_search_streaming(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
vector_store_args = {}
for index, store in config.vector_store.items():
@ -872,7 +872,7 @@ async def multi_index_drift_search(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
# Streaming not supported yet
if streaming:
@ -1065,7 +1065,7 @@ async def basic_search(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
callbacks = callbacks or []
full_response = ""
@ -1111,7 +1111,7 @@ def basic_search_streaming(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
vector_store_args = {}
for index, store in config.vector_store.items():
@ -1119,7 +1119,7 @@ def basic_search_streaming(
msg = f"Vector Store Args: {redact(vector_store_args)}"
logger.debug(msg)
description_embedding_store = get_embedding_store(
embedding_store = get_embedding_store(
config_args=vector_store_args,
embedding_name=text_unit_text_embedding,
)
@ -1130,7 +1130,7 @@ def basic_search_streaming(
search_engine = get_basic_search_engine(
config=config,
text_units=read_indexer_text_units(text_units),
text_unit_embeddings=description_embedding_store,
text_unit_embeddings=embedding_store,
system_prompt=prompt,
callbacks=callbacks,
)
@ -1164,7 +1164,7 @@ async def multi_index_basic_search(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
# Streaming not supported yet
if streaming:

View File

@ -1,23 +1,24 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing create_cache method definition."""
"""Factory functions for creating a cache."""
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from graphrag.config.enums import CacheType
from graphrag.storage.blob_pipeline_storage import create_blob_storage
from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
if TYPE_CHECKING:
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
from graphrag.cache.memory_pipeline_cache import InMemoryCache
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
from graphrag.config.enums import CacheType
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
if TYPE_CHECKING:
from collections.abc import Callable
from graphrag.cache.pipeline_cache import PipelineCache
class CacheFactory:
@ -25,39 +26,80 @@ class CacheFactory:
Includes a method for users to register a custom cache implementation.
Configuration arguments are passed to each cache implementation as kwargs (where possible)
Configuration arguments are passed to each cache implementation as kwargs
for individual enforcement of required/optional arguments.
"""
cache_types: ClassVar[dict[str, type]] = {}
_registry: ClassVar[dict[str, Callable[..., PipelineCache]]] = {}
@classmethod
def register(cls, cache_type: str, cache: type):
"""Register a custom cache implementation."""
cls.cache_types[cache_type] = cache
def register(cls, cache_type: str, creator: Callable[..., PipelineCache]) -> None:
"""Register a custom cache implementation.
Args:
cache_type: The type identifier for the cache.
creator: A class or callable that creates an instance of PipelineCache.
"""
cls._registry[cache_type] = creator
@classmethod
def create_cache(
cls, cache_type: CacheType | str | None, root_dir: str, kwargs: dict
) -> PipelineCache:
"""Create or get a cache from the provided type."""
if not cache_type:
return NoopPipelineCache()
match cache_type:
case CacheType.none:
return NoopPipelineCache()
case CacheType.memory:
return InMemoryCache()
case CacheType.file:
return JsonPipelineCache(
FilePipelineStorage(root_dir=root_dir).child(kwargs["base_dir"])
)
case CacheType.blob:
return JsonPipelineCache(create_blob_storage(**kwargs))
case CacheType.cosmosdb:
return JsonPipelineCache(create_cosmosdb_storage(**kwargs))
case _:
if cache_type in cls.cache_types:
return cls.cache_types[cache_type](**kwargs)
msg = f"Unknown cache type: {cache_type}"
raise ValueError(msg)
def create_cache(cls, cache_type: str, kwargs: dict) -> PipelineCache:
"""Create a cache object from the provided type.
Args:
cache_type: The type of cache to create.
root_dir: The root directory for file-based caches.
kwargs: Additional keyword arguments for the cache constructor.
Returns
-------
A PipelineCache instance.
Raises
------
ValueError: If the cache type is not registered.
"""
if cache_type not in cls._registry:
msg = f"Unknown cache type: {cache_type}"
raise ValueError(msg)
return cls._registry[cache_type](**kwargs)
@classmethod
def get_cache_types(cls) -> list[str]:
"""Get the registered cache implementations."""
return list(cls._registry.keys())
@classmethod
def is_supported_type(cls, cache_type: str) -> bool:
"""Check if the given cache type is supported."""
return cache_type in cls._registry
# --- register built-in cache implementations ---
def create_file_cache(root_dir: str, base_dir: str, **kwargs) -> PipelineCache:
"""Create a file-based cache implementation."""
# Create storage with base_dir in kwargs since FilePipelineStorage expects it there
storage_kwargs = {"base_dir": root_dir, **kwargs}
storage = FilePipelineStorage(**storage_kwargs).child(base_dir)
return JsonPipelineCache(storage)
def create_blob_cache(**kwargs) -> PipelineCache:
"""Create a blob storage-based cache implementation."""
storage = BlobPipelineStorage(**kwargs)
return JsonPipelineCache(storage)
def create_cosmosdb_cache(**kwargs) -> PipelineCache:
"""Create a CosmosDB-based cache implementation."""
storage = CosmosDBPipelineStorage(**kwargs)
return JsonPipelineCache(storage)
# --- register built-in cache implementations ---
CacheFactory.register(CacheType.none.value, NoopPipelineCache)
CacheFactory.register(CacheType.memory.value, InMemoryCache)
CacheFactory.register(CacheType.file.value, create_file_cache)
CacheFactory.register(CacheType.blob.value, create_blob_cache)
CacheFactory.register(CacheType.cosmosdb.value, create_cosmosdb_cache)

View File

@ -11,10 +11,9 @@ from pathlib import Path
import graphrag.api as api
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
from graphrag.config.enums import CacheType, IndexingMethod, ReportingType
from graphrag.config.enums import CacheType, IndexingMethod
from graphrag.config.load_config import load_config
from graphrag.index.validate_config import validate_config_names
from graphrag.logger.standard_logging import DEFAULT_LOG_FILENAME
from graphrag.utils.cli import redact
# Ignore warnings from numba
@ -123,17 +122,6 @@ def _run_index(
if not cache:
config.cache.type = CacheType.none
# Log the configuration details
if config.reporting.type == ReportingType.file:
log_dir = Path(config.root_dir) / config.reporting.base_dir
log_path = log_dir / DEFAULT_LOG_FILENAME
logger.info("Logging enabled at %s", log_path)
else:
logger.info(
"Logging not enabled for config %s",
redact(config.model_dump()),
)
if not skip_validation:
validate_config_names(config)
@ -156,6 +144,7 @@ def _run_index(
is_update_run=is_update_run,
memory_profile=memprofile,
callbacks=[ConsoleWorkflowCallbacks(verbose=verbose)],
verbose=verbose,
)
)
encountered_errors = any(

View File

@ -7,9 +7,7 @@ import logging
from pathlib import Path
import graphrag.api as api
from graphrag.config.enums import ReportingType
from graphrag.config.load_config import load_config
from graphrag.logger.standard_logging import DEFAULT_LOG_FILENAME
from graphrag.prompt_tune.generator.community_report_summarization import (
COMMUNITY_SUMMARIZATION_FILENAME,
)
@ -74,21 +72,13 @@ async def prompt_tune(
from graphrag.logger.standard_logging import init_loggers
# initialize loggers with config
init_loggers(
config=graph_config,
verbose=verbose,
)
init_loggers(config=graph_config, verbose=verbose, filename="prompt-tuning.log")
# log the configuration details
if graph_config.reporting.type == ReportingType.file:
log_dir = Path(root_path) / graph_config.reporting.base_dir
log_path = log_dir / DEFAULT_LOG_FILENAME
logger.info("Logging enabled at %s", log_path)
else:
logger.info(
"Logging not enabled for config %s",
redact(graph_config.model_dump()),
)
logger.info("Starting prompt tune.")
logger.info(
"Using default configuration: %s",
redact(graph_config.model_dump()),
)
prompts = await api.generate_indexing_prompts(
config=graph_config,
@ -103,6 +93,7 @@ async def prompt_tune(
min_examples_required=min_examples_required,
n_subset_max=n_subset_max,
k=k,
verbose=verbose,
)
output_path = output.resolve()

View File

@ -18,11 +18,11 @@ from graphrag.config.enums import (
NounPhraseExtractorType,
ReportingType,
StorageType,
VectorStoreType,
)
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
EN_STOP_WORDS,
)
from graphrag.vector_stores.factory import VectorStoreType
DEFAULT_OUTPUT_BASE_DIR = "output"
DEFAULT_CHAT_MODEL_ID = "default_chat_model"

View File

@ -59,6 +59,14 @@ class StorageType(str, Enum):
return f'"{self.value}"'
class VectorStoreType(str, Enum):
"""The supported vector store types."""
LanceDB = "lancedb"
AzureAISearch = "azure_ai_search"
CosmosDB = "cosmosdb"
class ReportingType(str, Enum):
"""The reporting configuration type for the pipeline."""

View File

@ -82,7 +82,7 @@ cache:
base_dir: "{graphrag_config_defaults.cache.base_dir}"
reporting:
type: {graphrag_config_defaults.reporting.type.value} # [file, blob, cosmosdb]
type: {graphrag_config_defaults.reporting.type.value} # [file, blob]
base_dir: "{graphrag_config_defaults.reporting.base_dir}"
vector_store:

View File

@ -12,7 +12,7 @@ from graphrag.config.enums import CacheType
class CacheConfig(BaseModel):
"""The default configuration section for Cache."""
type: CacheType = Field(
type: CacheType | str = Field(
description="The cache type to use.",
default=graphrag_config_defaults.cache.type,
)

View File

@ -11,6 +11,7 @@ from pydantic import BaseModel, Field, model_validator
import graphrag.config.defaults as defs
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import VectorStoreType
from graphrag.config.errors import LanguageModelConfigMissingError
from graphrag.config.models.basic_search_config import BasicSearchConfig
from graphrag.config.models.cache_config import CacheConfig
@ -36,7 +37,6 @@ from graphrag.config.models.summarize_descriptions_config import (
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig
from graphrag.config.models.umap_config import UmapConfig
from graphrag.config.models.vector_store_config import VectorStoreConfig
from graphrag.vector_stores.factory import VectorStoreType
class GraphRagConfig(BaseModel):

View File

@ -12,7 +12,7 @@ from graphrag.config.enums import ReportingType
class ReportingConfig(BaseModel):
"""The default configuration section for Reporting."""
type: ReportingType = Field(
type: ReportingType | str = Field(
description="The reporting type to use.",
default=graphrag_config_defaults.reporting.type,
)

View File

@ -14,7 +14,7 @@ from graphrag.config.enums import StorageType
class StorageConfig(BaseModel):
"""The default configuration section for storage."""
type: StorageType = Field(
type: StorageType | str = Field(
description="The storage type to use.",
default=graphrag_config_defaults.storage.type,
)

View File

@ -39,7 +39,7 @@ class SummarizeDescriptionsConfig(BaseModel):
self, root_dir: str, model_config: LanguageModelConfig
) -> dict:
"""Get the resolved description summarization strategy."""
from graphrag.index.operations.summarize_descriptions import (
from graphrag.index.operations.summarize_descriptions.summarize_descriptions import (
SummarizeStrategyType,
)

View File

@ -39,7 +39,7 @@ class TextEmbeddingConfig(BaseModel):
def resolved_strategy(self, model_config: LanguageModelConfig) -> dict:
"""Get the resolved text embedding strategy."""
from graphrag.index.operations.embed_text import (
from graphrag.index.operations.embed_text.embed_text import (
TextEmbedStrategyType,
)

View File

@ -6,7 +6,7 @@
from pydantic import BaseModel, Field, model_validator
from graphrag.config.defaults import vector_store_defaults
from graphrag.vector_stores.factory import VectorStoreType
from graphrag.config.enums import VectorStoreType
class VectorStoreConfig(BaseModel):

View File

@ -2,10 +2,3 @@
# Licensed under the MIT License
"""The Indexing Engine text embed package root."""
from graphrag.index.operations.embed_text.embed_text import (
TextEmbedStrategyType,
embed_text,
)
__all__ = ["TextEmbedStrategyType", "embed_text"]

View File

@ -2,17 +2,3 @@
# Licensed under the MIT License
"""Root package for description summarization."""
from graphrag.index.operations.summarize_descriptions.summarize_descriptions import (
summarize_descriptions,
)
from graphrag.index.operations.summarize_descriptions.typing import (
SummarizationStrategy,
SummarizeStrategyType,
)
__all__ = [
"SummarizationStrategy",
"SummarizeStrategyType",
"summarize_descriptions",
]

View File

@ -15,7 +15,7 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.operations.extract_graph.extract_graph import (
extract_graph as extractor,
)
from graphrag.index.operations.summarize_descriptions import (
from graphrag.index.operations.summarize_descriptions.summarize_descriptions import (
summarize_descriptions,
)
from graphrag.index.typing.context import PipelineRunContext

View File

@ -21,7 +21,7 @@ from graphrag.config.embeddings import (
)
from graphrag.config.get_embedding_settings import get_embedding_settings
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.operations.embed_text import embed_text
from graphrag.index.operations.embed_text.embed_text import embed_text
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import (

View File

@ -7,7 +7,7 @@ from collections.abc import Callable
from typing import Any, ClassVar
from graphrag.config.enums import ModelType
from graphrag.language_model.protocol import ChatModel, EmbeddingModel
from graphrag.language_model.protocol.base import ChatModel, EmbeddingModel
from graphrag.language_model.providers.fnllm.models import (
AzureOpenAIChatFNLLM,
AzureOpenAIEmbeddingFNLLM,
@ -100,15 +100,16 @@ class ModelFactory:
# --- Register default implementations ---
ModelFactory.register_chat(
ModelType.AzureOpenAIChat, lambda **kwargs: AzureOpenAIChatFNLLM(**kwargs)
ModelType.AzureOpenAIChat.value, lambda **kwargs: AzureOpenAIChatFNLLM(**kwargs)
)
ModelFactory.register_chat(
ModelType.OpenAIChat, lambda **kwargs: OpenAIChatFNLLM(**kwargs)
ModelType.OpenAIChat.value, lambda **kwargs: OpenAIChatFNLLM(**kwargs)
)
ModelFactory.register_embedding(
ModelType.AzureOpenAIEmbedding, lambda **kwargs: AzureOpenAIEmbeddingFNLLM(**kwargs)
ModelType.AzureOpenAIEmbedding.value,
lambda **kwargs: AzureOpenAIEmbeddingFNLLM(**kwargs),
)
ModelFactory.register_embedding(
ModelType.OpenAIEmbedding, lambda **kwargs: OpenAIEmbeddingFNLLM(**kwargs)
ModelType.OpenAIEmbedding.value, lambda **kwargs: OpenAIEmbeddingFNLLM(**kwargs)
)

View File

@ -16,7 +16,7 @@ from typing_extensions import Self
from graphrag.language_model.factory import ModelFactory
if TYPE_CHECKING:
from graphrag.language_model.protocol import ChatModel, EmbeddingModel
from graphrag.language_model.protocol.base import ChatModel, EmbeddingModel
class ModelManager:

View File

@ -2,7 +2,3 @@
# Licensed under the MIT License
"""Base protocol definitions for LLMs."""
from .base import ChatModel, EmbeddingModel
__all__ = ["ChatModel", "EmbeddingModel"]

113
graphrag/logger/factory.py Normal file
View File

@ -0,0 +1,113 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Factory functions for creating a logger."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING, ClassVar
from graphrag.config.enums import ReportingType
if TYPE_CHECKING:
from collections.abc import Callable
LOG_FORMAT = "%(asctime)s.%(msecs)04d - %(levelname)s - %(name)s - %(message)s"
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
class LoggerFactory:
"""A factory class for logger implementations.
Includes a method for users to register a custom logger implementation.
Configuration arguments are passed to each logger implementation as kwargs
for individual enforcement of required/optional arguments.
Note that because we rely on the built-in Python logging architecture, this factory does not return an instance,
it merely configures the logger to your specified storage location.
"""
_registry: ClassVar[dict[str, Callable[..., logging.Handler]]] = {}
@classmethod
def register(
cls, reporting_type: str, creator: Callable[..., logging.Handler]
) -> None:
"""Register a custom logger implementation.
Args:
reporting_type: The type identifier for the logger.
creator: A class or callable that initializes logging.
"""
cls._registry[reporting_type] = creator
@classmethod
def create_logger(cls, reporting_type: str, kwargs: dict) -> logging.Handler:
"""Create a logger from the provided type.
Args:
reporting_type: The type of logger to create.
logger: The logger instance for the application.
kwargs: Additional keyword arguments for the constructor.
Returns
-------
A logger instance.
Raises
------
ValueError: If the logger type is not registered.
"""
if reporting_type not in cls._registry:
msg = f"Unknown reporting type: {reporting_type}"
raise ValueError(msg)
return cls._registry[reporting_type](**kwargs)
@classmethod
def get_logger_types(cls) -> list[str]:
"""Get the registered logger implementations."""
return list(cls._registry.keys())
@classmethod
def is_supported_type(cls, reporting_type: str) -> bool:
"""Check if the given logger type is supported."""
return reporting_type in cls._registry
# --- register built-in logger implementations ---
def create_file_logger(**kwargs) -> logging.Handler:
"""Create a file-based logger."""
root_dir = kwargs["root_dir"]
base_dir = kwargs["base_dir"]
filename = kwargs["filename"]
log_dir = Path(root_dir) / base_dir
log_dir.mkdir(parents=True, exist_ok=True)
log_file_path = log_dir / filename
handler = logging.FileHandler(str(log_file_path), mode="a")
formatter = logging.Formatter(fmt=LOG_FORMAT, datefmt=DATE_FORMAT)
handler.setFormatter(formatter)
return handler
def create_blob_logger(**kwargs) -> logging.Handler:
"""Create a blob storage-based logger."""
from graphrag.logger.blob_workflow_logger import BlobWorkflowLogger
return BlobWorkflowLogger(
connection_string=kwargs["connection_string"],
container_name=kwargs["container_name"],
base_dir=kwargs["base_dir"],
storage_account_blob_url=kwargs["storage_account_blob_url"],
)
# --- register built-in implementations ---
LoggerFactory.register(ReportingType.file.value, create_file_logger)
LoggerFactory.register(ReportingType.blob.value, create_blob_logger)

View File

@ -32,15 +32,19 @@ Notes
All progress logging now uses this standard logging system for consistency.
"""
import logging
from pathlib import Path
from __future__ import annotations
from graphrag.config.enums import ReportingType
from graphrag.config.models.graph_rag_config import GraphRagConfig
import logging
from typing import TYPE_CHECKING
from graphrag.logger.factory import (
LoggerFactory,
)
if TYPE_CHECKING:
from graphrag.config.models.graph_rag_config import GraphRagConfig
DEFAULT_LOG_FILENAME = "indexing-engine.log"
LOG_FORMAT = "%(asctime)s.%(msecs)04d - %(levelname)s - %(name)s - %(message)s"
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
def init_loggers(
@ -50,26 +54,15 @@ def init_loggers(
) -> None:
"""Initialize logging handlers for graphrag based on configuration.
This function merges the functionality of configure_logging() and create_pipeline_logger()
to provide a unified way to set up logging for the graphrag package.
Parameters
----------
config : GraphRagConfig | None, default=None
The GraphRAG configuration. If None, defaults to file-based reporting.
root_dir : str | None, default=None
The root directory for file-based logging.
verbose : bool, default=False
Whether to enable verbose (DEBUG) logging.
log_file : Optional[Union[str, Path]], default=None
Path to a specific log file. If provided, takes precedence over config.
filename : Optional[str]
Log filename on disk. If unset, will use a default name.
"""
# import BlobWorkflowLogger here to avoid circular imports
from graphrag.logger.blob_workflow_logger import BlobWorkflowLogger
# extract reporting config from GraphRagConfig if provided
reporting_config = config.reporting
logger = logging.getLogger("graphrag")
log_level = logging.DEBUG if verbose else logging.INFO
logger.setLevel(log_level)
@ -82,27 +75,9 @@ def init_loggers(
handler.close()
logger.handlers.clear()
# create formatter with custom format
formatter = logging.Formatter(fmt=LOG_FORMAT, datefmt=DATE_FORMAT)
reporting_config = config.reporting
config_dict = reporting_config.model_dump()
args = {**config_dict, "root_dir": config.root_dir, "filename": filename}
# add more handlers based on configuration
handler: logging.Handler
match reporting_config.type:
case ReportingType.file:
# use the config-based file path
log_dir = Path(config.root_dir) / (reporting_config.base_dir)
log_dir.mkdir(parents=True, exist_ok=True)
log_file_path = log_dir / filename
handler = logging.FileHandler(str(log_file_path), mode="a")
handler.setFormatter(formatter)
logger.addHandler(handler)
case ReportingType.blob:
handler = BlobWorkflowLogger(
reporting_config.connection_string,
reporting_config.container_name,
base_dir=reporting_config.base_dir,
storage_account_blob_url=reporting_config.storage_account_blob_url,
)
logger.addHandler(handler)
case _:
logger.error("Unknown reporting type '%s'.", reporting_config.type)
handler = LoggerFactory.create_logger(reporting_config.type, args)
logger.addHandler(handler)

View File

@ -29,15 +29,20 @@ class BlobPipelineStorage(PipelineStorage):
_encoding: str
_storage_account_blob_url: str | None
def __init__(
self,
connection_string: str | None,
container_name: str,
encoding: str = "utf-8",
path_prefix: str | None = None,
storage_account_blob_url: str | None = None,
):
def __init__(self, **kwargs: Any) -> None:
"""Create a new BlobStorage instance."""
connection_string = kwargs.get("connection_string")
storage_account_blob_url = kwargs.get("storage_account_blob_url")
path_prefix = kwargs.get("base_dir")
container_name = kwargs["container_name"]
if container_name is None:
msg = "No container name provided for blob storage."
raise ValueError(msg)
if connection_string is None and storage_account_blob_url is None:
msg = "No storage account blob url provided for blob storage."
raise ValueError(msg)
logger.info("Creating blob storage at %s", container_name)
if connection_string:
self._blob_service_client = BlobServiceClient.from_connection_string(
connection_string
@ -51,7 +56,7 @@ class BlobPipelineStorage(PipelineStorage):
account_url=storage_account_blob_url,
credential=DefaultAzureCredential(),
)
self._encoding = encoding
self._encoding = kwargs.get("encoding", "utf-8")
self._container_name = container_name
self._connection_string = connection_string
self._path_prefix = path_prefix or ""
@ -271,11 +276,11 @@ class BlobPipelineStorage(PipelineStorage):
return self
path = str(Path(self._path_prefix) / name)
return BlobPipelineStorage(
self._connection_string,
self._container_name,
self._encoding,
path,
self._storage_account_blob_url,
connection_string=self._connection_string,
container_name=self._container_name,
encoding=self._encoding,
base_dir=path,
storage_account_blob_url=self._storage_account_blob_url,
)
def keys(self) -> list[str]:
@ -307,27 +312,6 @@ class BlobPipelineStorage(PipelineStorage):
return ""
def create_blob_storage(**kwargs: Any) -> PipelineStorage:
"""Create a blob based storage."""
connection_string = kwargs.get("connection_string")
storage_account_blob_url = kwargs.get("storage_account_blob_url")
base_dir = kwargs.get("base_dir")
container_name = kwargs["container_name"]
logger.info("Creating blob storage at %s", container_name)
if container_name is None:
msg = "No container name provided for blob storage."
raise ValueError(msg)
if connection_string is None and storage_account_blob_url is None:
msg = "No storage account blob url provided for blob storage."
raise ValueError(msg)
return BlobPipelineStorage(
connection_string=connection_string,
container_name=container_name,
path_prefix=base_dir,
storage_account_blob_url=storage_account_blob_url,
)
def validate_blob_container_name(container_name: str):
"""
Check if the provided blob container name is valid based on Azure rules.

View File

@ -39,15 +39,20 @@ class CosmosDBPipelineStorage(PipelineStorage):
_encoding: str
_no_id_prefixes: list[str]
def __init__(
self,
database_name: str,
container_name: str,
cosmosdb_account_url: str | None = None,
connection_string: str | None = None,
encoding: str = "utf-8",
):
"""Initialize the CosmosDB Storage."""
def __init__(self, **kwargs: Any) -> None:
"""Create a CosmosDB storage instance."""
logger.info("Creating cosmosdb storage")
cosmosdb_account_url = kwargs.get("cosmosdb_account_url")
connection_string = kwargs.get("connection_string")
database_name = kwargs["base_dir"]
container_name = kwargs["container_name"]
if not database_name:
msg = "No base_dir provided for database name"
raise ValueError(msg)
if connection_string is None and cosmosdb_account_url is None:
msg = "connection_string or cosmosdb_account_url is required."
raise ValueError(msg)
if connection_string:
self._cosmos_client = CosmosClient.from_connection_string(connection_string)
else:
@ -60,7 +65,7 @@ class CosmosDBPipelineStorage(PipelineStorage):
url=cosmosdb_account_url,
credential=DefaultAzureCredential(),
)
self._encoding = encoding
self._encoding = kwargs.get("encoding", "utf-8")
self._database_name = database_name
self._connection_string = connection_string
self._cosmosdb_account_url = cosmosdb_account_url
@ -348,29 +353,6 @@ class CosmosDBPipelineStorage(PipelineStorage):
return ""
# TODO remove this helper function and have the factory instantiate the class directly
# once the new config system is in place and will enforce the correct types/existence of certain fields
def create_cosmosdb_storage(**kwargs: Any) -> PipelineStorage:
"""Create a CosmosDB storage instance."""
logger.info("Creating cosmosdb storage")
cosmosdb_account_url = kwargs.get("cosmosdb_account_url")
connection_string = kwargs.get("connection_string")
base_dir = kwargs["base_dir"]
container_name = kwargs["container_name"]
if not base_dir:
msg = "No base_dir provided for database name"
raise ValueError(msg)
if connection_string is None and cosmosdb_account_url is None:
msg = "connection_string or cosmosdb_account_url is required."
raise ValueError(msg)
return CosmosDBPipelineStorage(
cosmosdb_account_url=cosmosdb_account_url,
connection_string=connection_string,
database_name=base_dir,
container_name=container_name,
)
def _create_progress_status(
num_loaded: int, num_filtered: int, num_total: int
) -> Progress:

View File

@ -5,13 +5,12 @@
from __future__ import annotations
from contextlib import suppress
from typing import TYPE_CHECKING, ClassVar
from graphrag.config.enums import StorageType
from graphrag.storage.blob_pipeline_storage import create_blob_storage
from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage
from graphrag.storage.file_pipeline_storage import create_file_storage
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
if TYPE_CHECKING:
@ -29,8 +28,7 @@ class StorageFactory:
for individual enforcement of required/optional arguments.
"""
_storage_registry: ClassVar[dict[str, Callable[..., PipelineStorage]]] = {}
storage_types: ClassVar[dict[str, type]] = {} # For backward compatibility
_registry: ClassVar[dict[str, Callable[..., PipelineStorage]]] = {}
@classmethod
def register(
@ -40,23 +38,13 @@ class StorageFactory:
Args:
storage_type: The type identifier for the storage.
creator: A callable that creates an instance of the storage.
"""
cls._storage_registry[storage_type] = creator
creator: A class or callable that creates an instance of PipelineStorage.
# For backward compatibility with code that may access storage_types directly
if (
callable(creator)
and hasattr(creator, "__annotations__")
and "return" in creator.__annotations__
):
with suppress(TypeError, KeyError):
cls.storage_types[storage_type] = creator.__annotations__["return"]
"""
cls._registry[storage_type] = creator
@classmethod
def create_storage(
cls, storage_type: StorageType | str, kwargs: dict
) -> PipelineStorage:
def create_storage(cls, storage_type: str, kwargs: dict) -> PipelineStorage:
"""Create a storage object from the provided type.
Args:
@ -71,31 +59,25 @@ class StorageFactory:
------
ValueError: If the storage type is not registered.
"""
storage_type_str = (
storage_type.value
if isinstance(storage_type, StorageType)
else storage_type
)
if storage_type_str not in cls._storage_registry:
if storage_type not in cls._registry:
msg = f"Unknown storage type: {storage_type}"
raise ValueError(msg)
return cls._storage_registry[storage_type_str](**kwargs)
return cls._registry[storage_type](**kwargs)
@classmethod
def get_storage_types(cls) -> list[str]:
"""Get the registered storage implementations."""
return list(cls._storage_registry.keys())
return list(cls._registry.keys())
@classmethod
def is_supported_storage_type(cls, storage_type: str) -> bool:
def is_supported_type(cls, storage_type: str) -> bool:
"""Check if the given storage type is supported."""
return storage_type in cls._storage_registry
return storage_type in cls._registry
# --- Register default implementations ---
StorageFactory.register(StorageType.blob.value, create_blob_storage)
StorageFactory.register(StorageType.cosmosdb.value, create_cosmosdb_storage)
StorageFactory.register(StorageType.file.value, create_file_storage)
StorageFactory.register(StorageType.memory.value, lambda **_: MemoryPipelineStorage())
# --- register built-in storage implementations ---
StorageFactory.register(StorageType.blob.value, BlobPipelineStorage)
StorageFactory.register(StorageType.cosmosdb.value, CosmosDBPipelineStorage)
StorageFactory.register(StorageType.file.value, FilePipelineStorage)
StorageFactory.register(StorageType.memory.value, MemoryPipelineStorage)

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'FileStorage' and 'FilePipelineStorage' models."""
"""File-based Storage implementation of PipelineStorage."""
import logging
import os
@ -30,10 +30,11 @@ class FilePipelineStorage(PipelineStorage):
_root_dir: str
_encoding: str
def __init__(self, root_dir: str = "", encoding: str = "utf-8"):
"""Init method definition."""
self._root_dir = root_dir
self._encoding = encoding
def __init__(self, **kwargs: Any) -> None:
"""Create a file based storage."""
self._root_dir = kwargs.get("base_dir", "")
self._encoding = kwargs.get("encoding", "utf-8")
logger.info("Creating file storage at %s", self._root_dir)
Path(self._root_dir).mkdir(parents=True, exist_ok=True)
def find(
@ -148,7 +149,8 @@ class FilePipelineStorage(PipelineStorage):
"""Create a child storage instance."""
if name is None:
return self
return FilePipelineStorage(str(Path(self._root_dir) / Path(name)))
child_path = str(Path(self._root_dir) / Path(name))
return FilePipelineStorage(base_dir=child_path, encoding=self._encoding)
def keys(self) -> list[str]:
"""Return the keys in the storage."""
@ -167,10 +169,3 @@ class FilePipelineStorage(PipelineStorage):
def join_path(file_path: str, file_name: str) -> Path:
"""Join a path and a file. Independent of the OS."""
return Path(file_path) / Path(file_name).parent / Path(file_name).name
def create_file_storage(**kwargs: Any) -> PipelineStorage:
"""Create a file based storage."""
base_dir = kwargs["base_dir"]
logger.info("Creating file storage at %s", base_dir)
return FilePipelineStorage(root_dir=base_dir)

View File

@ -250,10 +250,10 @@ def create_storage_from_config(output: StorageConfig) -> PipelineStorage:
def create_cache_from_config(cache: CacheConfig, root_dir: str) -> PipelineCache:
"""Create a cache object from the config."""
cache_config = cache.model_dump()
kwargs = {**cache_config, "root_dir": root_dir}
return CacheFactory().create_cache(
cache_type=cache_config["type"],
root_dir=root_dir,
kwargs=cache_config,
kwargs=kwargs,
)

View File

@ -1,52 +1,88 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A package containing a factory and supported vector store types."""
"""Factory functions for creating a vector store."""
from enum import Enum
from typing import ClassVar
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from graphrag.config.enums import VectorStoreType
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
from graphrag.vector_stores.base import BaseVectorStore
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
from graphrag.vector_stores.lancedb import LanceDBVectorStore
if TYPE_CHECKING:
from collections.abc import Callable
class VectorStoreType(str, Enum):
"""The supported vector store types."""
LanceDB = "lancedb"
AzureAISearch = "azure_ai_search"
CosmosDB = "cosmosdb"
from graphrag.vector_stores.base import BaseVectorStore
class VectorStoreFactory:
"""A factory for vector stores.
Includes a method for users to register a custom vector store implementation.
Configuration arguments are passed to each vector store implementation as kwargs
for individual enforcement of required/optional arguments.
"""
vector_store_types: ClassVar[dict[str, type]] = {}
_registry: ClassVar[dict[str, Callable[..., BaseVectorStore]]] = {}
@classmethod
def register(cls, vector_store_type: str, vector_store: type):
"""Register a custom vector store implementation."""
cls.vector_store_types[vector_store_type] = vector_store
def register(
cls, vector_store_type: str, creator: Callable[..., BaseVectorStore]
) -> None:
"""Register a custom vector store implementation.
Args:
vector_store_type: The type identifier for the vector store.
creator: A class or callable that creates an instance of BaseVectorStore.
Raises
------
TypeError: If creator is a class type instead of a factory function.
"""
cls._registry[vector_store_type] = creator
@classmethod
def create_vector_store(
cls, vector_store_type: VectorStoreType | str, kwargs: dict
cls, vector_store_type: str, kwargs: dict
) -> BaseVectorStore:
"""Create or get a vector store from the provided type."""
match vector_store_type:
case VectorStoreType.LanceDB:
return LanceDBVectorStore(**kwargs)
case VectorStoreType.AzureAISearch:
return AzureAISearchVectorStore(**kwargs)
case VectorStoreType.CosmosDB:
return CosmosDBVectorStore(**kwargs)
case _:
if vector_store_type in cls.vector_store_types:
return cls.vector_store_types[vector_store_type](**kwargs)
msg = f"Unknown vector store type: {vector_store_type}"
raise ValueError(msg)
"""Create a vector store object from the provided type.
Args:
vector_store_type: The type of vector store to create.
kwargs: Additional keyword arguments for the vector store constructor.
Returns
-------
A BaseVectorStore instance.
Raises
------
ValueError: If the vector store type is not registered.
"""
if vector_store_type not in cls._registry:
msg = f"Unknown vector store type: {vector_store_type}"
raise ValueError(msg)
return cls._registry[vector_store_type](**kwargs)
@classmethod
def get_vector_store_types(cls) -> list[str]:
"""Get the registered vector store implementations."""
return list(cls._registry.keys())
@classmethod
def is_supported_type(cls, vector_store_type: str) -> bool:
"""Check if the given vector store type is supported."""
return vector_store_type in cls._registry
# --- register built-in vector store implementations ---
VectorStoreFactory.register(VectorStoreType.LanceDB.value, LanceDBVectorStore)
VectorStoreFactory.register(
VectorStoreType.AzureAISearch.value, AzureAISearchVectorStore
)
VectorStoreFactory.register(VectorStoreType.CosmosDB.value, CosmosDBVectorStore)

View File

@ -75,6 +75,7 @@ dev = [
"jupyter>=1.1.1",
"nbconvert>=7.16.4",
"poethepoet>=0.31.1",
"pandas-stubs>=2.3.0.250703",
"pyright>=1.1.390",
"pytest>=8.3.4",
"pytest-asyncio>=0.24.0",
@ -242,7 +243,7 @@ ignore = [
[tool.ruff.lint.per-file-ignores]
"tests/*" = ["S", "D", "ANN", "T201", "ASYNC", "ARG", "PTH", "TRY"]
"graphrag/index/config/*" = ["TCH"]
"*.ipynb" = ["T201"]
"*.ipynb" = ["T201", "S101", "PT015", "B011"]
[tool.ruff.lint.flake8-builtins]
builtins-ignorelist = ["input", "id", "bytes"]

2
tests/integration/cache/__init__.py vendored Normal file
View File

@ -0,0 +1,2 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

169
tests/integration/cache/test_factory.py vendored Normal file
View File

@ -0,0 +1,169 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""CacheFactory Tests.
These tests will test the CacheFactory class and the creation of each cache type that is natively supported.
"""
import sys
import pytest
from graphrag.cache.factory import CacheFactory
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
from graphrag.cache.memory_pipeline_cache import InMemoryCache
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.enums import CacheType
# cspell:disable-next-line well-known-key
WELL_KNOWN_BLOB_STORAGE_KEY = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;"
# cspell:disable-next-line well-known-key
WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
def test_create_noop_cache():
kwargs = {}
cache = CacheFactory.create_cache(CacheType.none.value, kwargs)
assert isinstance(cache, NoopPipelineCache)
def test_create_memory_cache():
kwargs = {}
cache = CacheFactory.create_cache(CacheType.memory.value, kwargs)
assert isinstance(cache, InMemoryCache)
def test_create_file_cache():
kwargs = {"root_dir": "/tmp", "base_dir": "testcache"}
cache = CacheFactory.create_cache(CacheType.file.value, kwargs)
assert isinstance(cache, JsonPipelineCache)
def test_create_blob_cache():
kwargs = {
"connection_string": WELL_KNOWN_BLOB_STORAGE_KEY,
"container_name": "testcontainer",
"base_dir": "testcache",
}
cache = CacheFactory.create_cache(CacheType.blob.value, kwargs)
assert isinstance(cache, JsonPipelineCache)
@pytest.mark.skipif(
not sys.platform.startswith("win"),
reason="cosmosdb emulator is only available on windows runners at this time",
)
def test_create_cosmosdb_cache():
kwargs = {
"connection_string": WELL_KNOWN_COSMOS_CONNECTION_STRING,
"base_dir": "testdatabase",
"container_name": "testcontainer",
}
cache = CacheFactory.create_cache(CacheType.cosmosdb.value, kwargs)
assert isinstance(cache, JsonPipelineCache)
def test_register_and_create_custom_cache():
"""Test registering and creating a custom cache type."""
from unittest.mock import MagicMock
# Create a mock that satisfies the PipelineCache interface
custom_cache_class = MagicMock(spec=PipelineCache)
# Make the mock return a mock instance when instantiated
instance = MagicMock()
instance.initialized = True
custom_cache_class.return_value = instance
CacheFactory.register("custom", lambda **kwargs: custom_cache_class(**kwargs))
cache = CacheFactory.create_cache("custom", {})
assert custom_cache_class.called
assert cache is instance
# Access the attribute we set on our mock
assert cache.initialized is True # type: ignore # Attribute only exists on our mock
# Check if it's in the list of registered cache types
assert "custom" in CacheFactory.get_cache_types()
assert CacheFactory.is_supported_type("custom")
def test_get_cache_types():
cache_types = CacheFactory.get_cache_types()
# Check that built-in types are registered
assert CacheType.none.value in cache_types
assert CacheType.memory.value in cache_types
assert CacheType.file.value in cache_types
assert CacheType.blob.value in cache_types
assert CacheType.cosmosdb.value in cache_types
def test_create_unknown_cache():
with pytest.raises(ValueError, match="Unknown cache type: unknown"):
CacheFactory.create_cache("unknown", {})
def test_is_supported_type():
# Test built-in types
assert CacheFactory.is_supported_type(CacheType.none.value)
assert CacheFactory.is_supported_type(CacheType.memory.value)
assert CacheFactory.is_supported_type(CacheType.file.value)
assert CacheFactory.is_supported_type(CacheType.blob.value)
assert CacheFactory.is_supported_type(CacheType.cosmosdb.value)
# Test unknown type
assert not CacheFactory.is_supported_type("unknown")
def test_enum_and_string_compatibility():
"""Test that both enum and string types work for cache creation."""
kwargs = {}
# Test with enum
cache_enum = CacheFactory.create_cache(CacheType.memory, kwargs)
assert isinstance(cache_enum, InMemoryCache)
# Test with string
cache_str = CacheFactory.create_cache("memory", kwargs)
assert isinstance(cache_str, InMemoryCache)
# Both should create the same type
assert type(cache_enum) is type(cache_str)
def test_register_class_directly_works():
"""Test that registering a class directly works (CacheFactory allows this)."""
from graphrag.cache.pipeline_cache import PipelineCache
class CustomCache(PipelineCache):
def __init__(self, **kwargs):
pass
async def get(self, key: str):
return None
async def set(self, key: str, value, debug_data=None):
pass
async def has(self, key: str):
return False
async def delete(self, key: str):
pass
async def clear(self):
pass
def child(self, name: str):
return self
# CacheFactory allows registering classes directly (no TypeError)
CacheFactory.register("custom_class", CustomCache)
# Verify it was registered
assert "custom_class" in CacheFactory.get_cache_types()
assert CacheFactory.is_supported_type("custom_class")
# Test creating an instance
cache = CacheFactory.create_cache("custom_class", {})
assert isinstance(cache, CustomCache)

View File

@ -0,0 +1,65 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""LoggerFactory Tests.
These tests will test the LoggerFactory class and the creation of each reporting type that is natively supported.
"""
import logging
import pytest
from graphrag.config.enums import ReportingType
from graphrag.logger.blob_workflow_logger import BlobWorkflowLogger
from graphrag.logger.factory import LoggerFactory
# cspell:disable-next-line well-known-key
WELL_KNOWN_BLOB_STORAGE_KEY = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;"
# cspell:disable-next-line well-known-key
WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="
@pytest.mark.skip(reason="Blob storage emulator is not available in this environment")
def test_create_blob_logger():
kwargs = {
"type": "blob",
"connection_string": WELL_KNOWN_BLOB_STORAGE_KEY,
"base_dir": "testbasedir",
"container_name": "testcontainer",
}
logger = LoggerFactory.create_logger(ReportingType.blob.value, kwargs)
assert isinstance(logger, BlobWorkflowLogger)
def test_register_and_create_custom_logger():
"""Test registering and creating a custom logger type."""
from unittest.mock import MagicMock
custom_logger_class = MagicMock(spec=logging.Handler)
instance = MagicMock()
instance.initialized = True
custom_logger_class.return_value = instance
LoggerFactory.register("custom", lambda **kwargs: custom_logger_class(**kwargs))
logger = LoggerFactory.create_logger("custom", {})
assert custom_logger_class.called
assert logger is instance
# Access the attribute we set on our mock
assert logger.initialized is True # type: ignore # Attribute only exists on our mock
# Check if it's in the list of registered logger types
assert "custom" in LoggerFactory.get_logger_types()
assert LoggerFactory.is_supported_type("custom")
def test_get_logger_types():
logger_types = LoggerFactory.get_logger_types()
# Check that built-in types are registered
assert ReportingType.file.value in logger_types
assert ReportingType.blob.value in logger_types
def test_create_unknown_logger():
with pytest.raises(ValueError, match="Unknown reporting type: unknown"):
LoggerFactory.create_logger("unknown", {})

View File

@ -24,7 +24,7 @@ if not sys.platform.startswith("win"):
async def test_find():
storage = CosmosDBPipelineStorage(
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
database_name="testfind",
base_dir="testfind",
container_name="testfindcontainer",
)
try:
@ -70,7 +70,7 @@ async def test_find():
async def test_child():
storage = CosmosDBPipelineStorage(
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
database_name="testchild",
base_dir="testchild",
container_name="testchildcontainer",
)
try:
@ -83,7 +83,7 @@ async def test_child():
async def test_clear():
storage = CosmosDBPipelineStorage(
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
database_name="testclear",
base_dir="testclear",
container_name="testclearcontainer",
)
try:
@ -113,7 +113,7 @@ async def test_clear():
async def test_get_creation_date():
storage = CosmosDBPipelineStorage(
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
database_name="testclear",
base_dir="testclear",
container_name="testclearcontainer",
)
try:

View File

@ -31,7 +31,7 @@ def test_create_blob_storage():
"base_dir": "testbasedir",
"container_name": "testcontainer",
}
storage = StorageFactory.create_storage(StorageType.blob, kwargs)
storage = StorageFactory.create_storage(StorageType.blob.value, kwargs)
assert isinstance(storage, BlobPipelineStorage)
@ -46,19 +46,19 @@ def test_create_cosmosdb_storage():
"base_dir": "testdatabase",
"container_name": "testcontainer",
}
storage = StorageFactory.create_storage(StorageType.cosmosdb, kwargs)
storage = StorageFactory.create_storage(StorageType.cosmosdb.value, kwargs)
assert isinstance(storage, CosmosDBPipelineStorage)
def test_create_file_storage():
kwargs = {"type": "file", "base_dir": "/tmp/teststorage"}
storage = StorageFactory.create_storage(StorageType.file, kwargs)
storage = StorageFactory.create_storage(StorageType.file.value, kwargs)
assert isinstance(storage, FilePipelineStorage)
def test_create_memory_storage():
kwargs = {"type": "memory"}
storage = StorageFactory.create_storage(StorageType.memory, kwargs)
kwargs = {} # MemoryPipelineStorage doesn't accept any constructor parameters
storage = StorageFactory.create_storage(StorageType.memory.value, kwargs)
assert isinstance(storage, MemoryPipelineStorage)
@ -84,7 +84,7 @@ def test_register_and_create_custom_storage():
# Check if it's in the list of registered storage types
assert "custom" in StorageFactory.get_storage_types()
assert StorageFactory.is_supported_storage_type("custom")
assert StorageFactory.is_supported_type("custom")
def test_get_storage_types():
@ -96,13 +96,65 @@ def test_get_storage_types():
assert StorageType.cosmosdb.value in storage_types
def test_backward_compatibility():
"""Test that the storage_types attribute is still accessible for backward compatibility."""
assert hasattr(StorageFactory, "storage_types")
# The storage_types attribute should be a dict
assert isinstance(StorageFactory.storage_types, dict)
def test_create_unknown_storage():
with pytest.raises(ValueError, match="Unknown storage type: unknown"):
StorageFactory.create_storage("unknown", {})
def test_register_class_directly_works():
"""Test that registering a class directly works (StorageFactory allows this)."""
import re
from collections.abc import Iterator
from typing import Any
from graphrag.storage.pipeline_storage import PipelineStorage
class CustomStorage(PipelineStorage):
def __init__(self, **kwargs):
pass
def find(
self,
file_pattern: re.Pattern[str],
base_dir: str | None = None,
file_filter: dict[str, Any] | None = None,
max_count=-1,
) -> Iterator[tuple[str, dict[str, Any]]]:
return iter([])
async def get(
self, key: str, as_bytes: bool | None = None, encoding: str | None = None
) -> Any:
return None
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
pass
async def delete(self, key: str) -> None:
pass
async def has(self, key: str) -> bool:
return False
async def clear(self) -> None:
pass
def child(self, name: str | None) -> "PipelineStorage":
return self
def keys(self) -> list[str]:
return []
async def get_creation_date(self, key: str) -> str:
return "2024-01-01 00:00:00 +0000"
# StorageFactory allows registering classes directly (no TypeError)
StorageFactory.register("custom_class", CustomStorage)
# Verify it was registered
assert "custom_class" in StorageFactory.get_storage_types()
assert StorageFactory.is_supported_type("custom_class")
# Test creating an instance
storage = StorageFactory.create_storage("custom_class", {})
assert isinstance(storage, CustomStorage)

View File

@ -0,0 +1,144 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""VectorStoreFactory Tests.
These tests will test the VectorStoreFactory class and the creation of each vector store type that is natively supported.
"""
import pytest
from graphrag.config.enums import VectorStoreType
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
from graphrag.vector_stores.base import BaseVectorStore
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
from graphrag.vector_stores.factory import VectorStoreFactory
from graphrag.vector_stores.lancedb import LanceDBVectorStore
def test_create_lancedb_vector_store():
kwargs = {
"collection_name": "test_collection",
"db_uri": "/tmp/lancedb",
}
vector_store = VectorStoreFactory.create_vector_store(
VectorStoreType.LanceDB.value, kwargs
)
assert isinstance(vector_store, LanceDBVectorStore)
assert vector_store.collection_name == "test_collection"
@pytest.mark.skip(reason="Azure AI Search requires credentials and setup")
def test_create_azure_ai_search_vector_store():
kwargs = {
"collection_name": "test_collection",
"url": "https://test.search.windows.net",
"api_key": "test_key",
}
vector_store = VectorStoreFactory.create_vector_store(
VectorStoreType.AzureAISearch.value, kwargs
)
assert isinstance(vector_store, AzureAISearchVectorStore)
@pytest.mark.skip(reason="CosmosDB requires credentials and setup")
def test_create_cosmosdb_vector_store():
kwargs = {
"collection_name": "test_collection",
"connection_string": "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=test_key==",
"database_name": "test_db",
}
vector_store = VectorStoreFactory.create_vector_store(
VectorStoreType.CosmosDB.value, kwargs
)
assert isinstance(vector_store, CosmosDBVectorStore)
def test_register_and_create_custom_vector_store():
"""Test registering and creating a custom vector store type."""
from unittest.mock import MagicMock
# Create a mock that satisfies the BaseVectorStore interface
custom_vector_store_class = MagicMock(spec=BaseVectorStore)
# Make the mock return a mock instance when instantiated
instance = MagicMock()
instance.initialized = True
custom_vector_store_class.return_value = instance
VectorStoreFactory.register(
"custom", lambda **kwargs: custom_vector_store_class(**kwargs)
)
vector_store = VectorStoreFactory.create_vector_store("custom", {})
assert custom_vector_store_class.called
assert vector_store is instance
# Access the attribute we set on our mock
assert vector_store.initialized is True # type: ignore # Attribute only exists on our mock
# Check if it's in the list of registered vector store types
assert "custom" in VectorStoreFactory.get_vector_store_types()
assert VectorStoreFactory.is_supported_type("custom")
def test_get_vector_store_types():
vector_store_types = VectorStoreFactory.get_vector_store_types()
# Check that built-in types are registered
assert VectorStoreType.LanceDB.value in vector_store_types
assert VectorStoreType.AzureAISearch.value in vector_store_types
assert VectorStoreType.CosmosDB.value in vector_store_types
def test_create_unknown_vector_store():
with pytest.raises(ValueError, match="Unknown vector store type: unknown"):
VectorStoreFactory.create_vector_store("unknown", {})
def test_is_supported_type():
# Test built-in types
assert VectorStoreFactory.is_supported_type(VectorStoreType.LanceDB.value)
assert VectorStoreFactory.is_supported_type(VectorStoreType.AzureAISearch.value)
assert VectorStoreFactory.is_supported_type(VectorStoreType.CosmosDB.value)
# Test unknown type
assert not VectorStoreFactory.is_supported_type("unknown")
def test_register_class_directly_works():
"""Test that registering a class directly works (VectorStoreFactory allows this)."""
from graphrag.vector_stores.base import BaseVectorStore
class CustomVectorStore(BaseVectorStore):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def connect(self, **kwargs):
pass
def load_documents(self, documents, overwrite=True):
pass
def similarity_search_by_vector(self, query_embedding, k=10, **kwargs):
return []
def similarity_search_by_text(self, text, text_embedder, k=10, **kwargs):
return []
def filter_by_id(self, include_ids):
return {}
def search_by_id(self, id):
from graphrag.vector_stores.base import VectorStoreDocument
return VectorStoreDocument(id=id, text="test", vector=None)
# VectorStoreFactory allows registering classes directly (no TypeError)
VectorStoreFactory.register("custom_class", CustomVectorStore)
# Verify it was registered
assert "custom_class" in VectorStoreFactory.get_vector_store_types()
assert VectorStoreFactory.is_supported_type("custom_class")
# Test creating an instance
vector_store = VectorStoreFactory.create_vector_store(
"custom_class", {"collection_name": "test"}
)
assert isinstance(vector_store, CustomVectorStore)

View File

@ -13,7 +13,7 @@ TEMP_DIR = "./.tmp"
def create_cache():
storage = FilePipelineStorage(root_dir=os.path.join(os.getcwd(), ".tmp"))
storage = FilePipelineStorage(base_dir=os.path.join(os.getcwd(), ".tmp"))
return JsonPipelineCache(storage)

3292
uv.lock generated

File diff suppressed because it is too large Load Diff