From 6aae386b305c4acd79a5977bc31a41076d1eb0eb Mon Sep 17 00:00:00 2001 From: Matthieu Maitre Date: Mon, 21 Oct 2024 13:03:48 -0500 Subject: [PATCH 1/2] Perf optimizations in map_query_to_entities() (#1276) * Address perf issue in map_query_to_entities() * Add semver --------- Co-authored-by: Matthieu Maitre Co-authored-by: Alonso Guevara --- .../patch-20241014040518441266.json | 4 + .../context_builder/entity_extraction.py | 19 +- graphrag/query/input/retrieval/entities.py | 21 +- .../local_search/mixed_context.py | 2 +- tests/unit/query/context_builder/__init__.py | 2 + .../context_builder/test_entity_extraction.py | 182 ++++++++++++++++++ tests/unit/query/input/__init__.py | 2 + tests/unit/query/input/retrieval/__init__.py | 2 + .../query/input/retrieval/test_entities.py | 167 ++++++++++++++++ 9 files changed, 388 insertions(+), 13 deletions(-) create mode 100644 .semversioner/next-release/patch-20241014040518441266.json create mode 100644 tests/unit/query/context_builder/__init__.py create mode 100644 tests/unit/query/context_builder/test_entity_extraction.py create mode 100644 tests/unit/query/input/__init__.py create mode 100644 tests/unit/query/input/retrieval/__init__.py create mode 100644 tests/unit/query/input/retrieval/test_entities.py diff --git a/.semversioner/next-release/patch-20241014040518441266.json b/.semversioner/next-release/patch-20241014040518441266.json new file mode 100644 index 00000000..c5831a0c --- /dev/null +++ b/.semversioner/next-release/patch-20241014040518441266.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Perf optimizations in map_query_to_entities()" +} diff --git a/graphrag/query/context_builder/entity_extraction.py b/graphrag/query/context_builder/entity_extraction.py index 82a0699c..4b8767b8 100644 --- a/graphrag/query/context_builder/entity_extraction.py +++ b/graphrag/query/context_builder/entity_extraction.py @@ -7,6 +7,7 @@ from enum import Enum from graphrag.model import Entity, Relationship from graphrag.query.input.retrieval.entities import ( + get_entity_by_id, get_entity_by_key, get_entity_by_name, ) @@ -36,7 +37,7 @@ def map_query_to_entities( query: str, text_embedding_vectorstore: BaseVectorStore, text_embedder: BaseTextEmbedding, - all_entities: list[Entity], + all_entities_dict: dict[str, Entity], embedding_vectorstore_key: str = EntityVectorStoreKey.ID, include_entity_names: list[str] | None = None, exclude_entity_names: list[str] | None = None, @@ -48,6 +49,7 @@ def map_query_to_entities( include_entity_names = [] if exclude_entity_names is None: exclude_entity_names = [] + all_entities = list(all_entities_dict.values()) matched_entities = [] if query != "": # get entities with highest semantic similarity to query @@ -58,11 +60,16 @@ def map_query_to_entities( k=k * oversample_scaler, ) for result in search_results: - matched = get_entity_by_key( - entities=all_entities, - key=embedding_vectorstore_key, - value=result.document.id, - ) + if embedding_vectorstore_key == EntityVectorStoreKey.ID and isinstance( + result.document.id, str + ): + matched = get_entity_by_id(all_entities_dict, result.document.id) + else: + matched = get_entity_by_key( + entities=all_entities, + key=embedding_vectorstore_key, + value=result.document.id, + ) if matched: matched_entities.append(matched) else: diff --git a/graphrag/query/input/retrieval/entities.py b/graphrag/query/input/retrieval/entities.py index 5465f9f5..41c92fab 100644 --- a/graphrag/query/input/retrieval/entities.py +++ b/graphrag/query/input/retrieval/entities.py @@ -12,17 +12,26 @@ import pandas as pd from graphrag.model import Entity +def get_entity_by_id(entities: dict[str, Entity], value: str) -> Entity | None: + """Get entity by id.""" + entity = entities.get(value) + if entity is None and is_valid_uuid(value): + entity = entities.get(value.replace("-", "")) + return entity + + def get_entity_by_key( entities: Iterable[Entity], key: str, value: str | int ) -> Entity | None: """Get entity by key.""" - for entity in entities: - if isinstance(value, str) and is_valid_uuid(value): - if getattr(entity, key) == value or getattr(entity, key) == value.replace( - "-", "" - ): + if isinstance(value, str) and is_valid_uuid(value): + value_no_dashes = value.replace("-", "") + for entity in entities: + entity_value = getattr(entity, key) + if entity_value in (value, value_no_dashes): return entity - else: + else: + for entity in entities: if getattr(entity, key) == value: return entity return None diff --git a/graphrag/query/structured_search/local_search/mixed_context.py b/graphrag/query/structured_search/local_search/mixed_context.py index e0608e4b..d160fe81 100644 --- a/graphrag/query/structured_search/local_search/mixed_context.py +++ b/graphrag/query/structured_search/local_search/mixed_context.py @@ -141,7 +141,7 @@ class LocalSearchMixedContext(LocalContextBuilder): query=query, text_embedding_vectorstore=self.entity_text_embeddings, text_embedder=self.text_embedder, - all_entities=list(self.entities.values()), + all_entities_dict=self.entities, embedding_vectorstore_key=self.embedding_vectorstore_key, include_entity_names=include_entity_names, exclude_entity_names=exclude_entity_names, diff --git a/tests/unit/query/context_builder/__init__.py b/tests/unit/query/context_builder/__init__.py new file mode 100644 index 00000000..0a3e38ad --- /dev/null +++ b/tests/unit/query/context_builder/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License diff --git a/tests/unit/query/context_builder/test_entity_extraction.py b/tests/unit/query/context_builder/test_entity_extraction.py new file mode 100644 index 00000000..de71b880 --- /dev/null +++ b/tests/unit/query/context_builder/test_entity_extraction.py @@ -0,0 +1,182 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +from typing import Any + +from graphrag.model import Entity +from graphrag.model.types import TextEmbedder +from graphrag.query.context_builder.entity_extraction import ( + EntityVectorStoreKey, + map_query_to_entities, +) +from graphrag.query.llm.base import BaseTextEmbedding +from graphrag.vector_stores import ( + BaseVectorStore, + VectorStoreDocument, + VectorStoreSearchResult, +) + + +class MockBaseVectorStore(BaseVectorStore): + def __init__(self, documents: list[VectorStoreDocument]) -> None: + super().__init__("mock") + self.documents = documents + + def connect(self, **kwargs: Any) -> None: + raise NotImplementedError + + def load_documents( + self, documents: list[VectorStoreDocument], overwrite: bool = True + ) -> None: + raise NotImplementedError + + def similarity_search_by_vector( + self, query_embedding: list[float], k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + return [ + VectorStoreSearchResult(document=document, score=1) + for document in self.documents[:k] + ] + + def similarity_search_by_text( + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + return sorted( + [ + VectorStoreSearchResult( + document=document, score=abs(len(text) - len(document.text or "")) + ) + for document in self.documents + ], + key=lambda x: x.score, + )[:k] + + def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: + return [document for document in self.documents if document.id in include_ids] + + +class MockBaseTextEmbedding(BaseTextEmbedding): + def embed(self, text: str, **kwargs: Any) -> list[float]: + return [len(text)] + + async def aembed(self, text: str, **kwargs: Any) -> list[float]: + return [len(text)] + + +def test_map_query_to_entities(): + entities = [ + Entity( + id="2da37c7a-50a8-44d4-aa2c-fd401e19976c", + short_id="sid1", + title="t1", + rank=2, + ), + Entity( + id="c4f93564-4507-4ee4-b102-98add401a965", + short_id="sid2", + title="t22", + rank=4, + ), + Entity( + id="7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", + short_id="sid3", + title="t333", + rank=1, + ), + Entity( + id="8fd6d72a-8e9d-4183-8a97-c38bcc971c83", + short_id="sid4", + title="t4444", + rank=3, + ), + ] + + assert map_query_to_entities( + query="t22", + text_embedding_vectorstore=MockBaseVectorStore([ + VectorStoreDocument(id=entity.id, text=entity.title, vector=None) + for entity in entities + ]), + text_embedder=MockBaseTextEmbedding(), + all_entities_dict={entity.id: entity for entity in entities}, + embedding_vectorstore_key=EntityVectorStoreKey.ID, + k=1, + oversample_scaler=1, + ) == [ + Entity( + id="c4f93564-4507-4ee4-b102-98add401a965", + short_id="sid2", + title="t22", + rank=4, + ) + ] + + assert map_query_to_entities( + query="t22", + text_embedding_vectorstore=MockBaseVectorStore([ + VectorStoreDocument(id=entity.title, text=entity.title, vector=None) + for entity in entities + ]), + text_embedder=MockBaseTextEmbedding(), + all_entities_dict={entity.id: entity for entity in entities}, + embedding_vectorstore_key=EntityVectorStoreKey.TITLE, + k=1, + oversample_scaler=1, + ) == [ + Entity( + id="c4f93564-4507-4ee4-b102-98add401a965", + short_id="sid2", + title="t22", + rank=4, + ) + ] + + assert map_query_to_entities( + query="", + text_embedding_vectorstore=MockBaseVectorStore([ + VectorStoreDocument(id=entity.id, text=entity.title, vector=None) + for entity in entities + ]), + text_embedder=MockBaseTextEmbedding(), + all_entities_dict={entity.id: entity for entity in entities}, + embedding_vectorstore_key=EntityVectorStoreKey.ID, + k=2, + ) == [ + Entity( + id="c4f93564-4507-4ee4-b102-98add401a965", + short_id="sid2", + title="t22", + rank=4, + ), + Entity( + id="8fd6d72a-8e9d-4183-8a97-c38bcc971c83", + short_id="sid4", + title="t4444", + rank=3, + ), + ] + + assert map_query_to_entities( + query="", + text_embedding_vectorstore=MockBaseVectorStore([ + VectorStoreDocument(id=entity.id, text=entity.title, vector=None) + for entity in entities + ]), + text_embedder=MockBaseTextEmbedding(), + all_entities_dict={entity.id: entity for entity in entities}, + embedding_vectorstore_key=EntityVectorStoreKey.TITLE, + k=2, + ) == [ + Entity( + id="c4f93564-4507-4ee4-b102-98add401a965", + short_id="sid2", + title="t22", + rank=4, + ), + Entity( + id="8fd6d72a-8e9d-4183-8a97-c38bcc971c83", + short_id="sid4", + title="t4444", + rank=3, + ), + ] diff --git a/tests/unit/query/input/__init__.py b/tests/unit/query/input/__init__.py new file mode 100644 index 00000000..0a3e38ad --- /dev/null +++ b/tests/unit/query/input/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License diff --git a/tests/unit/query/input/retrieval/__init__.py b/tests/unit/query/input/retrieval/__init__.py new file mode 100644 index 00000000..0a3e38ad --- /dev/null +++ b/tests/unit/query/input/retrieval/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License diff --git a/tests/unit/query/input/retrieval/test_entities.py b/tests/unit/query/input/retrieval/test_entities.py new file mode 100644 index 00000000..a66e3432 --- /dev/null +++ b/tests/unit/query/input/retrieval/test_entities.py @@ -0,0 +1,167 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +from graphrag.model import Entity +from graphrag.query.input.retrieval.entities import ( + get_entity_by_id, + get_entity_by_key, +) + + +def test_get_entity_by_id(): + assert ( + get_entity_by_id( + { + entity.id: entity + for entity in [ + Entity( + id="2da37c7a-50a8-44d4-aa2c-fd401e19976c", + short_id="sid1", + title="title1", + ), + ] + }, + "00000000-0000-0000-0000-000000000000", + ) + is None + ) + + assert get_entity_by_id( + { + entity.id: entity + for entity in [ + Entity( + id="2da37c7a-50a8-44d4-aa2c-fd401e19976c", + short_id="sid1", + title="title1", + ), + Entity( + id="c4f93564-4507-4ee4-b102-98add401a965", + short_id="sid2", + title="title2", + ), + Entity( + id="7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", + short_id="sid3", + title="title3", + ), + ] + }, + "7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", + ) == Entity( + id="7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", short_id="sid3", title="title3" + ) + + assert get_entity_by_id( + { + entity.id: entity + for entity in [ + Entity( + id="2da37c7a50a844d4aa2cfd401e19976c", + short_id="sid1", + title="title1", + ), + Entity( + id="c4f9356445074ee4b10298add401a965", + short_id="sid2", + title="title2", + ), + Entity( + id="7c6f2bc947c9445393a3d2e174a02cd9", + short_id="sid3", + title="title3", + ), + ] + }, + "7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", + ) == Entity(id="7c6f2bc947c9445393a3d2e174a02cd9", short_id="sid3", title="title3") + + assert get_entity_by_id( + { + entity.id: entity + for entity in [ + Entity(id="id1", short_id="sid1", title="title1"), + Entity(id="id2", short_id="sid2", title="title2"), + Entity(id="id3", short_id="sid3", title="title3"), + ] + }, + "id3", + ) == Entity(id="id3", short_id="sid3", title="title3") + + +def test_get_entity_by_key(): + assert ( + get_entity_by_key( + [ + Entity( + id="2da37c7a-50a8-44d4-aa2c-fd401e19976c", + short_id="sid1", + title="title1", + ), + ], + "id", + "00000000-0000-0000-0000-000000000000", + ) + is None + ) + + assert get_entity_by_key( + [ + Entity( + id="2da37c7a-50a8-44d4-aa2c-fd401e19976c", + short_id="sid1", + title="title1", + ), + Entity( + id="c4f93564-4507-4ee4-b102-98add401a965", + short_id="sid2", + title="title2", + ), + Entity( + id="7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", + short_id="sid3", + title="title3", + ), + ], + "id", + "7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", + ) == Entity( + id="7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", short_id="sid3", title="title3" + ) + + assert get_entity_by_key( + [ + Entity( + id="2da37c7a50a844d4aa2cfd401e19976c", short_id="sid1", title="title1" + ), + Entity( + id="c4f9356445074ee4b10298add401a965", short_id="sid2", title="title2" + ), + Entity( + id="7c6f2bc947c9445393a3d2e174a02cd9", short_id="sid3", title="title3" + ), + ], + "id", + "7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", + ) == Entity(id="7c6f2bc947c9445393a3d2e174a02cd9", short_id="sid3", title="title3") + + assert get_entity_by_key( + [ + Entity(id="id1", short_id="sid1", title="title1"), + Entity(id="id2", short_id="sid2", title="title2"), + Entity(id="id3", short_id="sid3", title="title3"), + ], + "id", + "id3", + ) == Entity(id="id3", short_id="sid3", title="title3") + + assert get_entity_by_key( + [ + Entity(id="id1", short_id="sid1", title="title1", rank=1), + Entity(id="id2", short_id="sid2", title="title2a", rank=2), + Entity(id="id3", short_id="sid3", title="title3", rank=3), + Entity(id="id2", short_id="sid2", title="title2b", rank=2), + ], + "rank", + 2, + ) == Entity(id="id2", short_id="sid2", title="title2a", rank=2) From e0840a2dc4787632791e40de51321feb5b0da866 Mon Sep 17 00:00:00 2001 From: KennyZhang1 <90438893+KennyZhang1@users.noreply.github.com> Date: Mon, 21 Oct 2024 16:56:56 -0400 Subject: [PATCH 2/2] Fix vector store logic and refactor audience parameter (#1259) --- .gitignore | 2 +- .../patch-20241008161248831044.json | 4 + docs/config/json_yaml.md | 14 ++- graphrag/api/__init__.py | 6 +- graphrag/api/{index_api.py => index.py} | 13 ++- .../{prompt_tune_api.py => prompt_tune.py} | 0 graphrag/api/{query_api.py => query.py} | 97 +++++-------------- graphrag/config/create_graphrag_config.py | 32 +++--- graphrag/config/defaults.py | 13 ++- .../input_models/llm_parameters_input.py | 2 +- graphrag/config/models/llm_parameters.py | 5 +- .../config/models/text_embedding_config.py | 2 +- graphrag/index/cli.py | 36 +------ graphrag/index/init_content.py | 17 +++- graphrag/index/llm/load_llm.py | 2 +- .../index/operations/embed_text/embed_text.py | 2 +- graphrag/llm/openai/create_openai_client.py | 11 ++- graphrag/llm/openai/openai_configuration.py | 8 +- graphrag/query/factories.py | 23 ++--- graphrag/utils/cli.py | 31 ++++++ graphrag/vector_stores/__init__.py | 12 +-- graphrag/vector_stores/azure_ai_search.py | 22 +++-- .../vector_stores/{typing.py => factory.py} | 4 +- graphrag/vector_stores/lancedb.py | 32 ++++-- tests/fixtures/azure/settings.yml | 2 - tests/fixtures/min-csv/settings.yml | 3 +- tests/fixtures/text/settings.yml | 2 - 27 files changed, 203 insertions(+), 194 deletions(-) create mode 100644 .semversioner/next-release/patch-20241008161248831044.json rename graphrag/api/{index_api.py => index.py} (80%) rename graphrag/api/{prompt_tune_api.py => prompt_tune.py} (100%) rename graphrag/api/{query_api.py => query.py} (82%) rename graphrag/vector_stores/{typing.py => factory.py} (92%) diff --git a/.gitignore b/.gitignore index 9f0c6529..2ddb2935 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ # Python Artifacts python/*/lib/ dist/ + # Test Output .coverage coverage/ @@ -20,7 +21,6 @@ venv/ .conda .tmp - .env build.zip diff --git a/.semversioner/next-release/patch-20241008161248831044.json b/.semversioner/next-release/patch-20241008161248831044.json new file mode 100644 index 00000000..64deb4e9 --- /dev/null +++ b/.semversioner/next-release/patch-20241008161248831044.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "refactor use of vector stores and update support for managed identity" +} diff --git a/docs/config/json_yaml.md b/docs/config/json_yaml.md index 35d456b6..bc2c5646 100644 --- a/docs/config/json_yaml.md +++ b/docs/config/json_yaml.md @@ -52,7 +52,7 @@ This is the base LLM configuration section. Other steps may override this config - `api_version` **str** - The API version - `organization` **str** - The client organization. - `proxy` **str** - The proxy URL to use. -- `cognitive_services_endpoint` **str** - The url endpoint for cognitive services. +- `audience` **str** - (Azure OpenAI only) The URI of the target Azure resource/service for which a managed identity token is requested. Used if `api_key` is not defined. Default=`https://cognitiveservices.azure.com/.default` - `deployment_name` **str** - The deployment name to use (Azure). - `model_supports_json` **bool** - Whether the model supports JSON-mode output. - `tokens_per_minute` **int** - Set a leaky-bucket throttle on tokens-per-minute. @@ -84,9 +84,17 @@ This is the base LLM configuration section. Other steps may override this config - `parallelization` (see Parallelization top-level config) - `async_mode` (see Async Mode top-level config) - `batch_size` **int** - The maximum batch size to use. -- `batch_max_tokens` **int** - The maximum batch #-tokens. +- `batch_max_tokens` **int** - The maximum batch # of tokens. - `target` **required|all** - Determines which set of embeddings to emit. - `skip` **list[str]** - Which embeddings to skip. +- `vector_store` **dict** - The vector store to use. Configured for lancedb by default. + - `type` **str** - `lancedb` or `azure_ai_search`. Default=`lancedb` + - `db_uri` **str** (only for lancedb) - The database uri. Default=`storage.base_dir/lancedb` + - `url` **str** (only for AI Search) - AI Search endpoint + - `api_key` **str** (optional - only for AI Search) - The AI Search api key to use. + - `audience` **str** (only for AI Search) - Audience for managed identity token if managed identity authentication is used. + - `overwrite` **bool** (only used at index creation time) - Overwrite collection if it exist. Default=`True` + - `collection_name` **str** - The name of a vector collection. Default=`entity_description_embeddings` - `strategy` **dict** - Fully override the text-embedding strategy. ## chunks @@ -214,7 +222,7 @@ This is the base LLM configuration section. Other steps may override this config ## encoding_model -**str** - The text encoding model to use. Default is `cl100k_base`. +**str** - The text encoding model to use. Default=`cl100k_base`. ## skip_workflows diff --git a/graphrag/api/__init__.py b/graphrag/api/__init__.py index 120e3e41..15b005e4 100644 --- a/graphrag/api/__init__.py +++ b/graphrag/api/__init__.py @@ -7,9 +7,9 @@ WARNING: This API is under development and may undergo changes in future release Backwards compatibility is not guaranteed at this time. """ -from .index_api import build_index -from .prompt_tune_api import DocSelectionType, generate_indexing_prompts -from .query_api import ( +from graphrag.api.index import build_index +from graphrag.api.prompt_tune import DocSelectionType, generate_indexing_prompts +from graphrag.api.query import ( global_search, global_search_streaming, local_search, diff --git a/graphrag/api/index_api.py b/graphrag/api/index.py similarity index 80% rename from graphrag/api/index_api.py rename to graphrag/api/index.py index 4b9a36d1..e938abdb 100644 --- a/graphrag/api/index_api.py +++ b/graphrag/api/index.py @@ -8,6 +8,8 @@ WARNING: This API is under development and may undergo changes in future release Backwards compatibility is not guaranteed at this time. """ +from pathlib import Path + from graphrag.config import CacheType, GraphRagConfig from graphrag.index.cache.noop_pipeline_cache import NoopPipelineCache from graphrag.index.create_pipeline_config import create_pipeline_config @@ -15,6 +17,7 @@ from graphrag.index.emit.types import TableEmitterType from graphrag.index.run import run_pipeline_with_config from graphrag.index.typing import PipelineRunResult from graphrag.logging import ProgressReporter +from graphrag.vector_stores.factory import VectorStoreType async def build_index( @@ -30,7 +33,7 @@ async def build_index( Parameters ---------- - config : PipelineConfig + config : GraphRagConfig The configuration. run_id : str The run id. Creates a output directory with this name. @@ -55,6 +58,14 @@ async def build_index( msg = "Cannot resume and update a run at the same time." raise ValueError(msg) + # TODO: must update filepath of lancedb (if used) until the new config engine has been implemented + # TODO: remove the type ignore annotations below once the new config engine has been refactored + vector_store_type = config.embeddings.vector_store["type"] # type: ignore + if vector_store_type == VectorStoreType.LanceDB: + db_uri = config.embeddings.vector_store["db_uri"] # type: ignore + lancedb_dir = Path(config.root_dir).resolve() / db_uri + config.embeddings.vector_store["db_uri"] = str(lancedb_dir) # type: ignore + pipeline_config = create_pipeline_config(config) pipeline_cache = ( NoopPipelineCache() if config.cache.type == CacheType.none is None else None diff --git a/graphrag/api/prompt_tune_api.py b/graphrag/api/prompt_tune.py similarity index 100% rename from graphrag/api/prompt_tune_api.py rename to graphrag/api/prompt_tune.py diff --git a/graphrag/api/query_api.py b/graphrag/api/query.py similarity index 82% rename from graphrag/api/query_api.py rename to graphrag/api/query.py index 9e18ca88..2897217c 100644 --- a/graphrag/api/query_api.py +++ b/graphrag/api/query.py @@ -26,7 +26,6 @@ from pydantic import validate_call from graphrag.config import GraphRagConfig from graphrag.logging import PrintProgressReporter -from graphrag.model.entity import Entity from graphrag.query.factories import get_global_search_engine, get_local_search_engine from graphrag.query.indexer_adapters import ( read_indexer_covariates, @@ -35,10 +34,9 @@ from graphrag.query.indexer_adapters import ( read_indexer_reports, read_indexer_text_units, ) -from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings from graphrag.query.structured_search.base import SearchResult # noqa: TCH001 -from graphrag.vector_stores.lancedb import LanceDBVectorStore -from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType +from graphrag.utils.cli import redact +from graphrag.vector_stores import VectorStoreFactory, VectorStoreType reporter = PrintProgressReporter("") @@ -184,24 +182,20 @@ async def local_search( ------ TODO: Document any exceptions to expect. """ - vector_store_args = ( - config.embeddings.vector_store if config.embeddings.vector_store else {} + # TODO: must update filepath of lancedb (if used) until the new config engine has been implemented + # TODO: remove the type ignore annotations below once the new config engine has been refactored + vector_store_type = config.embeddings.vector_store.get("type") # type: ignore + vector_store_args = config.embeddings.vector_store + if vector_store_type == "lancedb": + db_uri = config.embeddings.vector_store["db_uri"] # type: ignore + lancedb_dir = Path(config.root_dir).resolve() / db_uri + vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore + reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore + description_embedding_store = _get_embedding_description_store( + config_args=vector_store_args, # type: ignore ) - reporter.info(f"Vector Store Args: {vector_store_args}") - - vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) _entities = read_indexer_entities(nodes, entities, community_level) - - lancedb_dir = Path(config.storage.base_dir) / "lancedb" - - vector_store_args.update({"db_uri": str(lancedb_dir)}) - description_embedding_store = _get_embedding_description_store( - entities=_entities, - vector_store_type=vector_store_type, - config_args=vector_store_args, - ) - _covariates = read_indexer_covariates(covariates) if covariates is not None else [] search_engine = get_local_search_engine( @@ -257,24 +251,20 @@ async def local_search_streaming( ------ TODO: Document any exceptions to expect. """ - vector_store_args = ( - config.embeddings.vector_store if config.embeddings.vector_store else {} + # TODO: must update filepath of lancedb (if used) until the new config engine has been implemented + # TODO: remove the type ignore annotations below once the new config engine has been refactored + vector_store_type = config.embeddings.vector_store["type"] # type: ignore + vector_store_args = config.embeddings.vector_store + if vector_store_type == VectorStoreType.LanceDB: + db_uri = config.embeddings.vector_store["db_uri"] # type: ignore + lancedb_dir = Path(config.root_dir).resolve() / db_uri + vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore + reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore + description_embedding_store = _get_embedding_description_store( + config_args=vector_store_args, # type: ignore ) - reporter.info(f"Vector Store Args: {vector_store_args}") - - vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) _entities = read_indexer_entities(nodes, entities, community_level) - - lancedb_dir = Path(config.storage.base_dir) / "lancedb" - - vector_store_args.update({"db_uri": str(lancedb_dir)}) - description_embedding_store = _get_embedding_description_store( - entities=_entities, - vector_store_type=vector_store_type, - config_args=vector_store_args, - ) - _covariates = read_indexer_covariates(covariates) if covariates is not None else [] search_engine = get_local_search_engine( @@ -303,49 +293,14 @@ async def local_search_streaming( def _get_embedding_description_store( - entities: list[Entity], - vector_store_type: str = VectorStoreType.LanceDB, - config_args: dict | None = None, + config_args: dict, ): """Get the embedding description store.""" - if not config_args: - config_args = {} - - collection_name = config_args.get( - "query_collection_name", "entity_description_embeddings" - ) - config_args.update({"collection_name": collection_name}) + vector_store_type = config_args["type"] description_embedding_store = VectorStoreFactory.get_vector_store( vector_store_type=vector_store_type, kwargs=config_args ) - description_embedding_store.connect(**config_args) - - if config_args.get("overwrite", True): - # this step assumes the embeddings were originally stored in a file rather - # than a vector database - - # dump embeddings from the entities list to the description_embedding_store - store_entity_semantic_embeddings( - entities=entities, vectorstore=description_embedding_store - ) - else: - # load description embeddings to an in-memory lancedb vectorstore - # and connect to a remote db, specify url and port values. - description_embedding_store = LanceDBVectorStore( - collection_name=collection_name - ) - description_embedding_store.connect( - db_uri=config_args.get("db_uri", "./lancedb") - ) - - # load data from an existing table - description_embedding_store.document_collection = ( - description_embedding_store.db_connection.open_table( - description_embedding_store.collection_name - ) - ) - return description_embedding_store diff --git a/graphrag/config/create_graphrag_config.py b/graphrag/config/create_graphrag_config.py index a883658c..82601e52 100644 --- a/graphrag/config/create_graphrag_config.py +++ b/graphrag/config/create_graphrag_config.py @@ -83,10 +83,7 @@ def create_graphrag_config( llm_type = LLMType(llm_type) if llm_type else base.type api_key = reader.str(Fragment.api_key) or base.api_key api_base = reader.str(Fragment.api_base) or base.api_base - cognitive_services_endpoint = ( - reader.str(Fragment.cognitive_services_endpoint) - or base.cognitive_services_endpoint - ) + audience = reader.str(Fragment.audience) or base.audience deployment_name = ( reader.str(Fragment.deployment_name) or base.deployment_name ) @@ -119,7 +116,7 @@ def create_graphrag_config( or base.model_supports_json, request_timeout=reader.float(Fragment.request_timeout) or base.request_timeout, - cognitive_services_endpoint=cognitive_services_endpoint, + audience=audience, deployment_name=deployment_name, tokens_per_minute=reader.int("tokens_per_minute", Fragment.tpm) or base.tokens_per_minute, @@ -141,7 +138,7 @@ def create_graphrag_config( api_type = LLMType(api_type) if api_type else defs.LLM_TYPE api_key = reader.str(Fragment.api_key) or base.api_key - # In a unique events where: + # Account for various permutations of config settings such as: # - same api_bases for LLM and embeddings (both Azure) # - different api_bases for LLM and embeddings (both Azure) # - LLM uses Azure OpenAI, while embeddings uses base OpenAI (this one is important) @@ -158,10 +155,7 @@ def create_graphrag_config( ) api_organization = reader.str("organization") or base.organization api_proxy = reader.str("proxy") or base.proxy - cognitive_services_endpoint = ( - reader.str(Fragment.cognitive_services_endpoint) - or base.cognitive_services_endpoint - ) + audience = reader.str(Fragment.audience) or base.audience deployment_name = reader.str(Fragment.deployment_name) if api_key is None and not _is_azure(api_type): @@ -186,7 +180,7 @@ def create_graphrag_config( model=reader.str(Fragment.model) or defs.EMBEDDING_MODEL, request_timeout=reader.float(Fragment.request_timeout) or defs.LLM_REQUEST_TIMEOUT, - cognitive_services_endpoint=cognitive_services_endpoint, + audience=audience, deployment_name=deployment_name, tokens_per_minute=reader.int("tokens_per_minute", Fragment.tpm) or defs.LLM_TOKENS_PER_MINUTE, @@ -237,9 +231,7 @@ def create_graphrag_config( api_base = reader.str(Fragment.api_base) or fallback_oai_base api_version = reader.str(Fragment.api_version) or fallback_oai_version api_proxy = reader.str(Fragment.api_proxy) or fallback_oai_proxy - cognitive_services_endpoint = reader.str( - Fragment.cognitive_services_endpoint - ) + audience = reader.str(Fragment.audience) deployment_name = reader.str(Fragment.deployment_name) if api_key is None and not _is_azure(llm_type): @@ -270,7 +262,7 @@ def create_graphrag_config( model_supports_json=reader.bool(Fragment.model_supports_json), request_timeout=reader.float(Fragment.request_timeout) or defs.LLM_REQUEST_TIMEOUT, - cognitive_services_endpoint=cognitive_services_endpoint, + audience=audience, deployment_name=deployment_name, tokens_per_minute=reader.int(Fragment.tpm) or defs.LLM_TOKENS_PER_MINUTE, @@ -294,13 +286,15 @@ def create_graphrag_config( embeddings_config = values.get("embeddings") or {} with reader.envvar_prefix(Section.embedding), reader.use(embeddings_config): embeddings_target = reader.str("target") + # TODO: remove the type ignore annotations below once the new config engine has been refactored embeddings_model = TextEmbeddingConfig( - llm=hydrate_embeddings_params(embeddings_config, llm_model), + llm=hydrate_embeddings_params(embeddings_config, llm_model), # type: ignore parallelization=hydrate_parallelization_params( - embeddings_config, llm_parallelization_model + embeddings_config, # type: ignore + llm_parallelization_model, # type: ignore ), vector_store=embeddings_config.get("vector_store", None), - async_mode=hydrate_async_type(embeddings_config, async_mode), + async_mode=hydrate_async_type(embeddings_config, async_mode), # type: ignore target=( TextEmbeddingTarget(embeddings_target) if embeddings_target @@ -579,8 +573,8 @@ class Fragment(str, Enum): api_organization = "API_ORGANIZATION" api_proxy = "API_PROXY" async_mode = "ASYNC_MODE" + audience = "AUDIENCE" base_dir = "BASE_DIR" - cognitive_services_endpoint = "COGNITIVE_SERVICES_ENDPOINT" concurrent_requests = "CONCURRENT_REQUESTS" conn_string = "CONNECTION_STRING" container_name = "CONTAINER_NAME" diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index c7880616..b02d58ec 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -3,8 +3,12 @@ """Common default configuration values.""" +from pathlib import Path + from datashaper import AsyncType +from graphrag.vector_stores import VectorStoreType + from .enums import ( CacheType, InputFileType, @@ -74,7 +78,7 @@ NODE2VEC_WINDOW_SIZE = 2 NODE2VEC_ITERATIONS = 3 NODE2VEC_RANDOM_SEED = 597832 REPORTING_TYPE = ReportingType.file -REPORTING_BASE_DIR = "output" +REPORTING_BASE_DIR = "logs" SNAPSHOTS_GRAPHML = False SNAPSHOTS_RAW_ENTITIES = False SNAPSHOTS_TOP_LEVEL_NODES = False @@ -83,6 +87,13 @@ STORAGE_TYPE = StorageType.file SUMMARIZE_DESCRIPTIONS_MAX_LENGTH = 500 UMAP_ENABLED = False +VECTOR_STORE = f""" + type: {VectorStoreType.LanceDB.value} + db_uri: '{(Path(STORAGE_BASE_DIR) / "lancedb")!s}' + collection_name: entity_description_embeddings + overwrite: true\ +""" + # Local Search LOCAL_SEARCH_TEXT_UNIT_PROP = 0.5 LOCAL_SEARCH_COMMUNITY_PROP = 0.1 diff --git a/graphrag/config/input_models/llm_parameters_input.py b/graphrag/config/input_models/llm_parameters_input.py index c89c6c09..b99d7e9b 100644 --- a/graphrag/config/input_models/llm_parameters_input.py +++ b/graphrag/config/input_models/llm_parameters_input.py @@ -20,7 +20,7 @@ class LLMParametersInput(TypedDict): api_version: NotRequired[str | None] organization: NotRequired[str | None] proxy: NotRequired[str | None] - cognitive_services_endpoint: NotRequired[str | None] + audience: NotRequired[str | None] deployment_name: NotRequired[str | None] model_supports_json: NotRequired[bool | str | None] tokens_per_minute: NotRequired[int | str | None] diff --git a/graphrag/config/models/llm_parameters.py b/graphrag/config/models/llm_parameters.py index df81138a..4f18ded0 100644 --- a/graphrag/config/models/llm_parameters.py +++ b/graphrag/config/models/llm_parameters.py @@ -52,8 +52,9 @@ class LLMParameters(BaseModel): proxy: str | None = Field( description="The proxy to use for the LLM service.", default=None ) - cognitive_services_endpoint: str | None = Field( - description="The endpoint to reach cognitives services.", default=None + audience: str | None = Field( + description="Azure resource URI to use with managed identity for the llm connection.", + default=None, ) deployment_name: str | None = Field( description="The deployment name to use for the LLM service.", default=None diff --git a/graphrag/config/models/text_embedding_config.py b/graphrag/config/models/text_embedding_config.py index abd2f2bf..815263bb 100644 --- a/graphrag/config/models/text_embedding_config.py +++ b/graphrag/config/models/text_embedding_config.py @@ -27,7 +27,7 @@ class TextEmbeddingConfig(LLMConfig): ) skip: list[str] = Field(description="The specific embeddings to skip.", default=[]) vector_store: dict | None = Field( - description="The vector storage configuration", default=None + description="The vector storage configuration", default=defs.VECTOR_STORE ) strategy: dict | None = Field( description="The override strategy to use.", default=None diff --git a/graphrag/index/cli.py b/graphrag/index/cli.py index 72835f66..d20166db 100644 --- a/graphrag/index/cli.py +++ b/graphrag/index/cli.py @@ -4,7 +4,6 @@ """Main definition.""" import asyncio -import json import logging import sys import time @@ -19,6 +18,7 @@ from graphrag.config import ( resolve_paths, ) from graphrag.logging import ProgressReporter, ReporterType, create_progress_reporter +from graphrag.utils.cli import redact from .emit.types import TableEmitterType from .graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT @@ -34,36 +34,6 @@ warnings.filterwarnings("ignore", message=".*NumbaDeprecationWarning.*") log = logging.getLogger(__name__) -def _redact(input: dict) -> str: - """Sanitize the config json.""" - - # Redact any sensitive configuration - def redact_dict(input: dict) -> dict: - if not isinstance(input, dict): - return input - - result = {} - for key, value in input.items(): - if key in { - "api_key", - "connection_string", - "container_name", - "organization", - }: - if value is not None: - result[key] = "==== REDACTED ====" - elif isinstance(value, dict): - result[key] = redact_dict(value) - elif isinstance(value, list): - result[key] = [redact_dict(i) for i in value] - else: - result[key] = value - return result - - redacted_dict = redact_dict(input) - return json.dumps(redacted_dict, indent=4) - - def _logger(reporter: ProgressReporter): def info(msg: str, verbose: bool = False): log.info(msg) @@ -140,7 +110,7 @@ def index_cli( info(f"Logging enabled at {log_path}", True) else: info( - f"Logging not enabled for config {_redact(config.model_dump())}", + f"Logging not enabled for config {redact(config.model_dump())}", True, ) @@ -149,7 +119,7 @@ def index_cli( info(f"Starting pipeline run for: {run_id}, {dryrun=}", verbose) info( - f"Using default configuration: {_redact(config.model_dump())}", + f"Using default configuration: {redact(config.model_dump())}", verbose, ) diff --git a/graphrag/index/init_content.py b/graphrag/index/init_content.py index 18646ca7..113cc17e 100644 --- a/graphrag/index/init_content.py +++ b/graphrag/index/init_content.py @@ -1,10 +1,11 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""Content for the init CLI command.""" + +"""Content for the init CLI command to generate a default configuration.""" import graphrag.config.defaults as defs -INIT_YAML = f""" +INIT_YAML = f"""\ encoding_model: cl100k_base skip_workflows: [] llm: @@ -12,6 +13,7 @@ llm: type: {defs.LLM_TYPE.value} # or azure_openai_chat model: {defs.LLM_MODEL} model_supports_json: true # recommended if this is available for your model. + # audience: "https://cognitiveservices.azure.com/.default" # max_tokens: {defs.LLM_MAX_TOKENS} # request_timeout: {defs.LLM_REQUEST_TIMEOUT} # api_base: https://.openai.azure.com @@ -40,12 +42,21 @@ embeddings: # target: {defs.EMBEDDING_TARGET.value} # or all # batch_size: {defs.EMBEDDING_BATCH_SIZE} # the number of documents to send in a single request # batch_max_tokens: {defs.EMBEDDING_BATCH_MAX_TOKENS} # the maximum number of tokens to send in a single request + vector_store:{defs.VECTOR_STORE} + # vector_store: # configuration for AI Search + # type: azure_ai_search + # url: + # api_key: # if not set, will attempt to use managed identity. Expects the `Search Index Data Contributor` RBAC role in this case. + # audience: # if using managed identity, the audience to use for the token + # overwrite: true # or false. Only applicable at index creation time + # collection_name: # the name of the collection to use llm: api_key: ${{GRAPHRAG_API_KEY}} type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding model: {defs.EMBEDDING_MODEL} # api_base: https://.openai.azure.com # api_version: 2024-02-15-preview + # audience: "https://cognitiveservices.azure.com/.default" # organization: # deployment_name: # tokens_per_minute: 150_000 # set a leaky bucket throttle @@ -160,6 +171,6 @@ global_search: # concurrency: {defs.GLOBAL_SEARCH_CONCURRENCY} """ -INIT_DOTENV = """ +INIT_DOTENV = """\ GRAPHRAG_API_KEY= """ diff --git a/graphrag/index/llm/load_llm.py b/graphrag/index/llm/load_llm.py index f321d767..a7eda31a 100644 --- a/graphrag/index/llm/load_llm.py +++ b/graphrag/index/llm/load_llm.py @@ -199,7 +199,7 @@ def _get_base_config(config: dict[str, Any]) -> dict[str, Any]: "model_supports_json": config.get("model_supports_json"), "concurrent_requests": config.get("concurrent_requests", 4), "encoding_model": config.get("encoding_model", "cl100k_base"), - "cognitive_services_endpoint": config.get("cognitive_services_endpoint"), + "audience": config.get("audience"), } diff --git a/graphrag/index/operations/embed_text/embed_text.py b/graphrag/index/operations/embed_text/embed_text.py index b77fdd6f..df4f400e 100644 --- a/graphrag/index/operations/embed_text/embed_text.py +++ b/graphrag/index/operations/embed_text/embed_text.py @@ -231,7 +231,7 @@ def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str: collection_names = vector_store_config.get("collection_names", {}) collection_name = collection_names.get(embedding_name, embedding_name) - msg = f"using {vector_store_config.get('type')} collection_name {collection_name} for embedding {embedding_name}" + msg = f"using vector store {vector_store_config.get('type')} with collection_name {collection_name} for embedding {embedding_name}" log.info(msg) return collection_name diff --git a/graphrag/llm/openai/create_openai_client.py b/graphrag/llm/openai/create_openai_client.py index 40d7d649..40d15cad 100644 --- a/graphrag/llm/openai/create_openai_client.py +++ b/graphrag/llm/openai/create_openai_client.py @@ -32,15 +32,16 @@ def create_openai_client( api_base, configuration.deployment_name, ) - if configuration.cognitive_services_endpoint is None: - cognitive_services_endpoint = "https://cognitiveservices.azure.com/.default" - else: - cognitive_services_endpoint = configuration.cognitive_services_endpoint + audience = ( + configuration.audience + if configuration.audience + else "https://cognitiveservices.azure.com/.default" + ) return AsyncAzureOpenAI( api_key=configuration.api_key if configuration.api_key else None, azure_ad_token_provider=get_bearer_token_provider( - DefaultAzureCredential(), cognitive_services_endpoint + DefaultAzureCredential(), audience ) if not configuration.api_key else None, diff --git a/graphrag/llm/openai/openai_configuration.py b/graphrag/llm/openai/openai_configuration.py index 1bcd5694..cbcc5409 100644 --- a/graphrag/llm/openai/openai_configuration.py +++ b/graphrag/llm/openai/openai_configuration.py @@ -26,7 +26,7 @@ class OpenAIConfiguration(Hashable, LLMConfig): _api_base: str | None _api_version: str | None - _cognitive_services_endpoint: str | None + _audience: str | None _deployment_name: str | None _organization: str | None _proxy: str | None @@ -103,7 +103,7 @@ class OpenAIConfiguration(Hashable, LLMConfig): self._deployment_name = lookup_str("deployment_name") self._api_base = lookup_str("api_base") self._api_version = lookup_str("api_version") - self._cognitive_services_endpoint = lookup_str("cognitive_services_endpoint") + self._audience = lookup_str("audience") self._organization = lookup_str("organization") self._proxy = lookup_str("proxy") self._n = lookup_int("n") @@ -156,9 +156,9 @@ class OpenAIConfiguration(Hashable, LLMConfig): return _non_blank(self._api_version) @property - def cognitive_services_endpoint(self) -> str | None: + def audience(self) -> str | None: """API version property definition.""" - return _non_blank(self._cognitive_services_endpoint) + return _non_blank(self._audience) @property def organization(self) -> str | None: diff --git a/graphrag/query/factories.py b/graphrag/query/factories.py index ae7803cd..7d07e1d7 100644 --- a/graphrag/query/factories.py +++ b/graphrag/query/factories.py @@ -44,17 +44,16 @@ def get_llm(config: GraphRagConfig) -> ChatOpenAI: **config.llm.model_dump(), "api_key": f"REDACTED,len={len(debug_llm_key)}", } - if config.llm.cognitive_services_endpoint is None: - cognitive_services_endpoint = "https://cognitiveservices.azure.com/.default" - else: - cognitive_services_endpoint = config.llm.cognitive_services_endpoint + audience = ( + config.llm.audience + if config.llm.audience + else "https://cognitiveservices.azure.com/.default" + ) print(f"creating llm client with {llm_debug_info}") # noqa T201 return ChatOpenAI( api_key=config.llm.api_key, azure_ad_token_provider=( - get_bearer_token_provider( - DefaultAzureCredential(), cognitive_services_endpoint - ) + get_bearer_token_provider(DefaultAzureCredential(), audience) if is_azure_client and not config.llm.api_key else None ), @@ -77,17 +76,15 @@ def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding: **config.embeddings.llm.model_dump(), "api_key": f"REDACTED,len={len(debug_embedding_api_key)}", } - if config.embeddings.llm.cognitive_services_endpoint is None: - cognitive_services_endpoint = "https://cognitiveservices.azure.com/.default" + if config.embeddings.llm.audience is None: + audience = "https://cognitiveservices.azure.com/.default" else: - cognitive_services_endpoint = config.embeddings.llm.cognitive_services_endpoint + audience = config.embeddings.llm.audience print(f"creating embedding llm client with {llm_debug_info}") # noqa T201 return OpenAIEmbedding( api_key=config.embeddings.llm.api_key, azure_ad_token_provider=( - get_bearer_token_provider( - DefaultAzureCredential(), cognitive_services_endpoint - ) + get_bearer_token_provider(DefaultAzureCredential(), audience) if is_azure_client and not config.embeddings.llm.api_key else None ), diff --git a/graphrag/utils/cli.py b/graphrag/utils/cli.py index 16e01284..5a9c1666 100644 --- a/graphrag/utils/cli.py +++ b/graphrag/utils/cli.py @@ -4,6 +4,7 @@ """CLI functions for the GraphRAG module.""" import argparse +import json from pathlib import Path @@ -21,3 +22,33 @@ def dir_exist(path): msg = f"Directory not found: {path}" raise argparse.ArgumentTypeError(msg) return path + + +def redact(config: dict) -> str: + """Sanitize secrets in a config object.""" + + # Redact any sensitive configuration + def redact_dict(config: dict) -> dict: + if not isinstance(config, dict): + return config + + result = {} + for key, value in config.items(): + if key in { + "api_key", + "connection_string", + "container_name", + "organization", + }: + if value is not None: + result[key] = "==== REDACTED ====" + elif isinstance(value, dict): + result[key] = redact_dict(value) + elif isinstance(value, list): + result[key] = [redact_dict(i) for i in value] + else: + result[key] = value + return result + + redacted_dict = redact_dict(config) + return json.dumps(redacted_dict, indent=4) diff --git a/graphrag/vector_stores/__init__.py b/graphrag/vector_stores/__init__.py index 764d51b1..560db063 100644 --- a/graphrag/vector_stores/__init__.py +++ b/graphrag/vector_stores/__init__.py @@ -3,15 +3,15 @@ """A module containing vector storage implementations.""" -from .azure_ai_search import AzureAISearch -from .base import BaseVectorStore, VectorStoreDocument, VectorStoreSearchResult -from .lancedb import LanceDBVectorStore -from .typing import VectorStoreFactory, VectorStoreType +from graphrag.vector_stores.base import ( + BaseVectorStore, + VectorStoreDocument, + VectorStoreSearchResult, +) +from graphrag.vector_stores.factory import VectorStoreFactory, VectorStoreType __all__ = [ - "AzureAISearch", "BaseVectorStore", - "LanceDBVectorStore", "VectorStoreDocument", "VectorStoreFactory", "VectorStoreSearchResult", diff --git a/graphrag/vector_stores/azure_ai_search.py b/graphrag/vector_stores/azure_ai_search.py index 46bddad6..38078082 100644 --- a/graphrag/vector_stores/azure_ai_search.py +++ b/graphrag/vector_stores/azure_ai_search.py @@ -35,13 +35,16 @@ from .base import ( class AzureAISearch(BaseVectorStore): - """The Azure AI Search vector storage implementation.""" + """Azure AI Search vector storage implementation.""" index_client: SearchIndexClient + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + def connect(self, **kwargs: Any) -> Any: - """Connect to the AzureAI vector store.""" - url = kwargs.get("url") + """Connect to AI search vector storage.""" + url = kwargs["url"] api_key = kwargs.get("api_key") audience = kwargs.get("audience") self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE) @@ -51,7 +54,7 @@ class AzureAISearch(BaseVectorStore): ) if url: - audience_arg = {"audience": audience} if audience else {} + audience_arg = {"audience": audience} if audience and not api_key else {} self.db_connection = SearchClient( endpoint=url, index_name=self.collection_name, @@ -68,18 +71,18 @@ class AzureAISearch(BaseVectorStore): **audience_arg, ) else: - not_supported_error = "AAISearchDBClient is not supported on local host." + not_supported_error = "Azure AI Search expects `url`." raise ValueError(not_supported_error) def load_documents( self, documents: list[VectorStoreDocument], overwrite: bool = True ) -> None: - """Load documents into the Azure AI Search index.""" + """Load documents into an Azure AI Search index.""" if overwrite: if self.collection_name in self.index_client.list_index_names(): self.index_client.delete_index(self.collection_name) - # Configure the vector search profile + # Configure vector search profile vector_search = VectorSearch( algorithms=[ HnswAlgorithmConfiguration( @@ -96,7 +99,7 @@ class AzureAISearch(BaseVectorStore): ) ], ) - + # Configure the index index = SearchIndex( name=self.collection_name, fields=[ @@ -120,7 +123,6 @@ class AzureAISearch(BaseVectorStore): ], vector_search=vector_search, ) - self.index_client.create_or_update_index( index, ) @@ -136,7 +138,7 @@ class AzureAISearch(BaseVectorStore): if doc.vector is not None ] - if batch and len(batch) > 0: + if len(batch) > 0: self.db_connection.upload_documents(batch) def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: diff --git a/graphrag/vector_stores/typing.py b/graphrag/vector_stores/factory.py similarity index 92% rename from graphrag/vector_stores/typing.py rename to graphrag/vector_stores/factory.py index 0b5a5cd1..564533ba 100644 --- a/graphrag/vector_stores/typing.py +++ b/graphrag/vector_stores/factory.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""A package containing the supported vector store types.""" +"""A package containing a factory and supported vector store types.""" from enum import Enum from typing import ClassVar @@ -18,7 +18,7 @@ class VectorStoreType(str, Enum): class VectorStoreFactory: - """A factory class for creating vector stores.""" + """A factory class for vector stores.""" vector_store_types: ClassVar[dict[str, type]] = {} diff --git a/graphrag/vector_stores/lancedb.py b/graphrag/vector_stores/lancedb.py index 0c9ea17f..9d8b24af 100644 --- a/graphrag/vector_stores/lancedb.py +++ b/graphrag/vector_stores/lancedb.py @@ -3,14 +3,14 @@ """The LanceDB vector storage implementation package.""" -import lancedb as lancedb # noqa: I001 (Ruff was breaking on this file imports, even tho they were sorted and passed local tests) -from graphrag.model.types import TextEmbedder - import json from typing import Any +import lancedb as lancedb import pyarrow as pa +from graphrag.model.types import TextEmbedder + from .base import ( BaseVectorStore, VectorStoreDocument, @@ -19,12 +19,21 @@ from .base import ( class LanceDBVectorStore(BaseVectorStore): - """The LanceDB vector storage implementation.""" + """LanceDB vector storage implementation.""" + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) def connect(self, **kwargs: Any) -> Any: """Connect to the vector storage.""" - db_uri = kwargs.get("db_uri", "./lancedb") - self.db_connection = lancedb.connect(db_uri) # type: ignore + self.db_connection = lancedb.connect(kwargs["db_uri"]) + if ( + self.collection_name + and self.collection_name in self.db_connection.table_names() + ): + self.document_collection = self.db_connection.open_table( + self.collection_name + ) def load_documents( self, documents: list[VectorStoreDocument], overwrite: bool = True @@ -50,6 +59,9 @@ class LanceDBVectorStore(BaseVectorStore): pa.field("vector", pa.list_(pa.float64())), pa.field("attributes", pa.string()), ]) + # NOTE: If modifying the next section of code, ensure that the schema remains the same. + # The pyarrow format of the 'vector' field may change if the order of operations is changed + # and will break vector search. if overwrite: if data: self.document_collection = self.db_connection.create_table( @@ -87,14 +99,18 @@ class LanceDBVectorStore(BaseVectorStore): """Perform a vector-based similarity search.""" if self.query_filter: docs = ( - self.document_collection.search(query=query_embedding) + self.document_collection.search( + query=query_embedding, vector_column_name="vector" + ) .where(self.query_filter, prefilter=True) .limit(k) .to_list() ) else: docs = ( - self.document_collection.search(query=query_embedding) + self.document_collection.search( + query=query_embedding, vector_column_name="vector" + ) .limit(k) .to_list() ) diff --git a/tests/fixtures/azure/settings.yml b/tests/fixtures/azure/settings.yml index 2ebe953f..cb61192b 100644 --- a/tests/fixtures/azure/settings.yml +++ b/tests/fixtures/azure/settings.yml @@ -7,8 +7,6 @@ embeddings: url: ${AZURE_AI_SEARCH_URL_ENDPOINT} api_key: ${AZURE_AI_SEARCH_API_KEY} collection_name: "azure_ci" - query_collection_name: "azure_ci_query" - entity_name_description: title_column: "name" diff --git a/tests/fixtures/min-csv/settings.yml b/tests/fixtures/min-csv/settings.yml index a6393c02..06543742 100644 --- a/tests/fixtures/min-csv/settings.yml +++ b/tests/fixtures/min-csv/settings.yml @@ -4,7 +4,8 @@ input: embeddings: vector_store: type: "lancedb" - uri_db: "./tests/fixtures/min-csv/lancedb" + db_uri: "./tests/fixtures/min-csv/lancedb" + collection_name: "lancedb_ci" store_in_table: True entity_name_description: diff --git a/tests/fixtures/text/settings.yml b/tests/fixtures/text/settings.yml index 4076e8fb..37d7f09c 100644 --- a/tests/fixtures/text/settings.yml +++ b/tests/fixtures/text/settings.yml @@ -7,9 +7,7 @@ embeddings: url: ${AZURE_AI_SEARCH_URL_ENDPOINT} api_key: ${AZURE_AI_SEARCH_API_KEY} collection_name: "simple_text_ci" - query_collection_name: "simple_text_ci_query" store_in_table: True - entity_name_description: title_column: "name"