graphrag/graphrag/api/query.py
Copilot e84df28e64
Some checks are pending
gh-pages / build (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.11) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.10) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.11) (push) Waiting to run
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Publish (pypi) / Upload release to PyPI (push) Waiting to run
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Spellcheck / spellcheck (push) Waiting to run
Improve internal logging functionality by using Python's standard logging module (#1956)
* Initial plan for issue

* Implement standard logging module and integrate with existing loggers

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Add test cases and improve documentation for standard logging

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Apply ruff formatting and add semversioner file for logging improvements

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Remove custom logger classes and refactor to use standard logging only

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Apply ruff formatting to resolve CI/CD test failures

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Add semversioner file and fix linting issues

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* ruff fixes

* fix spelling error

* Remove StandardProgressLogger and refactor to use standard logging

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Remove LoggerFactory and custom loggers, refactor to use standard logging

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Fix pyright error: use logger.info() instead of calling logger as function in cosmosdb_pipeline_storage.py

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* ruff fixes

* Remove deprecated logger files that were marked as deprecated placeholders

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Replace custom get_logger with standard Python logging

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Fix linting issues found by ruff check --fix

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* apply ruff check fixes

* add word to dictionary

* Fix type checker error in ModelManager.__new__ method

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Refactor multiple logging.getLogger() calls to use single logger per file

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Remove progress_logger parameter from build_index() and logger parameter from generate_indexing_prompts()

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Remove logger parameter from run_pipeline and standardize logger naming

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Replace logger parameter with log_level parameter in CLI commands

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Fix import ordering in notebook files to pass poetry poe check

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Remove --logger parameter from smoke test command

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Fix Windows CI/CD issue with log file cleanup in tests

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Add StreamHandler to root logger in __main__.py for CLI logging

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Only add StreamHandler if root logger doesn't have existing StreamHandler

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Fix import ordering in notebook files to pass ruff checks

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Replace logging.StreamHandler with colorlog.StreamHandler for colorized log output

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Regenerate poetry.lock file after adding colorlog dependency

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Fix import ordering in notebook files to pass ruff checks

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* move printing of dataframes to debug level

* remove colorlog for now

* Refactor workflow callbacks to inherit from logging.Handler

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Fix linting issues in workflow callback handlers

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Fix pyright type errors in blob and file workflow callbacks

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Refactor pipeline logging to use pure logging.Handler subclasses

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Rename workflow callback classes to workflow logger classes and move to logger directory

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* update dictionary

* apply ruff fixes

* fix function name

* simplify logger code

* update

* Remove error, warning, and log methods from WorkflowCallbacks and replace with standard logging

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* ruff fixes

* Fix pyright errors by removing WorkflowCallbacks from strategy type signatures

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Remove ConsoleWorkflowLogger and apply consistent formatter to all handlers

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* apply ruff fixes

* Refactor pipeline_logger.py to use standard FileHandler and remove FileWorkflowLogger

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Remove conditional azure import checks from blob_workflow_logger.py

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Fix pyright type checking errors in mock_provider.py and utils.py

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Run ruff check --fix to fix import ordering in notebooks

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Merge configure_logging and create_pipeline_logger into init_loggers function

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Remove configure_logging and create_pipeline_logger functions, replace all usage with init_loggers

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* apply ruff fixes

* cleanup unused code

* Update init_loggers to accept GraphRagConfig instead of ReportingConfig

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* apply ruff check fixes

* Fix test failures by providing valid GraphRagConfig with required model configurations

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* apply ruff fixes

* remove logging_workflow_callback

* cleanup logging messages

* Add logging to track progress of pandas DataFrame apply operation in create_base_text_units

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* cleanup logger logic throughout codebase

* update

* more cleanup of old loggers

* small logger cleanup

* final code cleanup and added loggers to query

* add verbose logging to query

* minor code cleanup

* Fix broken unit tests for chunk_text and standard_logging

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* apply ruff fixes

* Fix test_chunk_text by mocking progress_ticker function instead of ProgressTicker class

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* remove unnecessary logger

* remove rich and fix type annotation

* revert test formatting changes my by copilot

* promote graphrag logs to root logger

* add correct semversioner file

* revert change to file

* revert formatting changes that have no effect

* fix changes after merge with main

* revert unnecessary copilot changes

* remove whitespace

* cleanup docstring

* simplify some logic with less code

* update poetry lock file

* ruff fixes

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>
Co-authored-by: Josh Bradley <joshbradley@microsoft.com>
2025-07-09 18:29:03 -06:00

1212 lines
46 KiB
Python

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Query Engine API.
This API provides access to the query engine of graphrag, allowing external applications
to hook into graphrag and run queries over a knowledge graph generated by graphrag.
Contains the following functions:
- global_search: Perform a global search.
- global_search_streaming: Perform a global search and stream results back.
- local_search: Perform a local search.
- local_search_streaming: Perform a local search and stream results back.
WARNING: This API is under development and may undergo changes in future releases.
Backwards compatibility is not guaranteed at this time.
"""
import logging
from collections.abc import AsyncGenerator
from typing import Any
import pandas as pd
from pydantic import validate_call
from graphrag.callbacks.noop_query_callbacks import NoopQueryCallbacks
from graphrag.callbacks.query_callbacks import QueryCallbacks
from graphrag.config.embeddings import (
community_full_content_embedding,
entity_description_embedding,
text_unit_text_embedding,
)
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.logger.standard_logging import init_loggers
from graphrag.query.factory import (
get_basic_search_engine,
get_drift_search_engine,
get_global_search_engine,
get_local_search_engine,
)
from graphrag.query.indexer_adapters import (
read_indexer_communities,
read_indexer_covariates,
read_indexer_entities,
read_indexer_relationships,
read_indexer_report_embeddings,
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.utils.api import (
get_embedding_store,
load_search_prompt,
truncate,
update_context_data,
)
from graphrag.utils.cli import redact
# Initialize standard logger
logger = logging.getLogger(__name__)
@validate_call(config={"arbitrary_types_allowed": True})
async def global_search(
config: GraphRagConfig,
entities: pd.DataFrame,
communities: pd.DataFrame,
community_reports: pd.DataFrame,
community_level: int | None,
dynamic_community_selection: bool,
response_type: str,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
]:
"""Perform a global search and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
- communities (pd.DataFrame): A DataFrame containing the final communities (from communities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
- community_level (int): The community level to search at.
- dynamic_community_selection (bool): Enable dynamic community selection instead of using all community reports at a fixed level. Note that you can still provide community_level cap the maximum level to search.
- response_type (str): The type of response to return.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
callbacks = callbacks or []
full_response = ""
context_data = {}
def on_context(context: Any) -> None:
nonlocal context_data
context_data = context
local_callbacks = NoopQueryCallbacks()
local_callbacks.on_context = on_context
callbacks.append(local_callbacks)
logger.debug("Executing global search query: %s", query)
async for chunk in global_search_streaming(
config=config,
entities=entities,
communities=communities,
community_reports=community_reports,
community_level=community_level,
dynamic_community_selection=dynamic_community_selection,
response_type=response_type,
query=query,
callbacks=callbacks,
):
full_response += chunk
logger.debug("Query response: %s", truncate(full_response, 400))
return full_response, context_data
@validate_call(config={"arbitrary_types_allowed": True})
def global_search_streaming(
config: GraphRagConfig,
entities: pd.DataFrame,
communities: pd.DataFrame,
community_reports: pd.DataFrame,
community_level: int | None,
dynamic_community_selection: bool,
response_type: str,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> AsyncGenerator:
"""Perform a global search and return the context data and response via a generator.
Context data is returned as a dictionary of lists, with one list entry for each record.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
- communities (pd.DataFrame): A DataFrame containing the final communities (from communities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
- community_level (int): The community level to search at.
- dynamic_community_selection (bool): Enable dynamic community selection instead of using all community reports at a fixed level. Note that you can still provide community_level cap the maximum level to search.
- response_type (str): The type of response to return.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
communities_ = read_indexer_communities(communities, community_reports)
reports = read_indexer_reports(
community_reports,
communities,
community_level=community_level,
dynamic_community_selection=dynamic_community_selection,
)
entities_ = read_indexer_entities(
entities, communities, community_level=community_level
)
map_prompt = load_search_prompt(config.root_dir, config.global_search.map_prompt)
reduce_prompt = load_search_prompt(
config.root_dir, config.global_search.reduce_prompt
)
knowledge_prompt = load_search_prompt(
config.root_dir, config.global_search.knowledge_prompt
)
logger.debug("Executing streaming global search query: %s", query)
search_engine = get_global_search_engine(
config,
reports=reports,
entities=entities_,
communities=communities_,
response_type=response_type,
dynamic_community_selection=dynamic_community_selection,
map_system_prompt=map_prompt,
reduce_system_prompt=reduce_prompt,
general_knowledge_inclusion_prompt=knowledge_prompt,
callbacks=callbacks,
)
return search_engine.stream_search(query=query)
@validate_call(config={"arbitrary_types_allowed": True})
async def multi_index_global_search(
config: GraphRagConfig,
entities_list: list[pd.DataFrame],
communities_list: list[pd.DataFrame],
community_reports_list: list[pd.DataFrame],
index_names: list[str],
community_level: int | None,
dynamic_community_selection: bool,
response_type: str,
streaming: bool,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
]:
"""Perform a global search across multiple indexes and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from entities.parquet)
- communities_list (list[pd.DataFrame]): A list of DataFrames containing the final communities (from communities.parquet)
- community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from community_reports.parquet)
- index_names (list[str]): A list of index names.
- community_level (int): The community level to search at.
- dynamic_community_selection (bool): Enable dynamic community selection instead of using all community reports at a fixed level. Note that you can still provide community_level cap the maximum level to search.
- response_type (str): The type of response to return.
- streaming (bool): Whether to stream the results or not.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
# Streaming not supported yet
if streaming:
message = "Streaming not yet implemented for multi_global_search"
raise NotImplementedError(message)
links = {
"communities": {},
"community_reports": {},
"entities": {},
}
max_vals = {
"communities": -1,
"community_reports": -1,
"entities": -1,
}
communities_dfs = []
community_reports_dfs = []
entities_dfs = []
for idx, index_name in enumerate(index_names):
# Prepare each index's community reports dataframe for merging
community_reports_df = community_reports_list[idx]
community_reports_df["community"] = community_reports_df["community"].astype(
int
)
for i in community_reports_df["community"]:
links["community_reports"][i + max_vals["community_reports"] + 1] = {
"index_name": index_name,
"id": str(i),
}
community_reports_df["community"] += max_vals["community_reports"] + 1
community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1
max_vals["community_reports"] = int(community_reports_df["community"].max())
community_reports_dfs.append(community_reports_df)
# Prepare each index's communities dataframe for merging
communities_df = communities_list[idx]
communities_df["community"] = communities_df["community"].astype(int)
communities_df["parent"] = communities_df["parent"].astype(int)
for i in communities_df["community"]:
links["communities"][i + max_vals["communities"] + 1] = {
"index_name": index_name,
"id": str(i),
}
communities_df["community"] += max_vals["communities"] + 1
communities_df["parent"] = communities_df["parent"].apply(
lambda x: x if x == -1 else x + max_vals["communities"] + 1
)
communities_df["human_readable_id"] += max_vals["communities"] + 1
# concat the index name to the entity_ids, since this is used for joining later
communities_df["entity_ids"] = communities_df["entity_ids"].apply(
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
)
max_vals["communities"] = int(communities_df["community"].max())
communities_dfs.append(communities_df)
# Prepare each index's entities dataframe for merging
entities_df = entities_list[idx]
for i in entities_df["human_readable_id"]:
links["entities"][i + max_vals["entities"] + 1] = {
"index_name": index_name,
"id": i,
}
entities_df["human_readable_id"] += max_vals["entities"] + 1
entities_df["title"] = entities_df["title"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply(
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
)
max_vals["entities"] = int(entities_df["human_readable_id"].max())
entities_dfs.append(entities_df)
# Merge the dataframes
community_reports_combined = pd.concat(
community_reports_dfs, axis=0, ignore_index=True, sort=False
)
entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False)
communities_combined = pd.concat(
communities_dfs, axis=0, ignore_index=True, sort=False
)
logger.debug("Executing multi-index global search query: %s", query)
result = await global_search(
config,
entities=entities_combined,
communities=communities_combined,
community_reports=community_reports_combined,
community_level=community_level,
dynamic_community_selection=dynamic_community_selection,
response_type=response_type,
query=query,
callbacks=callbacks,
)
# Update the context data by linking index names and community ids
context = update_context_data(result[1], links)
logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore
return (result[0], context)
@validate_call(config={"arbitrary_types_allowed": True})
async def local_search(
config: GraphRagConfig,
entities: pd.DataFrame,
communities: pd.DataFrame,
community_reports: pd.DataFrame,
text_units: pd.DataFrame,
relationships: pd.DataFrame,
covariates: pd.DataFrame | None,
community_level: int,
response_type: str,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
]:
"""Perform a local search and return the context data and response.
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from relationships.parquet)
- covariates (pd.DataFrame): A DataFrame containing the final covariates (from covariates.parquet)
- community_level (int): The community level to search at.
- response_type (str): The response type to return.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
callbacks = callbacks or []
full_response = ""
context_data = {}
def on_context(context: Any) -> None:
nonlocal context_data
context_data = context
local_callbacks = NoopQueryCallbacks()
local_callbacks.on_context = on_context
callbacks.append(local_callbacks)
logger.debug("Executing local search query: %s", query)
async for chunk in local_search_streaming(
config=config,
entities=entities,
communities=communities,
community_reports=community_reports,
text_units=text_units,
relationships=relationships,
covariates=covariates,
community_level=community_level,
response_type=response_type,
query=query,
callbacks=callbacks,
):
full_response += chunk
logger.debug("Query response: %s", truncate(full_response, 400))
return full_response, context_data
@validate_call(config={"arbitrary_types_allowed": True})
def local_search_streaming(
config: GraphRagConfig,
entities: pd.DataFrame,
communities: pd.DataFrame,
community_reports: pd.DataFrame,
text_units: pd.DataFrame,
relationships: pd.DataFrame,
covariates: pd.DataFrame | None,
community_level: int,
response_type: str,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> AsyncGenerator:
"""Perform a local search and return the context data and response via a generator.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from relationships.parquet)
- covariates (pd.DataFrame): A DataFrame containing the final covariates (from covariates.parquet)
- community_level (int): The community level to search at.
- response_type (str): The response type to return.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
vector_store_args = {}
for index, store in config.vector_store.items():
vector_store_args[index] = store.model_dump()
msg = f"Vector Store Args: {redact(vector_store_args)}"
logger.debug(msg)
description_embedding_store = get_embedding_store(
config_args=vector_store_args,
embedding_name=entity_description_embedding,
)
entities_ = read_indexer_entities(entities, communities, community_level)
covariates_ = read_indexer_covariates(covariates) if covariates is not None else []
prompt = load_search_prompt(config.root_dir, config.local_search.prompt)
logger.debug("Executing streaming local search query: %s", query)
search_engine = get_local_search_engine(
config=config,
reports=read_indexer_reports(community_reports, communities, community_level),
text_units=read_indexer_text_units(text_units),
entities=entities_,
relationships=read_indexer_relationships(relationships),
covariates={"claims": covariates_},
description_embedding_store=description_embedding_store,
response_type=response_type,
system_prompt=prompt,
callbacks=callbacks,
)
return search_engine.stream_search(query=query)
@validate_call(config={"arbitrary_types_allowed": True})
async def multi_index_local_search(
config: GraphRagConfig,
entities_list: list[pd.DataFrame],
communities_list: list[pd.DataFrame],
community_reports_list: list[pd.DataFrame],
text_units_list: list[pd.DataFrame],
relationships_list: list[pd.DataFrame],
covariates_list: list[pd.DataFrame] | None,
index_names: list[str],
community_level: int,
response_type: str,
streaming: bool,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
]:
"""Perform a local search across multiple indexes and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from entities.parquet)
- community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from community_reports.parquet)
- text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from text_units.parquet)
- relationships_list (list[pd.DataFrame]): A list of DataFrames containing the final relationships (from relationships.parquet)
- covariates_list (list[pd.DataFrame]): [Optional] A list of DataFrames containing the final covariates (from covariates.parquet)
- index_names (list[str]): A list of index names.
- community_level (int): The community level to search at.
- response_type (str): The response type to return.
- streaming (bool): Whether to stream the results or not.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
# Streaming not supported yet
if streaming:
message = "Streaming not yet implemented for multi_index_local_search"
raise NotImplementedError(message)
links = {
"community_reports": {},
"communities": {},
"entities": {},
"text_units": {},
"relationships": {},
"covariates": {},
}
max_vals = {
"community_reports": -1,
"communities": -1,
"entities": -1,
"text_units": 0,
"relationships": -1,
"covariates": 0,
}
community_reports_dfs = []
communities_dfs = []
entities_dfs = []
relationships_dfs = []
text_units_dfs = []
covariates_dfs = []
for idx, index_name in enumerate(index_names):
# Prepare each index's communities dataframe for merging
communities_df = communities_list[idx]
communities_df["community"] = communities_df["community"].astype(int)
for i in communities_df["community"]:
links["communities"][i + max_vals["communities"] + 1] = {
"index_name": index_name,
"id": str(i),
}
communities_df["community"] += max_vals["communities"] + 1
communities_df["human_readable_id"] += max_vals["communities"] + 1
# concat the index name to the entity_ids, since this is used for joining later
communities_df["entity_ids"] = communities_df["entity_ids"].apply(
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
)
max_vals["communities"] = int(communities_df["community"].max())
communities_dfs.append(communities_df)
# Prepare each index's community reports dataframe for merging
community_reports_df = community_reports_list[idx]
community_reports_df["community"] = community_reports_df["community"].astype(
int
)
for i in community_reports_df["community"]:
links["community_reports"][i + max_vals["community_reports"] + 1] = {
"index_name": index_name,
"id": str(i),
}
community_reports_df["community"] += max_vals["community_reports"] + 1
community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1
max_vals["community_reports"] = int(community_reports_df["community"].max())
community_reports_dfs.append(community_reports_df)
# Prepare each index's entities dataframe for merging
entities_df = entities_list[idx]
for i in entities_df["human_readable_id"]:
links["entities"][i + max_vals["entities"] + 1] = {
"index_name": index_name,
"id": i,
}
entities_df["human_readable_id"] += max_vals["entities"] + 1
entities_df["title"] = entities_df["title"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
entities_df["id"] = entities_df["id"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply(
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
)
max_vals["entities"] = int(entities_df["human_readable_id"].max())
entities_dfs.append(entities_df)
# Prepare each index's relationships dataframe for merging
relationships_df = relationships_list[idx]
for i in relationships_df["human_readable_id"].astype(int):
links["relationships"][i + max_vals["relationships"] + 1] = {
"index_name": index_name,
"id": i,
}
if max_vals["relationships"] != -1:
col = (
relationships_df["human_readable_id"].astype(int)
+ max_vals["relationships"]
+ 1
)
relationships_df["human_readable_id"] = col.astype(str)
relationships_df["source"] = relationships_df["source"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
relationships_df["target"] = relationships_df["target"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
relationships_df["text_unit_ids"] = relationships_df["text_unit_ids"].apply(
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
)
max_vals["relationships"] = int(relationships_df["human_readable_id"].max())
relationships_dfs.append(relationships_df)
# Prepare each index's text units dataframe for merging
text_units_df = text_units_list[idx]
for i in range(text_units_df.shape[0]):
links["text_units"][i + max_vals["text_units"]] = {
"index_name": index_name,
"id": i,
}
text_units_df["id"] = text_units_df["id"].apply(
lambda x, index_name=index_name: f"{x}-{index_name}"
)
text_units_df["human_readable_id"] = (
text_units_df["human_readable_id"] + max_vals["text_units"]
)
max_vals["text_units"] += text_units_df.shape[0]
text_units_dfs.append(text_units_df)
# If presents, prepare each index's covariates dataframe for merging
if covariates_list is not None:
covariates_df = covariates_list[idx]
for i in covariates_df["human_readable_id"].astype(int):
links["covariates"][i + max_vals["covariates"]] = {
"index_name": index_name,
"id": i,
}
covariates_df["id"] = covariates_df["id"].apply(
lambda x, index_name=index_name: f"{x}-{index_name}"
)
covariates_df["human_readable_id"] = (
covariates_df["human_readable_id"] + max_vals["covariates"]
)
covariates_df["text_unit_id"] = covariates_df["text_unit_id"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
covariates_df["subject_id"] = covariates_df["subject_id"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
max_vals["covariates"] += covariates_df.shape[0]
covariates_dfs.append(covariates_df)
# Merge the dataframes
communities_combined = pd.concat(
communities_dfs, axis=0, ignore_index=True, sort=False
)
community_reports_combined = pd.concat(
community_reports_dfs, axis=0, ignore_index=True, sort=False
)
entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False)
relationships_combined = pd.concat(
relationships_dfs, axis=0, ignore_index=True, sort=False
)
text_units_combined = pd.concat(
text_units_dfs, axis=0, ignore_index=True, sort=False
)
covariates_combined = None
if len(covariates_dfs) > 0:
covariates_combined = pd.concat(
covariates_dfs, axis=0, ignore_index=True, sort=False
)
logger.debug("Executing multi-index local search query: %s", query)
result = await local_search(
config,
entities=entities_combined,
communities=communities_combined,
community_reports=community_reports_combined,
text_units=text_units_combined,
relationships=relationships_combined,
covariates=covariates_combined,
community_level=community_level,
response_type=response_type,
query=query,
callbacks=callbacks,
)
# Update the context data by linking index names and community ids
context = update_context_data(result[1], links)
logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore
return (result[0], context)
@validate_call(config={"arbitrary_types_allowed": True})
async def drift_search(
config: GraphRagConfig,
entities: pd.DataFrame,
communities: pd.DataFrame,
community_reports: pd.DataFrame,
text_units: pd.DataFrame,
relationships: pd.DataFrame,
community_level: int,
response_type: str,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
]:
"""Perform a DRIFT search and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from relationships.parquet)
- community_level (int): The community level to search at.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
callbacks = callbacks or []
full_response = ""
context_data = {}
def on_context(context: Any) -> None:
nonlocal context_data
context_data = context
local_callbacks = NoopQueryCallbacks()
local_callbacks.on_context = on_context
callbacks.append(local_callbacks)
logger.debug("Executing drift search query: %s", query)
async for chunk in drift_search_streaming(
config=config,
entities=entities,
communities=communities,
community_reports=community_reports,
text_units=text_units,
relationships=relationships,
community_level=community_level,
response_type=response_type,
query=query,
callbacks=callbacks,
):
full_response += chunk
logger.debug("Query response: %s", truncate(full_response, 400))
return full_response, context_data
@validate_call(config={"arbitrary_types_allowed": True})
def drift_search_streaming(
config: GraphRagConfig,
entities: pd.DataFrame,
communities: pd.DataFrame,
community_reports: pd.DataFrame,
text_units: pd.DataFrame,
relationships: pd.DataFrame,
community_level: int,
response_type: str,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> AsyncGenerator:
"""Perform a DRIFT search and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from relationships.parquet)
- community_level (int): The community level to search at.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
vector_store_args = {}
for index, store in config.vector_store.items():
vector_store_args[index] = store.model_dump()
msg = f"Vector Store Args: {redact(vector_store_args)}"
logger.debug(msg)
description_embedding_store = get_embedding_store(
config_args=vector_store_args,
embedding_name=entity_description_embedding,
)
full_content_embedding_store = get_embedding_store(
config_args=vector_store_args,
embedding_name=community_full_content_embedding,
)
entities_ = read_indexer_entities(entities, communities, community_level)
reports = read_indexer_reports(community_reports, communities, community_level)
read_indexer_report_embeddings(reports, full_content_embedding_store)
prompt = load_search_prompt(config.root_dir, config.drift_search.prompt)
reduce_prompt = load_search_prompt(
config.root_dir, config.drift_search.reduce_prompt
)
logger.debug("Executing streaming drift search query: %s", query)
search_engine = get_drift_search_engine(
config=config,
reports=reports,
text_units=read_indexer_text_units(text_units),
entities=entities_,
relationships=read_indexer_relationships(relationships),
description_embedding_store=description_embedding_store,
local_system_prompt=prompt,
reduce_system_prompt=reduce_prompt,
response_type=response_type,
callbacks=callbacks,
)
return search_engine.stream_search(query=query)
@validate_call(config={"arbitrary_types_allowed": True})
async def multi_index_drift_search(
config: GraphRagConfig,
entities_list: list[pd.DataFrame],
communities_list: list[pd.DataFrame],
community_reports_list: list[pd.DataFrame],
text_units_list: list[pd.DataFrame],
relationships_list: list[pd.DataFrame],
index_names: list[str],
community_level: int,
response_type: str,
streaming: bool,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
]:
"""Perform a DRIFT search across multiple indexes and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from entities.parquet)
- community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from community_reports.parquet)
- text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from text_units.parquet)
- relationships_list (list[pd.DataFrame]): A list of DataFrames containing the final relationships (from relationships.parquet)
- index_names (list[str]): A list of index names.
- community_level (int): The community level to search at.
- response_type (str): The response type to return.
- streaming (bool): Whether to stream the results or not.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
# Streaming not supported yet
if streaming:
message = "Streaming not yet implemented for multi_drift_search"
raise NotImplementedError(message)
links = {
"community_reports": {},
"communities": {},
"entities": {},
"text_units": {},
"relationships": {},
}
max_vals = {
"community_reports": -1,
"communities": -1,
"entities": -1,
"text_units": 0,
"relationships": -1,
}
communities_dfs = []
community_reports_dfs = []
entities_dfs = []
relationships_dfs = []
text_units_dfs = []
for idx, index_name in enumerate(index_names):
# Prepare each index's communities dataframe for merging
communities_df = communities_list[idx]
communities_df["community"] = communities_df["community"].astype(int)
for i in communities_df["community"]:
links["communities"][i + max_vals["communities"] + 1] = {
"index_name": index_name,
"id": str(i),
}
communities_df["community"] += max_vals["communities"] + 1
communities_df["human_readable_id"] += max_vals["communities"] + 1
# concat the index name to the entity_ids, since this is used for joining later
communities_df["entity_ids"] = communities_df["entity_ids"].apply(
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
)
max_vals["communities"] = int(communities_df["community"].max())
communities_dfs.append(communities_df)
# Prepare each index's community reports dataframe for merging
community_reports_df = community_reports_list[idx]
community_reports_df["community"] = community_reports_df["community"].astype(
int
)
for i in community_reports_df["community"]:
links["community_reports"][i + max_vals["community_reports"] + 1] = {
"index_name": index_name,
"id": str(i),
}
community_reports_df["community"] += max_vals["community_reports"] + 1
community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1
community_reports_df["id"] = community_reports_df["id"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
max_vals["community_reports"] = int(community_reports_df["community"].max())
community_reports_dfs.append(community_reports_df)
# Prepare each index's entities dataframe for merging
entities_df = entities_list[idx]
for i in entities_df["human_readable_id"]:
links["entities"][i + max_vals["entities"] + 1] = {
"index_name": index_name,
"id": i,
}
entities_df["human_readable_id"] += max_vals["entities"] + 1
entities_df["title"] = entities_df["title"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
entities_df["id"] = entities_df["id"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply(
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
)
max_vals["entities"] = int(entities_df["human_readable_id"].max())
entities_dfs.append(entities_df)
# Prepare each index's relationships dataframe for merging
relationships_df = relationships_list[idx]
for i in relationships_df["human_readable_id"].astype(int):
links["relationships"][i + max_vals["relationships"] + 1] = {
"index_name": index_name,
"id": i,
}
if max_vals["relationships"] != -1:
col = (
relationships_df["human_readable_id"].astype(int)
+ max_vals["relationships"]
+ 1
)
relationships_df["human_readable_id"] = col.astype(str)
relationships_df["source"] = relationships_df["source"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
relationships_df["target"] = relationships_df["target"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
relationships_df["text_unit_ids"] = relationships_df["text_unit_ids"].apply(
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
)
max_vals["relationships"] = int(
relationships_df["human_readable_id"].astype(int).max()
)
relationships_dfs.append(relationships_df)
# Prepare each index's text units dataframe for merging
text_units_df = text_units_list[idx]
for i in range(text_units_df.shape[0]):
links["text_units"][i + max_vals["text_units"]] = {
"index_name": index_name,
"id": i,
}
text_units_df["id"] = text_units_df["id"].apply(
lambda x, index_name=index_name: f"{x}-{index_name}"
)
text_units_df["human_readable_id"] = (
text_units_df["human_readable_id"] + max_vals["text_units"]
)
max_vals["text_units"] += text_units_df.shape[0]
text_units_dfs.append(text_units_df)
# Merge the dataframes
communities_combined = pd.concat(
communities_dfs, axis=0, ignore_index=True, sort=False
)
community_reports_combined = pd.concat(
community_reports_dfs, axis=0, ignore_index=True, sort=False
)
entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False)
relationships_combined = pd.concat(
relationships_dfs, axis=0, ignore_index=True, sort=False
)
text_units_combined = pd.concat(
text_units_dfs, axis=0, ignore_index=True, sort=False
)
logger.debug("Executing multi-index drift search query: %s", query)
result = await drift_search(
config,
entities=entities_combined,
communities=communities_combined,
community_reports=community_reports_combined,
text_units=text_units_combined,
relationships=relationships_combined,
community_level=community_level,
response_type=response_type,
query=query,
callbacks=callbacks,
)
# Update the context data by linking index names and community ids
context = {}
if type(result[1]) is dict:
for key in result[1]:
context[key] = update_context_data(result[1][key], links)
else:
context = result[1]
logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore
return (result[0], context)
@validate_call(config={"arbitrary_types_allowed": True})
async def basic_search(
config: GraphRagConfig,
text_units: pd.DataFrame,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
]:
"""Perform a basic search and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
callbacks = callbacks or []
full_response = ""
context_data = {}
def on_context(context: Any) -> None:
nonlocal context_data
context_data = context
local_callbacks = NoopQueryCallbacks()
local_callbacks.on_context = on_context
callbacks.append(local_callbacks)
logger.debug("Executing basic search query: %s", query)
async for chunk in basic_search_streaming(
config=config,
text_units=text_units,
query=query,
callbacks=callbacks,
):
full_response += chunk
logger.debug("Query response: %s", truncate(full_response, 400))
return full_response, context_data
@validate_call(config={"arbitrary_types_allowed": True})
def basic_search_streaming(
config: GraphRagConfig,
text_units: pd.DataFrame,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> AsyncGenerator:
"""Perform a local search and return the context data and response via a generator.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
vector_store_args = {}
for index, store in config.vector_store.items():
vector_store_args[index] = store.model_dump()
msg = f"Vector Store Args: {redact(vector_store_args)}"
logger.debug(msg)
description_embedding_store = get_embedding_store(
config_args=vector_store_args,
embedding_name=text_unit_text_embedding,
)
prompt = load_search_prompt(config.root_dir, config.basic_search.prompt)
logger.debug("Executing streaming basic search query: %s", query)
search_engine = get_basic_search_engine(
config=config,
text_units=read_indexer_text_units(text_units),
text_unit_embeddings=description_embedding_store,
system_prompt=prompt,
callbacks=callbacks,
)
return search_engine.stream_search(query=query)
@validate_call(config={"arbitrary_types_allowed": True})
async def multi_index_basic_search(
config: GraphRagConfig,
text_units_list: list[pd.DataFrame],
index_names: list[str],
streaming: bool,
query: str,
callbacks: list[QueryCallbacks] | None = None,
verbose: bool = False,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
]:
"""Perform a basic search across multiple indexes and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from text_units.parquet)
- index_names (list[str]): A list of index names.
- streaming (bool): Whether to stream the results or not.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
# Streaming not supported yet
if streaming:
message = "Streaming not yet implemented for multi_basic_search"
raise NotImplementedError(message)
links = {
"text_units": {},
}
max_vals = {
"text_units": 0,
}
text_units_dfs = []
for idx, index_name in enumerate(index_names):
# Prepare each index's text units dataframe for merging
text_units_df = text_units_list[idx]
for i in range(text_units_df.shape[0]):
links["text_units"][i + max_vals["text_units"]] = {
"index_name": index_name,
"id": i,
}
text_units_df["id"] = text_units_df["id"].apply(
lambda x, index_name=index_name: f"{x}-{index_name}"
)
text_units_df["human_readable_id"] = (
text_units_df["human_readable_id"] + max_vals["text_units"]
)
max_vals["text_units"] += text_units_df.shape[0]
text_units_dfs.append(text_units_df)
# Merge the dataframes
text_units_combined = pd.concat(
text_units_dfs, axis=0, ignore_index=True, sort=False
)
logger.debug("Executing multi-index basic search query: %s", query)
return await basic_search(
config,
text_units=text_units_combined,
query=query,
callbacks=callbacks,
)