Multi-index querying for API layer (#1644)

* added multi-global-query function header

* ported over code for merging dataframes

* added connection to global streaming api function

* added function header for update context helper

* implemented and incorperated update_context function

* Updated to make sure 'parent' column in final_communities gets incremented for multi index.

* first cut at multi_local_seach function

* several minor changes and fixes

* Updated multi index local search.

* Cleaned up code.

* fixed lambda function ruff errors

* fixed more ruff errors

* moved query api helpers to util file

* moved index api helpers to util file

* merged in code left out of conflict

* changed GraphRagConfig object to support lists of vector stores

* Updated with fixes for multi_local_search.

* Minor updates.

* Minor updates.

* Updates for ruff check.

* Minor updates.

* removed redundant vector_store_configs arg

* ruff formatting changes

* semversioner

* Minor fix.

* spellcheck fixes

* ruff

* test fix for cicd errors

* another test fix

* added explicit typing for ci tests

* added dict type check for vector_store during indexing

* more ruff fixes

* moved type check

* Removed streaming. Added multi drift and basic searches.

* Formatting changes.

* Updates for pyright.

* Update for ruff.

* Ruff formatted.

* first cut at fixing vector store typing errors

* got multi local search working with new config

* ruff and test fixes

* added fix for embeddings type error

* renamed multi index api functions

* ruff

* convert config model to dict[VectorStoreConfig]

* modified tests to support new vector_store model

* ruff fixes

* changed some test setups to match new model

* changed ci/cd settings files to match new structure

* Fix stderror check

* fixed bug in vector_store_config validation

* ruff

* add database_name field to vectorstoreconfig

* removed print statements

* small refactoring for PR comments

* modified default config in test

* modified vector store config unit test

---------

Co-authored-by: dorbaker <dorbaker@microsoft.com>
Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
KennyZhang1 2025-01-27 17:26:38 -05:00 committed by GitHub
parent 053bf60162
commit 1bbce33f42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1112 additions and 207 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "implemented multi-index querying for api layer"
}

View File

@ -18,6 +18,7 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.run.run_workflows import run_workflows
from graphrag.index.typing import PipelineRunResult
from graphrag.logger.base import ProgressLogger
from graphrag.utils.api import get_workflows_list
log = logging.getLogger(__name__)
@ -60,7 +61,7 @@ async def build_index(
if memory_profile:
log.warning("New pipeline does not yet support memory profiling.")
workflows = _get_workflows_list(config)
workflows = get_workflows_list(config)
async for output in run_workflows(
workflows,
@ -79,20 +80,3 @@ async def build_index(
progress_logger.info(str(output.result))
return outputs
def _get_workflows_list(config: GraphRagConfig) -> list[str]:
return [
"create_base_text_units",
"create_final_documents",
"extract_graph",
"compute_communities",
"create_final_entities",
"create_final_relationships",
"create_final_nodes",
"create_final_communities",
*(["create_final_covariates"] if config.claim_extraction.enabled else []),
"create_final_text_units",
"create_final_community_reports",
"generate_text_embeddings",
]

File diff suppressed because it is too large Load Diff

View File

@ -106,6 +106,7 @@ VECTOR_STORE_TYPE = VectorStoreType.LanceDB.value
VECTOR_STORE_DB_URI = str(Path(OUTPUT_BASE_DIR) / "lancedb")
VECTOR_STORE_CONTAINER_NAME = "default"
VECTOR_STORE_OVERWRITE = True
VECTOR_STORE_INDEX_NAME = "output"
# Local Search
LOCAL_SEARCH_TEXT_UNIT_PROP = 0.5

View File

@ -57,7 +57,14 @@ def get_embedding_settings(
embeddings_llm_settings = settings.get_language_model_config(
settings.embeddings.model_id
)
vector_store_settings = settings.vector_store.model_dump()
num_entries = len(settings.vector_store)
if num_entries == 1:
store = next(iter(settings.vector_store.values()))
vector_store_settings = store.model_dump()
else:
# The vector_store dict should only have more than one entry for multi-index query
vector_store_settings = None
if vector_store_settings is None:
return {
"strategy": settings.embeddings.resolved_strategy(embeddings_llm_settings)

View File

@ -40,6 +40,7 @@ models:
# deployment_name: <azure_model_deployment_name>
vector_store:
{defs.VECTOR_STORE_INDEX_NAME}:
type: {defs.VECTOR_STORE_TYPE}
db_uri: {defs.VECTOR_STORE_DB_URI}
container_name: {defs.VECTOR_STORE_CONTAINER_NAME}

View File

@ -224,20 +224,20 @@ class GraphRagConfig(BaseModel):
)
"""The basic search configuration."""
vector_store: VectorStoreConfig = Field(
description="The vector store configuration.", default=VectorStoreConfig()
vector_store: dict[str, VectorStoreConfig] = Field(
description="The vector store configuration.",
default={"output": VectorStoreConfig()},
)
"""The vector store configuration."""
def _validate_vector_store_db_uri(self) -> None:
"""Validate the vector store configuration."""
if self.vector_store.type == VectorStoreType.LanceDB:
if not self.vector_store.db_uri or self.vector_store.db_uri.strip == "":
msg = "Vector store URI is required for LanceDB. Please rerun `graphrag init` and set the vector store configuration."
raise ValueError(msg)
self.vector_store.db_uri = str(
(Path(self.root_dir) / self.vector_store.db_uri).resolve()
)
for store in self.vector_store.values():
if store.type == VectorStoreType.LanceDB:
if not store.db_uri or store.db_uri.strip == "":
msg = "Vector store URI is required for LanceDB. Please rerun `graphrag init` and set the vector store configuration."
raise ValueError(msg)
store.db_uri = str((Path(self.root_dir) / store.db_uri).resolve())
def get_language_model_config(self, model_id: str) -> LanguageModelConfig:
"""Get a model configuration by ID.

View File

@ -45,10 +45,16 @@ class VectorStoreConfig(BaseModel):
msg = "vector_store.url is required when vector_store.type == azure_ai_search. Please rerun `graphrag init` and select the correct vector store type."
raise ValueError(msg)
if self.type != VectorStoreType.AzureAISearch and (
if self.type == VectorStoreType.CosmosDB and (
self.url is None or self.url.strip() == ""
):
msg = "vector_store.url is required when vector_store.type == cosmos_db. Please rerun `graphrag init` and select the correct vector store type."
raise ValueError(msg)
if self.type == VectorStoreType.LanceDB and (
self.url is not None and self.url.strip() != ""
):
msg = "vector_store.url is only used when vector_store.type == azure_ai_search. Please rerun `graphrag init` and select the correct vector store type."
msg = "vector_store.url is only used when vector_store.type == azure_ai_search or vector_store.type == cosmos_db. Please rerun `graphrag init` and select the correct vector store type."
raise ValueError(msg)
api_key: str | None = Field(
@ -62,10 +68,14 @@ class VectorStoreConfig(BaseModel):
)
container_name: str = Field(
description="The database name to use.",
description="The container name to use.",
default=defs.VECTOR_STORE_CONTAINER_NAME,
)
database_name: str | None = Field(
description="The database name to use when type == cosmos_db.", default=None
)
overwrite: bool = Field(
description="Overwrite the existing data.", default=defs.VECTOR_STORE_OVERWRITE
)

251
graphrag/utils/api.py Normal file
View File

@ -0,0 +1,251 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""API functions for the GraphRAG module."""
from pathlib import Path
from typing import Any
from graphrag.config.embeddings import create_collection_name
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.model.types import TextEmbedder
from graphrag.vector_stores.base import (
BaseVectorStore,
VectorStoreDocument,
VectorStoreSearchResult,
)
from graphrag.vector_stores.factory import VectorStoreFactory
class MultiVectorStore(BaseVectorStore):
"""Multi Vector Store wrapper implementation."""
def __init__(
self,
embedding_stores: list[BaseVectorStore],
index_names: list[str],
):
self.embedding_stores = embedding_stores
self.index_names = index_names
def load_documents(
self, documents: list[VectorStoreDocument], overwrite: bool = True
) -> None:
"""Load documents into the vector store."""
msg = "load_documents method not implemented"
raise NotImplementedError(msg)
def connect(self, **kwargs: Any) -> Any:
"""Connect to vector storage."""
msg = "connect method not implemented"
raise NotImplementedError(msg)
def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
"""Build a query filter to filter documents by id."""
msg = "filter_by_id method not implemented"
raise NotImplementedError(msg)
def search_by_id(self, id: str) -> VectorStoreDocument:
"""Search for a document by id."""
search_index_id = id.split("-")[0]
search_index_name = id.split("-")[1]
for index_name, embedding_store in zip(
self.index_names, self.embedding_stores, strict=False
):
if index_name == search_index_name:
return embedding_store.search_by_id(search_index_id)
else:
message = f"Index {search_index_name} not found."
raise ValueError(message)
def similarity_search_by_vector(
self, query_embedding: list[float], k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
"""Perform a vector-based similarity search."""
all_results = []
for index_name, embedding_store in zip(
self.index_names, self.embedding_stores, strict=False
):
results = embedding_store.similarity_search_by_vector(
query_embedding=query_embedding, k=k
)
mod_results = []
for r in results:
r.document.id = str(r.document.id) + f"-{index_name}"
mod_results += [r]
all_results += mod_results
return sorted(all_results, key=lambda x: x.score, reverse=True)[:k]
def similarity_search_by_text(
self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
"""Perform a text-based similarity search."""
query_embedding = text_embedder(text)
if query_embedding:
return self.similarity_search_by_vector(
query_embedding=query_embedding, k=k
)
return []
def get_embedding_store(
config_args: dict[str, dict],
embedding_name: str,
) -> BaseVectorStore:
"""Get the embedding description store."""
num_indexes = len(config_args)
embedding_stores = []
index_names = []
for index, store in config_args.items():
vector_store_type = store["type"]
collection_name = create_collection_name(
store.get("container_name", "default"), embedding_name
)
embedding_store = VectorStoreFactory().create_vector_store(
vector_store_type=vector_store_type,
kwargs={**store, "collection_name": collection_name},
)
embedding_store.connect(**store)
# If there is only a single index, return the embedding store directly
if num_indexes == 1:
return embedding_store
embedding_stores.append(embedding_store)
index_names.append(index)
return MultiVectorStore(embedding_stores, index_names)
def reformat_context_data(context_data: dict) -> dict:
"""
Reformats context_data for all query responses.
Reformats a dictionary of dataframes into a dictionary of lists.
One list entry for each record. Records are grouped by original
dictionary keys.
Note: depending on which query algorithm is used, the context_data may not
contain the same information (keys). In this case, the default behavior will be to
set these keys as empty lists to preserve a standard output format.
"""
final_format = {
"reports": [],
"entities": [],
"relationships": [],
"claims": [],
"sources": [],
}
for key in context_data:
records = (
context_data[key].to_dict(orient="records")
if context_data[key] is not None and not isinstance(context_data[key], dict)
else context_data[key]
)
if len(records) < 1:
continue
final_format[key] = records
return final_format
def update_context_data(
context_data: Any,
links: dict[str, Any],
) -> Any:
"""
Update context data with the links dict so that it contains both the index name and community id.
Parameters
----------
- context_data (str | list[pd.DataFrame] | dict[str, pd.DataFrame]): The context data to update.
- links (dict[str, Any]): A dictionary of links to the original dataframes.
Returns
-------
str | list[pd.DataFrame] | dict[str, pd.DataFrame]: The updated context data.
"""
updated_context_data = {}
for key in context_data:
updated_entry = []
if key == "reports":
updated_entry = [
dict(
{k: entry[k] for k in entry},
index_name=links["community_reports"][int(entry["id"])][
"index_name"
],
index_id=links["community_reports"][int(entry["id"])]["id"],
)
for entry in context_data[key]
]
if key == "entities":
updated_entry = [
dict(
{k: entry[k] for k in entry},
entity=entry["entity"].split("-")[0],
index_name=links["entities"][int(entry["id"])]["index_name"],
index_id=links["entities"][int(entry["id"])]["id"],
)
for entry in context_data[key]
]
if key == "relationships":
updated_entry = [
dict(
{k: entry[k] for k in entry},
source=entry["source"].split("-")[0],
target=entry["target"].split("-")[0],
index_name=links["relationships"][int(entry["id"])]["index_name"],
index_id=links["relationships"][int(entry["id"])]["id"],
)
for entry in context_data[key]
]
if key == "claims":
updated_entry = [
dict(
{k: entry[k] for k in entry},
entity=entry["entity"].split("-")[0],
index_name=links["covariates"][int(entry["id"])]["index_name"],
index_id=links["covariates"][int(entry["id"])]["id"],
)
for entry in context_data[key]
]
if key == "sources":
updated_entry = [
dict(
{k: entry[k] for k in entry},
index_name=links["text_units"][int(entry["id"])]["index_name"],
index_id=links["text_units"][int(entry["id"])]["id"],
)
for entry in context_data[key]
]
updated_context_data[key] = updated_entry
return updated_context_data
def load_search_prompt(root_dir: str, prompt_config: str | None) -> str | None:
"""
Load the search prompt from disk if configured.
If not, leave it empty - the search functions will load their defaults.
"""
if prompt_config:
prompt_file = Path(root_dir) / prompt_config
if prompt_file.exists():
return prompt_file.read_bytes().decode(encoding="utf-8")
return None
def get_workflows_list(config: GraphRagConfig) -> list[str]:
"""Return a list of workflows for the indexing pipeline."""
return [
"create_base_text_units",
"create_final_documents",
"extract_graph",
"compute_communities",
"create_final_entities",
"create_final_relationships",
"create_final_nodes",
"create_final_communities",
*(["create_final_covariates"] if config.claim_extraction.enabled else []),
"create_final_text_units",
"create_final_community_reports",
"generate_text_embeddings",
]

View File

@ -3,10 +3,11 @@ claim_extraction:
embeddings:
vector_store:
type: "azure_ai_search"
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
api_key: ${AZURE_AI_SEARCH_API_KEY}
container_name: "azure_ci"
output:
type: "azure_ai_search"
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
api_key: ${AZURE_AI_SEARCH_API_KEY}
container_name: "azure_ci"
input:
type: blob

View File

@ -26,10 +26,11 @@ models:
async_mode: threaded
vector_store:
type: "lancedb"
db_uri: "./tests/fixtures/min-csv/lancedb"
container_name: "lancedb_ci"
overwrite: True
output:
type: "lancedb"
db_uri: "./tests/fixtures/min-csv/lancedb"
container_name: "lancedb_ci"
overwrite: True
input:
file_type: csv

View File

@ -26,10 +26,11 @@ models:
async_mode: threaded
vector_store:
type: "azure_ai_search"
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
api_key: ${AZURE_AI_SEARCH_API_KEY}
container_name: "simple_text_ci"
output:
type: "azure_ai_search"
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
api_key: ${AZURE_AI_SEARCH_API_KEY}
container_name: "simple_text_ci"
claim_extraction:
enabled: true

View File

@ -288,13 +288,6 @@ class TestIndexer:
result = self.__run_query(root, query)
print(f"Query: {query}\nResponse: {result.stdout}")
# Check stderr because lancedb logs path creating as WARN which leads to false negatives
stderror = (
result.stderr if "No existing dataset at" not in result.stderr else ""
)
assert stderror == "" or stderror.replace("\n", "") in KNOWN_WARNINGS, (
f"Query failed with error: {stderror}"
)
assert result.returncode == 0, "Query failed"
assert result.stdout is not None, "Query returned no output"
assert len(result.stdout) > 0, "Query returned empty output"

View File

@ -50,13 +50,16 @@ DEFAULT_MODEL_CONFIG = {
DEFAULT_GRAPHRAG_CONFIG_SETTINGS = {
"models": DEFAULT_MODEL_CONFIG,
"vector_store": {
"type": defs.VECTOR_STORE_TYPE,
"db_uri": defs.VECTOR_STORE_DB_URI,
"container_name": defs.VECTOR_STORE_CONTAINER_NAME,
"overwrite": defs.VECTOR_STORE_OVERWRITE,
"url": None,
"api_key": None,
"audience": None,
"output": {
"type": defs.VECTOR_STORE_TYPE,
"db_uri": defs.VECTOR_STORE_DB_URI,
"container_name": defs.VECTOR_STORE_CONTAINER_NAME,
"overwrite": defs.VECTOR_STORE_OVERWRITE,
"url": None,
"api_key": None,
"audience": None,
"database_name": None,
},
},
"reporting": {
"type": defs.REPORTING_TYPE,
@ -283,14 +286,24 @@ def assert_language_model_configs(
assert expected.responses is None
def assert_vector_store_configs(actual: VectorStoreConfig, expected: VectorStoreConfig):
assert actual.type == expected.type
assert actual.db_uri == expected.db_uri
assert actual.container_name == expected.container_name
assert actual.overwrite == expected.overwrite
assert actual.url == expected.url
assert actual.api_key == expected.api_key
assert actual.audience == expected.audience
def assert_vector_store_configs(
actual: dict[str, VectorStoreConfig],
expected: dict[str, VectorStoreConfig],
):
assert type(actual) is type(expected)
assert len(actual) == len(expected)
for (index_a, store_a), (index_e, store_e) in zip(
actual.items(), expected.items(), strict=True
):
assert index_a == index_e
assert store_a.type == store_e.type
assert store_a.db_uri == store_e.db_uri
assert store_a.url == store_e.url
assert store_a.api_key == store_e.api_key
assert store_a.audience == store_e.audience
assert store_a.container_name == store_e.container_name
assert store_a.overwrite == store_e.overwrite
assert store_a.database_name == store_e.database_name
def assert_reporting_configs(