mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 09:07:20 +08:00
Some checks are pending
gh-pages / build (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.11) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.10) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.11) (push) Waiting to run
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Publish (pypi) / Upload release to PyPI (push) Waiting to run
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Spellcheck / spellcheck (push) Waiting to run
* Initial plan for issue * Implement standard logging module and integrate with existing loggers Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Add test cases and improve documentation for standard logging Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Apply ruff formatting and add semversioner file for logging improvements Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Remove custom logger classes and refactor to use standard logging only Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Apply ruff formatting to resolve CI/CD test failures Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Add semversioner file and fix linting issues Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * ruff fixes * fix spelling error * Remove StandardProgressLogger and refactor to use standard logging Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Remove LoggerFactory and custom loggers, refactor to use standard logging Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Fix pyright error: use logger.info() instead of calling logger as function in cosmosdb_pipeline_storage.py Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * ruff fixes * Remove deprecated logger files that were marked as deprecated placeholders Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Replace custom get_logger with standard Python logging Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Fix linting issues found by ruff check --fix Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * apply ruff check fixes * add word to dictionary * Fix type checker error in ModelManager.__new__ method Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Refactor multiple logging.getLogger() calls to use single logger per file Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Remove progress_logger parameter from build_index() and logger parameter from generate_indexing_prompts() Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Remove logger parameter from run_pipeline and standardize logger naming Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Replace logger parameter with log_level parameter in CLI commands Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Fix import ordering in notebook files to pass poetry poe check Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Remove --logger parameter from smoke test command Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Fix Windows CI/CD issue with log file cleanup in tests Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Add StreamHandler to root logger in __main__.py for CLI logging Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Only add StreamHandler if root logger doesn't have existing StreamHandler Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Fix import ordering in notebook files to pass ruff checks Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Replace logging.StreamHandler with colorlog.StreamHandler for colorized log output Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Regenerate poetry.lock file after adding colorlog dependency Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Fix import ordering in notebook files to pass ruff checks Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * move printing of dataframes to debug level * remove colorlog for now * Refactor workflow callbacks to inherit from logging.Handler Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Fix linting issues in workflow callback handlers Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Fix pyright type errors in blob and file workflow callbacks Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Refactor pipeline logging to use pure logging.Handler subclasses Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Rename workflow callback classes to workflow logger classes and move to logger directory Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * update dictionary * apply ruff fixes * fix function name * simplify logger code * update * Remove error, warning, and log methods from WorkflowCallbacks and replace with standard logging Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * ruff fixes * Fix pyright errors by removing WorkflowCallbacks from strategy type signatures Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Remove ConsoleWorkflowLogger and apply consistent formatter to all handlers Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * apply ruff fixes * Refactor pipeline_logger.py to use standard FileHandler and remove FileWorkflowLogger Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Remove conditional azure import checks from blob_workflow_logger.py Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Fix pyright type checking errors in mock_provider.py and utils.py Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Run ruff check --fix to fix import ordering in notebooks Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Merge configure_logging and create_pipeline_logger into init_loggers function Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * Remove configure_logging and create_pipeline_logger functions, replace all usage with init_loggers Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * apply ruff fixes * cleanup unused code * Update init_loggers to accept GraphRagConfig instead of ReportingConfig Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * apply ruff check fixes * Fix test failures by providing valid GraphRagConfig with required model configurations Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * apply ruff fixes * remove logging_workflow_callback * cleanup logging messages * Add logging to track progress of pandas DataFrame apply operation in create_base_text_units Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * cleanup logger logic throughout codebase * update * more cleanup of old loggers * small logger cleanup * final code cleanup and added loggers to query * add verbose logging to query * minor code cleanup * Fix broken unit tests for chunk_text and standard_logging Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * apply ruff fixes * Fix test_chunk_text by mocking progress_ticker function instead of ProgressTicker class Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> * remove unnecessary logger * remove rich and fix type annotation * revert test formatting changes my by copilot * promote graphrag logs to root logger * add correct semversioner file * revert change to file * revert formatting changes that have no effect * fix changes after merge with main * revert unnecessary copilot changes * remove whitespace * cleanup docstring * simplify some logic with less code * update poetry lock file * ruff fixes --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com> Co-authored-by: Josh Bradley <joshbradley@microsoft.com>
1212 lines
46 KiB
Python
1212 lines
46 KiB
Python
# Copyright (c) 2024 Microsoft Corporation.
|
|
# Licensed under the MIT License
|
|
|
|
"""
|
|
Query Engine API.
|
|
|
|
This API provides access to the query engine of graphrag, allowing external applications
|
|
to hook into graphrag and run queries over a knowledge graph generated by graphrag.
|
|
|
|
Contains the following functions:
|
|
- global_search: Perform a global search.
|
|
- global_search_streaming: Perform a global search and stream results back.
|
|
- local_search: Perform a local search.
|
|
- local_search_streaming: Perform a local search and stream results back.
|
|
|
|
WARNING: This API is under development and may undergo changes in future releases.
|
|
Backwards compatibility is not guaranteed at this time.
|
|
"""
|
|
|
|
import logging
|
|
from collections.abc import AsyncGenerator
|
|
from typing import Any
|
|
|
|
import pandas as pd
|
|
from pydantic import validate_call
|
|
|
|
from graphrag.callbacks.noop_query_callbacks import NoopQueryCallbacks
|
|
from graphrag.callbacks.query_callbacks import QueryCallbacks
|
|
from graphrag.config.embeddings import (
|
|
community_full_content_embedding,
|
|
entity_description_embedding,
|
|
text_unit_text_embedding,
|
|
)
|
|
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
|
from graphrag.logger.standard_logging import init_loggers
|
|
from graphrag.query.factory import (
|
|
get_basic_search_engine,
|
|
get_drift_search_engine,
|
|
get_global_search_engine,
|
|
get_local_search_engine,
|
|
)
|
|
from graphrag.query.indexer_adapters import (
|
|
read_indexer_communities,
|
|
read_indexer_covariates,
|
|
read_indexer_entities,
|
|
read_indexer_relationships,
|
|
read_indexer_report_embeddings,
|
|
read_indexer_reports,
|
|
read_indexer_text_units,
|
|
)
|
|
from graphrag.utils.api import (
|
|
get_embedding_store,
|
|
load_search_prompt,
|
|
truncate,
|
|
update_context_data,
|
|
)
|
|
from graphrag.utils.cli import redact
|
|
|
|
# Initialize standard logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
async def global_search(
|
|
config: GraphRagConfig,
|
|
entities: pd.DataFrame,
|
|
communities: pd.DataFrame,
|
|
community_reports: pd.DataFrame,
|
|
community_level: int | None,
|
|
dynamic_community_selection: bool,
|
|
response_type: str,
|
|
query: str,
|
|
callbacks: list[QueryCallbacks] | None = None,
|
|
verbose: bool = False,
|
|
) -> tuple[
|
|
str | dict[str, Any] | list[dict[str, Any]],
|
|
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
|
]:
|
|
"""Perform a global search and return the context data and response.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
|
|
- communities (pd.DataFrame): A DataFrame containing the final communities (from communities.parquet)
|
|
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
|
|
- community_level (int): The community level to search at.
|
|
- dynamic_community_selection (bool): Enable dynamic community selection instead of using all community reports at a fixed level. Note that you can still provide community_level cap the maximum level to search.
|
|
- response_type (str): The type of response to return.
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
"""
|
|
init_loggers(config=config, verbose=verbose)
|
|
|
|
callbacks = callbacks or []
|
|
full_response = ""
|
|
context_data = {}
|
|
|
|
def on_context(context: Any) -> None:
|
|
nonlocal context_data
|
|
context_data = context
|
|
|
|
local_callbacks = NoopQueryCallbacks()
|
|
local_callbacks.on_context = on_context
|
|
callbacks.append(local_callbacks)
|
|
|
|
logger.debug("Executing global search query: %s", query)
|
|
async for chunk in global_search_streaming(
|
|
config=config,
|
|
entities=entities,
|
|
communities=communities,
|
|
community_reports=community_reports,
|
|
community_level=community_level,
|
|
dynamic_community_selection=dynamic_community_selection,
|
|
response_type=response_type,
|
|
query=query,
|
|
callbacks=callbacks,
|
|
):
|
|
full_response += chunk
|
|
logger.debug("Query response: %s", truncate(full_response, 400))
|
|
return full_response, context_data
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
def global_search_streaming(
|
|
config: GraphRagConfig,
|
|
entities: pd.DataFrame,
|
|
communities: pd.DataFrame,
|
|
community_reports: pd.DataFrame,
|
|
community_level: int | None,
|
|
dynamic_community_selection: bool,
|
|
response_type: str,
|
|
query: str,
|
|
callbacks: list[QueryCallbacks] | None = None,
|
|
verbose: bool = False,
|
|
) -> AsyncGenerator:
|
|
"""Perform a global search and return the context data and response via a generator.
|
|
|
|
Context data is returned as a dictionary of lists, with one list entry for each record.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
|
|
- communities (pd.DataFrame): A DataFrame containing the final communities (from communities.parquet)
|
|
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
|
|
- community_level (int): The community level to search at.
|
|
- dynamic_community_selection (bool): Enable dynamic community selection instead of using all community reports at a fixed level. Note that you can still provide community_level cap the maximum level to search.
|
|
- response_type (str): The type of response to return.
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
"""
|
|
init_loggers(config=config, verbose=verbose)
|
|
|
|
communities_ = read_indexer_communities(communities, community_reports)
|
|
reports = read_indexer_reports(
|
|
community_reports,
|
|
communities,
|
|
community_level=community_level,
|
|
dynamic_community_selection=dynamic_community_selection,
|
|
)
|
|
entities_ = read_indexer_entities(
|
|
entities, communities, community_level=community_level
|
|
)
|
|
map_prompt = load_search_prompt(config.root_dir, config.global_search.map_prompt)
|
|
reduce_prompt = load_search_prompt(
|
|
config.root_dir, config.global_search.reduce_prompt
|
|
)
|
|
knowledge_prompt = load_search_prompt(
|
|
config.root_dir, config.global_search.knowledge_prompt
|
|
)
|
|
|
|
logger.debug("Executing streaming global search query: %s", query)
|
|
search_engine = get_global_search_engine(
|
|
config,
|
|
reports=reports,
|
|
entities=entities_,
|
|
communities=communities_,
|
|
response_type=response_type,
|
|
dynamic_community_selection=dynamic_community_selection,
|
|
map_system_prompt=map_prompt,
|
|
reduce_system_prompt=reduce_prompt,
|
|
general_knowledge_inclusion_prompt=knowledge_prompt,
|
|
callbacks=callbacks,
|
|
)
|
|
return search_engine.stream_search(query=query)
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
async def multi_index_global_search(
|
|
config: GraphRagConfig,
|
|
entities_list: list[pd.DataFrame],
|
|
communities_list: list[pd.DataFrame],
|
|
community_reports_list: list[pd.DataFrame],
|
|
index_names: list[str],
|
|
community_level: int | None,
|
|
dynamic_community_selection: bool,
|
|
response_type: str,
|
|
streaming: bool,
|
|
query: str,
|
|
callbacks: list[QueryCallbacks] | None = None,
|
|
verbose: bool = False,
|
|
) -> tuple[
|
|
str | dict[str, Any] | list[dict[str, Any]],
|
|
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
|
]:
|
|
"""Perform a global search across multiple indexes and return the context data and response.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from entities.parquet)
|
|
- communities_list (list[pd.DataFrame]): A list of DataFrames containing the final communities (from communities.parquet)
|
|
- community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from community_reports.parquet)
|
|
- index_names (list[str]): A list of index names.
|
|
- community_level (int): The community level to search at.
|
|
- dynamic_community_selection (bool): Enable dynamic community selection instead of using all community reports at a fixed level. Note that you can still provide community_level cap the maximum level to search.
|
|
- response_type (str): The type of response to return.
|
|
- streaming (bool): Whether to stream the results or not.
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
"""
|
|
init_loggers(config=config, verbose=verbose)
|
|
|
|
# Streaming not supported yet
|
|
if streaming:
|
|
message = "Streaming not yet implemented for multi_global_search"
|
|
raise NotImplementedError(message)
|
|
|
|
links = {
|
|
"communities": {},
|
|
"community_reports": {},
|
|
"entities": {},
|
|
}
|
|
max_vals = {
|
|
"communities": -1,
|
|
"community_reports": -1,
|
|
"entities": -1,
|
|
}
|
|
|
|
communities_dfs = []
|
|
community_reports_dfs = []
|
|
entities_dfs = []
|
|
|
|
for idx, index_name in enumerate(index_names):
|
|
# Prepare each index's community reports dataframe for merging
|
|
community_reports_df = community_reports_list[idx]
|
|
community_reports_df["community"] = community_reports_df["community"].astype(
|
|
int
|
|
)
|
|
for i in community_reports_df["community"]:
|
|
links["community_reports"][i + max_vals["community_reports"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": str(i),
|
|
}
|
|
community_reports_df["community"] += max_vals["community_reports"] + 1
|
|
community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1
|
|
max_vals["community_reports"] = int(community_reports_df["community"].max())
|
|
community_reports_dfs.append(community_reports_df)
|
|
|
|
# Prepare each index's communities dataframe for merging
|
|
communities_df = communities_list[idx]
|
|
communities_df["community"] = communities_df["community"].astype(int)
|
|
communities_df["parent"] = communities_df["parent"].astype(int)
|
|
for i in communities_df["community"]:
|
|
links["communities"][i + max_vals["communities"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": str(i),
|
|
}
|
|
communities_df["community"] += max_vals["communities"] + 1
|
|
communities_df["parent"] = communities_df["parent"].apply(
|
|
lambda x: x if x == -1 else x + max_vals["communities"] + 1
|
|
)
|
|
communities_df["human_readable_id"] += max_vals["communities"] + 1
|
|
# concat the index name to the entity_ids, since this is used for joining later
|
|
communities_df["entity_ids"] = communities_df["entity_ids"].apply(
|
|
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
|
|
)
|
|
max_vals["communities"] = int(communities_df["community"].max())
|
|
communities_dfs.append(communities_df)
|
|
|
|
# Prepare each index's entities dataframe for merging
|
|
entities_df = entities_list[idx]
|
|
for i in entities_df["human_readable_id"]:
|
|
links["entities"][i + max_vals["entities"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": i,
|
|
}
|
|
entities_df["human_readable_id"] += max_vals["entities"] + 1
|
|
entities_df["title"] = entities_df["title"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply(
|
|
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
|
|
)
|
|
max_vals["entities"] = int(entities_df["human_readable_id"].max())
|
|
entities_dfs.append(entities_df)
|
|
|
|
# Merge the dataframes
|
|
community_reports_combined = pd.concat(
|
|
community_reports_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False)
|
|
communities_combined = pd.concat(
|
|
communities_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
|
|
logger.debug("Executing multi-index global search query: %s", query)
|
|
result = await global_search(
|
|
config,
|
|
entities=entities_combined,
|
|
communities=communities_combined,
|
|
community_reports=community_reports_combined,
|
|
community_level=community_level,
|
|
dynamic_community_selection=dynamic_community_selection,
|
|
response_type=response_type,
|
|
query=query,
|
|
callbacks=callbacks,
|
|
)
|
|
|
|
# Update the context data by linking index names and community ids
|
|
context = update_context_data(result[1], links)
|
|
|
|
logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore
|
|
return (result[0], context)
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
async def local_search(
|
|
config: GraphRagConfig,
|
|
entities: pd.DataFrame,
|
|
communities: pd.DataFrame,
|
|
community_reports: pd.DataFrame,
|
|
text_units: pd.DataFrame,
|
|
relationships: pd.DataFrame,
|
|
covariates: pd.DataFrame | None,
|
|
community_level: int,
|
|
response_type: str,
|
|
query: str,
|
|
callbacks: list[QueryCallbacks] | None = None,
|
|
verbose: bool = False,
|
|
) -> tuple[
|
|
str | dict[str, Any] | list[dict[str, Any]],
|
|
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
|
]:
|
|
"""Perform a local search and return the context data and response.
|
|
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
|
|
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
|
|
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
|
|
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from relationships.parquet)
|
|
- covariates (pd.DataFrame): A DataFrame containing the final covariates (from covariates.parquet)
|
|
- community_level (int): The community level to search at.
|
|
- response_type (str): The response type to return.
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
"""
|
|
init_loggers(config=config, verbose=verbose)
|
|
|
|
callbacks = callbacks or []
|
|
full_response = ""
|
|
context_data = {}
|
|
|
|
def on_context(context: Any) -> None:
|
|
nonlocal context_data
|
|
context_data = context
|
|
|
|
local_callbacks = NoopQueryCallbacks()
|
|
local_callbacks.on_context = on_context
|
|
callbacks.append(local_callbacks)
|
|
|
|
logger.debug("Executing local search query: %s", query)
|
|
async for chunk in local_search_streaming(
|
|
config=config,
|
|
entities=entities,
|
|
communities=communities,
|
|
community_reports=community_reports,
|
|
text_units=text_units,
|
|
relationships=relationships,
|
|
covariates=covariates,
|
|
community_level=community_level,
|
|
response_type=response_type,
|
|
query=query,
|
|
callbacks=callbacks,
|
|
):
|
|
full_response += chunk
|
|
logger.debug("Query response: %s", truncate(full_response, 400))
|
|
return full_response, context_data
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
def local_search_streaming(
|
|
config: GraphRagConfig,
|
|
entities: pd.DataFrame,
|
|
communities: pd.DataFrame,
|
|
community_reports: pd.DataFrame,
|
|
text_units: pd.DataFrame,
|
|
relationships: pd.DataFrame,
|
|
covariates: pd.DataFrame | None,
|
|
community_level: int,
|
|
response_type: str,
|
|
query: str,
|
|
callbacks: list[QueryCallbacks] | None = None,
|
|
verbose: bool = False,
|
|
) -> AsyncGenerator:
|
|
"""Perform a local search and return the context data and response via a generator.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
|
|
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
|
|
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
|
|
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from relationships.parquet)
|
|
- covariates (pd.DataFrame): A DataFrame containing the final covariates (from covariates.parquet)
|
|
- community_level (int): The community level to search at.
|
|
- response_type (str): The response type to return.
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
"""
|
|
init_loggers(config=config, verbose=verbose)
|
|
|
|
vector_store_args = {}
|
|
for index, store in config.vector_store.items():
|
|
vector_store_args[index] = store.model_dump()
|
|
msg = f"Vector Store Args: {redact(vector_store_args)}"
|
|
logger.debug(msg)
|
|
|
|
description_embedding_store = get_embedding_store(
|
|
config_args=vector_store_args,
|
|
embedding_name=entity_description_embedding,
|
|
)
|
|
|
|
entities_ = read_indexer_entities(entities, communities, community_level)
|
|
covariates_ = read_indexer_covariates(covariates) if covariates is not None else []
|
|
prompt = load_search_prompt(config.root_dir, config.local_search.prompt)
|
|
|
|
logger.debug("Executing streaming local search query: %s", query)
|
|
search_engine = get_local_search_engine(
|
|
config=config,
|
|
reports=read_indexer_reports(community_reports, communities, community_level),
|
|
text_units=read_indexer_text_units(text_units),
|
|
entities=entities_,
|
|
relationships=read_indexer_relationships(relationships),
|
|
covariates={"claims": covariates_},
|
|
description_embedding_store=description_embedding_store,
|
|
response_type=response_type,
|
|
system_prompt=prompt,
|
|
callbacks=callbacks,
|
|
)
|
|
return search_engine.stream_search(query=query)
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
async def multi_index_local_search(
|
|
config: GraphRagConfig,
|
|
entities_list: list[pd.DataFrame],
|
|
communities_list: list[pd.DataFrame],
|
|
community_reports_list: list[pd.DataFrame],
|
|
text_units_list: list[pd.DataFrame],
|
|
relationships_list: list[pd.DataFrame],
|
|
covariates_list: list[pd.DataFrame] | None,
|
|
index_names: list[str],
|
|
community_level: int,
|
|
response_type: str,
|
|
streaming: bool,
|
|
query: str,
|
|
callbacks: list[QueryCallbacks] | None = None,
|
|
verbose: bool = False,
|
|
) -> tuple[
|
|
str | dict[str, Any] | list[dict[str, Any]],
|
|
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
|
]:
|
|
"""Perform a local search across multiple indexes and return the context data and response.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from entities.parquet)
|
|
- community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from community_reports.parquet)
|
|
- text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from text_units.parquet)
|
|
- relationships_list (list[pd.DataFrame]): A list of DataFrames containing the final relationships (from relationships.parquet)
|
|
- covariates_list (list[pd.DataFrame]): [Optional] A list of DataFrames containing the final covariates (from covariates.parquet)
|
|
- index_names (list[str]): A list of index names.
|
|
- community_level (int): The community level to search at.
|
|
- response_type (str): The response type to return.
|
|
- streaming (bool): Whether to stream the results or not.
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
"""
|
|
init_loggers(config=config, verbose=verbose)
|
|
|
|
# Streaming not supported yet
|
|
if streaming:
|
|
message = "Streaming not yet implemented for multi_index_local_search"
|
|
raise NotImplementedError(message)
|
|
|
|
links = {
|
|
"community_reports": {},
|
|
"communities": {},
|
|
"entities": {},
|
|
"text_units": {},
|
|
"relationships": {},
|
|
"covariates": {},
|
|
}
|
|
max_vals = {
|
|
"community_reports": -1,
|
|
"communities": -1,
|
|
"entities": -1,
|
|
"text_units": 0,
|
|
"relationships": -1,
|
|
"covariates": 0,
|
|
}
|
|
community_reports_dfs = []
|
|
communities_dfs = []
|
|
entities_dfs = []
|
|
relationships_dfs = []
|
|
text_units_dfs = []
|
|
covariates_dfs = []
|
|
|
|
for idx, index_name in enumerate(index_names):
|
|
# Prepare each index's communities dataframe for merging
|
|
communities_df = communities_list[idx]
|
|
communities_df["community"] = communities_df["community"].astype(int)
|
|
for i in communities_df["community"]:
|
|
links["communities"][i + max_vals["communities"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": str(i),
|
|
}
|
|
communities_df["community"] += max_vals["communities"] + 1
|
|
communities_df["human_readable_id"] += max_vals["communities"] + 1
|
|
# concat the index name to the entity_ids, since this is used for joining later
|
|
communities_df["entity_ids"] = communities_df["entity_ids"].apply(
|
|
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
|
|
)
|
|
max_vals["communities"] = int(communities_df["community"].max())
|
|
communities_dfs.append(communities_df)
|
|
|
|
# Prepare each index's community reports dataframe for merging
|
|
community_reports_df = community_reports_list[idx]
|
|
community_reports_df["community"] = community_reports_df["community"].astype(
|
|
int
|
|
)
|
|
for i in community_reports_df["community"]:
|
|
links["community_reports"][i + max_vals["community_reports"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": str(i),
|
|
}
|
|
community_reports_df["community"] += max_vals["community_reports"] + 1
|
|
community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1
|
|
max_vals["community_reports"] = int(community_reports_df["community"].max())
|
|
community_reports_dfs.append(community_reports_df)
|
|
|
|
# Prepare each index's entities dataframe for merging
|
|
entities_df = entities_list[idx]
|
|
for i in entities_df["human_readable_id"]:
|
|
links["entities"][i + max_vals["entities"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": i,
|
|
}
|
|
entities_df["human_readable_id"] += max_vals["entities"] + 1
|
|
entities_df["title"] = entities_df["title"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
entities_df["id"] = entities_df["id"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply(
|
|
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
|
|
)
|
|
max_vals["entities"] = int(entities_df["human_readable_id"].max())
|
|
entities_dfs.append(entities_df)
|
|
|
|
# Prepare each index's relationships dataframe for merging
|
|
relationships_df = relationships_list[idx]
|
|
for i in relationships_df["human_readable_id"].astype(int):
|
|
links["relationships"][i + max_vals["relationships"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": i,
|
|
}
|
|
if max_vals["relationships"] != -1:
|
|
col = (
|
|
relationships_df["human_readable_id"].astype(int)
|
|
+ max_vals["relationships"]
|
|
+ 1
|
|
)
|
|
relationships_df["human_readable_id"] = col.astype(str)
|
|
relationships_df["source"] = relationships_df["source"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
relationships_df["target"] = relationships_df["target"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
relationships_df["text_unit_ids"] = relationships_df["text_unit_ids"].apply(
|
|
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
|
|
)
|
|
max_vals["relationships"] = int(relationships_df["human_readable_id"].max())
|
|
relationships_dfs.append(relationships_df)
|
|
|
|
# Prepare each index's text units dataframe for merging
|
|
text_units_df = text_units_list[idx]
|
|
for i in range(text_units_df.shape[0]):
|
|
links["text_units"][i + max_vals["text_units"]] = {
|
|
"index_name": index_name,
|
|
"id": i,
|
|
}
|
|
text_units_df["id"] = text_units_df["id"].apply(
|
|
lambda x, index_name=index_name: f"{x}-{index_name}"
|
|
)
|
|
text_units_df["human_readable_id"] = (
|
|
text_units_df["human_readable_id"] + max_vals["text_units"]
|
|
)
|
|
max_vals["text_units"] += text_units_df.shape[0]
|
|
text_units_dfs.append(text_units_df)
|
|
|
|
# If presents, prepare each index's covariates dataframe for merging
|
|
if covariates_list is not None:
|
|
covariates_df = covariates_list[idx]
|
|
for i in covariates_df["human_readable_id"].astype(int):
|
|
links["covariates"][i + max_vals["covariates"]] = {
|
|
"index_name": index_name,
|
|
"id": i,
|
|
}
|
|
covariates_df["id"] = covariates_df["id"].apply(
|
|
lambda x, index_name=index_name: f"{x}-{index_name}"
|
|
)
|
|
covariates_df["human_readable_id"] = (
|
|
covariates_df["human_readable_id"] + max_vals["covariates"]
|
|
)
|
|
covariates_df["text_unit_id"] = covariates_df["text_unit_id"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
covariates_df["subject_id"] = covariates_df["subject_id"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
max_vals["covariates"] += covariates_df.shape[0]
|
|
covariates_dfs.append(covariates_df)
|
|
|
|
# Merge the dataframes
|
|
communities_combined = pd.concat(
|
|
communities_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
community_reports_combined = pd.concat(
|
|
community_reports_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False)
|
|
relationships_combined = pd.concat(
|
|
relationships_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
text_units_combined = pd.concat(
|
|
text_units_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
covariates_combined = None
|
|
if len(covariates_dfs) > 0:
|
|
covariates_combined = pd.concat(
|
|
covariates_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
logger.debug("Executing multi-index local search query: %s", query)
|
|
result = await local_search(
|
|
config,
|
|
entities=entities_combined,
|
|
communities=communities_combined,
|
|
community_reports=community_reports_combined,
|
|
text_units=text_units_combined,
|
|
relationships=relationships_combined,
|
|
covariates=covariates_combined,
|
|
community_level=community_level,
|
|
response_type=response_type,
|
|
query=query,
|
|
callbacks=callbacks,
|
|
)
|
|
|
|
# Update the context data by linking index names and community ids
|
|
context = update_context_data(result[1], links)
|
|
|
|
logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore
|
|
return (result[0], context)
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
async def drift_search(
|
|
config: GraphRagConfig,
|
|
entities: pd.DataFrame,
|
|
communities: pd.DataFrame,
|
|
community_reports: pd.DataFrame,
|
|
text_units: pd.DataFrame,
|
|
relationships: pd.DataFrame,
|
|
community_level: int,
|
|
response_type: str,
|
|
query: str,
|
|
callbacks: list[QueryCallbacks] | None = None,
|
|
verbose: bool = False,
|
|
) -> tuple[
|
|
str | dict[str, Any] | list[dict[str, Any]],
|
|
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
|
]:
|
|
"""Perform a DRIFT search and return the context data and response.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
|
|
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
|
|
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
|
|
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from relationships.parquet)
|
|
- community_level (int): The community level to search at.
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
"""
|
|
init_loggers(config=config, verbose=verbose)
|
|
|
|
callbacks = callbacks or []
|
|
full_response = ""
|
|
context_data = {}
|
|
|
|
def on_context(context: Any) -> None:
|
|
nonlocal context_data
|
|
context_data = context
|
|
|
|
local_callbacks = NoopQueryCallbacks()
|
|
local_callbacks.on_context = on_context
|
|
callbacks.append(local_callbacks)
|
|
|
|
logger.debug("Executing drift search query: %s", query)
|
|
async for chunk in drift_search_streaming(
|
|
config=config,
|
|
entities=entities,
|
|
communities=communities,
|
|
community_reports=community_reports,
|
|
text_units=text_units,
|
|
relationships=relationships,
|
|
community_level=community_level,
|
|
response_type=response_type,
|
|
query=query,
|
|
callbacks=callbacks,
|
|
):
|
|
full_response += chunk
|
|
logger.debug("Query response: %s", truncate(full_response, 400))
|
|
return full_response, context_data
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
def drift_search_streaming(
|
|
config: GraphRagConfig,
|
|
entities: pd.DataFrame,
|
|
communities: pd.DataFrame,
|
|
community_reports: pd.DataFrame,
|
|
text_units: pd.DataFrame,
|
|
relationships: pd.DataFrame,
|
|
community_level: int,
|
|
response_type: str,
|
|
query: str,
|
|
callbacks: list[QueryCallbacks] | None = None,
|
|
verbose: bool = False,
|
|
) -> AsyncGenerator:
|
|
"""Perform a DRIFT search and return the context data and response.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
|
|
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
|
|
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
|
|
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from relationships.parquet)
|
|
- community_level (int): The community level to search at.
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
"""
|
|
init_loggers(config=config, verbose=verbose)
|
|
|
|
vector_store_args = {}
|
|
for index, store in config.vector_store.items():
|
|
vector_store_args[index] = store.model_dump()
|
|
msg = f"Vector Store Args: {redact(vector_store_args)}"
|
|
logger.debug(msg)
|
|
|
|
description_embedding_store = get_embedding_store(
|
|
config_args=vector_store_args,
|
|
embedding_name=entity_description_embedding,
|
|
)
|
|
|
|
full_content_embedding_store = get_embedding_store(
|
|
config_args=vector_store_args,
|
|
embedding_name=community_full_content_embedding,
|
|
)
|
|
|
|
entities_ = read_indexer_entities(entities, communities, community_level)
|
|
reports = read_indexer_reports(community_reports, communities, community_level)
|
|
read_indexer_report_embeddings(reports, full_content_embedding_store)
|
|
prompt = load_search_prompt(config.root_dir, config.drift_search.prompt)
|
|
reduce_prompt = load_search_prompt(
|
|
config.root_dir, config.drift_search.reduce_prompt
|
|
)
|
|
|
|
logger.debug("Executing streaming drift search query: %s", query)
|
|
search_engine = get_drift_search_engine(
|
|
config=config,
|
|
reports=reports,
|
|
text_units=read_indexer_text_units(text_units),
|
|
entities=entities_,
|
|
relationships=read_indexer_relationships(relationships),
|
|
description_embedding_store=description_embedding_store,
|
|
local_system_prompt=prompt,
|
|
reduce_system_prompt=reduce_prompt,
|
|
response_type=response_type,
|
|
callbacks=callbacks,
|
|
)
|
|
return search_engine.stream_search(query=query)
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
async def multi_index_drift_search(
|
|
config: GraphRagConfig,
|
|
entities_list: list[pd.DataFrame],
|
|
communities_list: list[pd.DataFrame],
|
|
community_reports_list: list[pd.DataFrame],
|
|
text_units_list: list[pd.DataFrame],
|
|
relationships_list: list[pd.DataFrame],
|
|
index_names: list[str],
|
|
community_level: int,
|
|
response_type: str,
|
|
streaming: bool,
|
|
query: str,
|
|
callbacks: list[QueryCallbacks] | None = None,
|
|
verbose: bool = False,
|
|
) -> tuple[
|
|
str | dict[str, Any] | list[dict[str, Any]],
|
|
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
|
]:
|
|
"""Perform a DRIFT search across multiple indexes and return the context data and response.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from entities.parquet)
|
|
- community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from community_reports.parquet)
|
|
- text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from text_units.parquet)
|
|
- relationships_list (list[pd.DataFrame]): A list of DataFrames containing the final relationships (from relationships.parquet)
|
|
- index_names (list[str]): A list of index names.
|
|
- community_level (int): The community level to search at.
|
|
- response_type (str): The response type to return.
|
|
- streaming (bool): Whether to stream the results or not.
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
"""
|
|
init_loggers(config=config, verbose=verbose)
|
|
|
|
# Streaming not supported yet
|
|
if streaming:
|
|
message = "Streaming not yet implemented for multi_drift_search"
|
|
raise NotImplementedError(message)
|
|
|
|
links = {
|
|
"community_reports": {},
|
|
"communities": {},
|
|
"entities": {},
|
|
"text_units": {},
|
|
"relationships": {},
|
|
}
|
|
max_vals = {
|
|
"community_reports": -1,
|
|
"communities": -1,
|
|
"entities": -1,
|
|
"text_units": 0,
|
|
"relationships": -1,
|
|
}
|
|
|
|
communities_dfs = []
|
|
community_reports_dfs = []
|
|
entities_dfs = []
|
|
relationships_dfs = []
|
|
text_units_dfs = []
|
|
|
|
for idx, index_name in enumerate(index_names):
|
|
# Prepare each index's communities dataframe for merging
|
|
communities_df = communities_list[idx]
|
|
communities_df["community"] = communities_df["community"].astype(int)
|
|
for i in communities_df["community"]:
|
|
links["communities"][i + max_vals["communities"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": str(i),
|
|
}
|
|
communities_df["community"] += max_vals["communities"] + 1
|
|
communities_df["human_readable_id"] += max_vals["communities"] + 1
|
|
# concat the index name to the entity_ids, since this is used for joining later
|
|
communities_df["entity_ids"] = communities_df["entity_ids"].apply(
|
|
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
|
|
)
|
|
max_vals["communities"] = int(communities_df["community"].max())
|
|
communities_dfs.append(communities_df)
|
|
|
|
# Prepare each index's community reports dataframe for merging
|
|
community_reports_df = community_reports_list[idx]
|
|
community_reports_df["community"] = community_reports_df["community"].astype(
|
|
int
|
|
)
|
|
for i in community_reports_df["community"]:
|
|
links["community_reports"][i + max_vals["community_reports"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": str(i),
|
|
}
|
|
community_reports_df["community"] += max_vals["community_reports"] + 1
|
|
community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1
|
|
community_reports_df["id"] = community_reports_df["id"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
max_vals["community_reports"] = int(community_reports_df["community"].max())
|
|
community_reports_dfs.append(community_reports_df)
|
|
|
|
# Prepare each index's entities dataframe for merging
|
|
entities_df = entities_list[idx]
|
|
for i in entities_df["human_readable_id"]:
|
|
links["entities"][i + max_vals["entities"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": i,
|
|
}
|
|
entities_df["human_readable_id"] += max_vals["entities"] + 1
|
|
entities_df["title"] = entities_df["title"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
entities_df["id"] = entities_df["id"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply(
|
|
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
|
|
)
|
|
max_vals["entities"] = int(entities_df["human_readable_id"].max())
|
|
entities_dfs.append(entities_df)
|
|
|
|
# Prepare each index's relationships dataframe for merging
|
|
relationships_df = relationships_list[idx]
|
|
for i in relationships_df["human_readable_id"].astype(int):
|
|
links["relationships"][i + max_vals["relationships"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": i,
|
|
}
|
|
if max_vals["relationships"] != -1:
|
|
col = (
|
|
relationships_df["human_readable_id"].astype(int)
|
|
+ max_vals["relationships"]
|
|
+ 1
|
|
)
|
|
relationships_df["human_readable_id"] = col.astype(str)
|
|
relationships_df["source"] = relationships_df["source"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
relationships_df["target"] = relationships_df["target"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
relationships_df["text_unit_ids"] = relationships_df["text_unit_ids"].apply(
|
|
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
|
|
)
|
|
max_vals["relationships"] = int(
|
|
relationships_df["human_readable_id"].astype(int).max()
|
|
)
|
|
|
|
relationships_dfs.append(relationships_df)
|
|
|
|
# Prepare each index's text units dataframe for merging
|
|
text_units_df = text_units_list[idx]
|
|
for i in range(text_units_df.shape[0]):
|
|
links["text_units"][i + max_vals["text_units"]] = {
|
|
"index_name": index_name,
|
|
"id": i,
|
|
}
|
|
text_units_df["id"] = text_units_df["id"].apply(
|
|
lambda x, index_name=index_name: f"{x}-{index_name}"
|
|
)
|
|
text_units_df["human_readable_id"] = (
|
|
text_units_df["human_readable_id"] + max_vals["text_units"]
|
|
)
|
|
max_vals["text_units"] += text_units_df.shape[0]
|
|
text_units_dfs.append(text_units_df)
|
|
|
|
# Merge the dataframes
|
|
communities_combined = pd.concat(
|
|
communities_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
community_reports_combined = pd.concat(
|
|
community_reports_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False)
|
|
relationships_combined = pd.concat(
|
|
relationships_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
text_units_combined = pd.concat(
|
|
text_units_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
|
|
logger.debug("Executing multi-index drift search query: %s", query)
|
|
result = await drift_search(
|
|
config,
|
|
entities=entities_combined,
|
|
communities=communities_combined,
|
|
community_reports=community_reports_combined,
|
|
text_units=text_units_combined,
|
|
relationships=relationships_combined,
|
|
community_level=community_level,
|
|
response_type=response_type,
|
|
query=query,
|
|
callbacks=callbacks,
|
|
)
|
|
|
|
# Update the context data by linking index names and community ids
|
|
context = {}
|
|
if type(result[1]) is dict:
|
|
for key in result[1]:
|
|
context[key] = update_context_data(result[1][key], links)
|
|
else:
|
|
context = result[1]
|
|
|
|
logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore
|
|
return (result[0], context)
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
async def basic_search(
|
|
config: GraphRagConfig,
|
|
text_units: pd.DataFrame,
|
|
query: str,
|
|
callbacks: list[QueryCallbacks] | None = None,
|
|
verbose: bool = False,
|
|
) -> tuple[
|
|
str | dict[str, Any] | list[dict[str, Any]],
|
|
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
|
]:
|
|
"""Perform a basic search and return the context data and response.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
"""
|
|
init_loggers(config=config, verbose=verbose)
|
|
|
|
callbacks = callbacks or []
|
|
full_response = ""
|
|
context_data = {}
|
|
|
|
def on_context(context: Any) -> None:
|
|
nonlocal context_data
|
|
context_data = context
|
|
|
|
local_callbacks = NoopQueryCallbacks()
|
|
local_callbacks.on_context = on_context
|
|
callbacks.append(local_callbacks)
|
|
|
|
logger.debug("Executing basic search query: %s", query)
|
|
async for chunk in basic_search_streaming(
|
|
config=config,
|
|
text_units=text_units,
|
|
query=query,
|
|
callbacks=callbacks,
|
|
):
|
|
full_response += chunk
|
|
logger.debug("Query response: %s", truncate(full_response, 400))
|
|
return full_response, context_data
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
def basic_search_streaming(
|
|
config: GraphRagConfig,
|
|
text_units: pd.DataFrame,
|
|
query: str,
|
|
callbacks: list[QueryCallbacks] | None = None,
|
|
verbose: bool = False,
|
|
) -> AsyncGenerator:
|
|
"""Perform a local search and return the context data and response via a generator.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
"""
|
|
init_loggers(config=config, verbose=verbose)
|
|
|
|
vector_store_args = {}
|
|
for index, store in config.vector_store.items():
|
|
vector_store_args[index] = store.model_dump()
|
|
msg = f"Vector Store Args: {redact(vector_store_args)}"
|
|
logger.debug(msg)
|
|
|
|
description_embedding_store = get_embedding_store(
|
|
config_args=vector_store_args,
|
|
embedding_name=text_unit_text_embedding,
|
|
)
|
|
|
|
prompt = load_search_prompt(config.root_dir, config.basic_search.prompt)
|
|
|
|
logger.debug("Executing streaming basic search query: %s", query)
|
|
search_engine = get_basic_search_engine(
|
|
config=config,
|
|
text_units=read_indexer_text_units(text_units),
|
|
text_unit_embeddings=description_embedding_store,
|
|
system_prompt=prompt,
|
|
callbacks=callbacks,
|
|
)
|
|
return search_engine.stream_search(query=query)
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
async def multi_index_basic_search(
|
|
config: GraphRagConfig,
|
|
text_units_list: list[pd.DataFrame],
|
|
index_names: list[str],
|
|
streaming: bool,
|
|
query: str,
|
|
callbacks: list[QueryCallbacks] | None = None,
|
|
verbose: bool = False,
|
|
) -> tuple[
|
|
str | dict[str, Any] | list[dict[str, Any]],
|
|
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
|
]:
|
|
"""Perform a basic search across multiple indexes and return the context data and response.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from text_units.parquet)
|
|
- index_names (list[str]): A list of index names.
|
|
- streaming (bool): Whether to stream the results or not.
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
"""
|
|
init_loggers(config=config, verbose=verbose)
|
|
|
|
# Streaming not supported yet
|
|
if streaming:
|
|
message = "Streaming not yet implemented for multi_basic_search"
|
|
raise NotImplementedError(message)
|
|
|
|
links = {
|
|
"text_units": {},
|
|
}
|
|
max_vals = {
|
|
"text_units": 0,
|
|
}
|
|
|
|
text_units_dfs = []
|
|
|
|
for idx, index_name in enumerate(index_names):
|
|
# Prepare each index's text units dataframe for merging
|
|
text_units_df = text_units_list[idx]
|
|
for i in range(text_units_df.shape[0]):
|
|
links["text_units"][i + max_vals["text_units"]] = {
|
|
"index_name": index_name,
|
|
"id": i,
|
|
}
|
|
text_units_df["id"] = text_units_df["id"].apply(
|
|
lambda x, index_name=index_name: f"{x}-{index_name}"
|
|
)
|
|
text_units_df["human_readable_id"] = (
|
|
text_units_df["human_readable_id"] + max_vals["text_units"]
|
|
)
|
|
max_vals["text_units"] += text_units_df.shape[0]
|
|
text_units_dfs.append(text_units_df)
|
|
|
|
# Merge the dataframes
|
|
text_units_combined = pd.concat(
|
|
text_units_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
|
|
logger.debug("Executing multi-index basic search query: %s", query)
|
|
return await basic_search(
|
|
config,
|
|
text_units=text_units_combined,
|
|
query=query,
|
|
callbacks=callbacks,
|
|
)
|