mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Drift Search CLI, API, Docs and Example Notebook (#1348)
* Drift CLI and backwards compat * Adding DRIFT Cli, Docs and example notebook * Update tests and fix ruff * Format * Small cleanup * Fix smoke tests * Update notebook * Oopsie fix * Delete duplicate img
This commit is contained in:
parent
68dfceef21
commit
d9f985ae52
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Add DRIFT search cli and example notebook"
|
||||
}
|
||||
@ -1,27 +1,35 @@
|
||||
|
||||
<div class="grid cards" markdown>
|
||||
|
||||
- [:octicons-arrow-right-24: __GraphRAG: Unlocking LLM discovery on narrative private data__](https://www.microsoft.com/en-us/research/blog/graphrag-unlocking-llm-discovery-on-narrative-private-data/)
|
||||
- [:octicons-arrow-right-24: **GraphRAG: Unlocking LLM discovery on narrative private data**](https://www.microsoft.com/en-us/research/blog/graphrag-unlocking-llm-discovery-on-narrative-private-data/)
|
||||
|
||||
---
|
||||
<h6>Published February 13, 2024
|
||||
***
|
||||
|
||||
By [Jonathan Larson](https://www.microsoft.com/en-us/research/people/jolarso/), Senior Principal Data Architect; [Steven Truitt](https://www.microsoft.com/en-us/research/people/steventruitt/), Principal Program Manager</h6>
|
||||
|
||||
<h6>Published February 13, 2024
|
||||
|
||||
- [:octicons-arrow-right-24: __GraphRAG: New tool for complex data discovery now on GitHub__](https://www.microsoft.com/en-us/research/blog/graphrag-new-tool-for-complex-data-discovery-now-on-github/)
|
||||
By [Jonathan Larson](https://www.microsoft.com/en-us/research/people/jolarso/), Senior Principal Data Architect; [Steven Truitt](https://www.microsoft.com/en-us/research/people/steventruitt/), Principal Program Manager</h6>
|
||||
|
||||
---
|
||||
<h6>Published July 2, 2024
|
||||
- [:octicons-arrow-right-24: **GraphRAG: New tool for complex data discovery now on GitHub**](https://www.microsoft.com/en-us/research/blog/graphrag-new-tool-for-complex-data-discovery-now-on-github/)
|
||||
|
||||
By [Darren Edge](https://www.microsoft.com/en-us/research/people/daedge/), Senior Director; [Ha Trinh](https://www.microsoft.com/en-us/research/people/trinhha/), Senior Data Scientist; [Steven Truitt](https://www.microsoft.com/en-us/research/people/steventruitt/), Principal Program Manager; [Jonathan Larson](https://www.microsoft.com/en-us/research/people/jolarso/), Senior Principal Data Architect</h6>
|
||||
***
|
||||
|
||||
<h6>Published July 2, 2024
|
||||
|
||||
- [:octicons-arrow-right-24: __GraphRAG auto-tuning provides rapid adaptation to new domains__](https://www.microsoft.com/en-us/research/blog/graphrag-auto-tuning-provides-rapid-adaptation-to-new-domains/)
|
||||
By [Darren Edge](https://www.microsoft.com/en-us/research/people/daedge/), Senior Director; [Ha Trinh](https://www.microsoft.com/en-us/research/people/trinhha/), Senior Data Scientist; [Steven Truitt](https://www.microsoft.com/en-us/research/people/steventruitt/), Principal Program Manager; [Jonathan Larson](https://www.microsoft.com/en-us/research/people/jolarso/), Senior Principal Data Architect</h6>
|
||||
|
||||
---
|
||||
<h6>Published September 9, 2024
|
||||
- [:octicons-arrow-right-24: **GraphRAG auto-tuning provides rapid adaptation to new domains**](https://www.microsoft.com/en-us/research/blog/graphrag-auto-tuning-provides-rapid-adaptation-to-new-domains/)
|
||||
|
||||
By [Alonso Guevara Fernández](https://www.microsoft.com/en-us/research/people/alonsog/), Sr. Software Engineer; Katy Smith, Data Scientist II; [Joshua Bradley](https://www.microsoft.com/en-us/research/people/joshbradley/), Senior Data Scientist; [Darren Edge](https://www.microsoft.com/en-us/research/people/daedge/), Senior Director; [Ha Trinh](https://www.microsoft.com/en-us/research/people/trinhha/), Senior Data Scientist; [Sarah Smith](https://www.microsoft.com/en-us/research/people/smithsarah/), Senior Program Manager; [Ben Cutler](https://www.microsoft.com/en-us/research/people/bcutler/), Senior Director; [Steven Truitt](https://www.microsoft.com/en-us/research/people/steventruitt/), Principal Program Manager; [Jonathan Larson](https://www.microsoft.com/en-us/research/people/jolarso/), Senior Principal Data Architect
|
||||
***
|
||||
|
||||
<h6>Published September 9, 2024
|
||||
|
||||
By [Alonso Guevara Fernández](https://www.microsoft.com/en-us/research/people/alonsog/), Sr. Software Engineer; Katy Smith, Data Scientist II; [Joshua Bradley](https://www.microsoft.com/en-us/research/people/joshbradley/), Senior Data Scientist; [Darren Edge](https://www.microsoft.com/en-us/research/people/daedge/), Senior Director; [Ha Trinh](https://www.microsoft.com/en-us/research/people/trinhha/), Senior Data Scientist; [Sarah Smith](https://www.microsoft.com/en-us/research/people/smithsarah/), Senior Program Manager; [Ben Cutler](https://www.microsoft.com/en-us/research/people/bcutler/), Senior Director; [Steven Truitt](https://www.microsoft.com/en-us/research/people/steventruitt/), Principal Program Manager; [Jonathan Larson](https://www.microsoft.com/en-us/research/people/jolarso/), Senior Principal Data Architect
|
||||
|
||||
- [:octicons-arrow-right-24: **Introducing DRIFT Search: Combining global and local search methods to improve quality and efficiency**](https://www.microsoft.com/en-us/research/blog/introducing-drift-search-combining-global-and-local-search-methods-to-improve-quality-and-efficiency/)
|
||||
|
||||
***
|
||||
|
||||
<h6>Published October 31, 2024
|
||||
|
||||
By Julian Whiting , Senior Machine Learning Engineer; Zachary Hills , Senior Software Engineer; [Alonso Guevara Fernández](https://www.microsoft.com/en-us/research/people/alonsog/), Sr. Software Engineer; [Ha Trinh](https://www.microsoft.com/en-us/research/people/trinhha/), Senior Data Scientist; Adam Bradley , Managing Partner, Strategic Research; [Jonathan Larson](https://www.microsoft.com/en-us/research/people/jolarso/), Senior Principal Data Architect
|
||||
|
||||
</div>
|
||||
|
||||
3194
docs/examples_notebooks/drift_search.ipynb
Normal file
3194
docs/examples_notebooks/drift_search.ipynb
Normal file
File diff suppressed because one or more lines are too long
BIN
docs/img/drift-search-diagram.png
Normal file
BIN
docs/img/drift-search-diagram.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 137 KiB |
@ -51,6 +51,7 @@ At query time, these structures are used to provide materials for the LLM contex
|
||||
|
||||
- [_Global Search_](query/global_search.md) for reasoning about holistic questions about the corpus by leveraging the community summaries.
|
||||
- [_Local Search_](query/local_search.md) for reasoning about specific entities by fanning-out to their neighbors and associated concepts.
|
||||
- [_DRIFT Search_](query/drift_search.md) for reasoning about specific entities by fanning-out to their neighbors and associated concepts, but with the added context of community information.
|
||||
|
||||
### Prompt Tuning
|
||||
|
||||
|
||||
34
docs/query/drift_search.md
Normal file
34
docs/query/drift_search.md
Normal file
@ -0,0 +1,34 @@
|
||||
# DRIFT Search 🔎
|
||||
|
||||
## Combining Local and Global Search
|
||||
|
||||
GraphRAG is a technique that uses large language models (LLMs) to create knowledge graphs and summaries from unstructured text documents and leverages them to improve retrieval-augmented generation (RAG) operations on private datasets. It offers comprehensive global overviews of large, private troves of unstructured text documents while also enabling exploration of detailed, localized information. By using LLMs to create comprehensive knowledge graphs that connect and describe entities and relationships contained in those documents, GraphRAG leverages semantic structuring of the data to generate responses to a wide variety of complex user queries.
|
||||
|
||||
DRIFT search (Dynamic Reasoning and Inference with Flexible Traversal) builds upon Microsoft’s GraphRAG technique, combining characteristics of both global and local search to generate detailed responses in a method that balances computational costs with quality outcomes using our [drift search](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/drift_search/) method.
|
||||
|
||||
## Methodology
|
||||
|
||||
<p align="center">
|
||||
<img src="img/drift-search-diagram.png" alt="Figure 1. An entire DRIFT search hierarchy highlighting the three core phases of the DRIFT search process. A (Primer): DRIFT compares the user’s query with the top K most semantically relevant community reports, generating a broad initial answer and follow-up questions to steer further exploration. B (Follow-Up): DRIFT uses local search to refine queries, producing additional intermediate answers and follow-up questions that enhance specificity, guiding the engine towards context-rich information. A glyph on each node in the diagram shows the confidence the algorithm has to continue the query expansion step. C (Output Hierarchy): The final output is a hierarchical structure of questions and answers ranked by relevance, reflecting a balanced mix of global insights and local refinements, making the results adaptable and comprehensive." width="450" align="center" />
|
||||
</p>
|
||||
<p align="center">
|
||||
|
||||
DRIFT Search introduces a new approach to local search queries by including community information in the search process. This greatly expands the breadth of the query’s starting point and leads to retrieval and usage of a far higher variety of facts in the final answer. This addition expands the GraphRAG query engine by providing a more comprehensive option for local search, which uses community insights to refine a query into detailed follow-up questions.
|
||||
|
||||
## Configuration
|
||||
|
||||
Below are the key parameters of the [DRIFTSearch class](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/drift_search/search.py):
|
||||
|
||||
- `llm`: OpenAI model object to be used for response generation
|
||||
- `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/drift_context.py) object to be used for preparing context data from community reports and query information
|
||||
- `config`: model to define the DRIFT Search hyperparameters. [DRIFT Config model](https://github.com/microsoft/graphrag/blob/main/graphrag/config/models/drift_config.py)
|
||||
- `token_encoder`: token encoder for tracking the budget for the algorithm.
|
||||
- `query_state`: a state object as defined in [Query State](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/state.py) that allows to track execution of a DRIFT Search instance, alongside follow ups and [DRIFT actions](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/action.py).
|
||||
|
||||
## How to Use
|
||||
|
||||
An example of a global search scenario can be found in the following [notebook](../examples_notebooks/drift_search.ipynb).
|
||||
|
||||
## Learn More
|
||||
|
||||
For a more in-depth look at the DRIFT search method, please refer to our [DRIFT Search blog post](https://www.microsoft.com/en-us/research/blog/introducing-drift-search-combining-global-and-local-search-methods-to-improve-quality-and-efficiency/)
|
||||
@ -4,5 +4,6 @@ For examples about running Query please refer to the following notebooks:
|
||||
|
||||
- [Global Search Notebook](../../examples_notebooks/global_search.ipynb)
|
||||
- [Local Search Notebook](../../examples_notebooks/local_search.ipynb)
|
||||
- [DRIFT Search Notebook](../../examples_notebooks/drift_search.ipynb)
|
||||
|
||||
The test dataset for these notebooks can be found in [dataset.zip](../../data/operation_dulce/dataset.zip){:download}.
|
||||
The test dataset for these notebooks can be found in [dataset.zip](../../data/operation_dulce/dataset.zip){:download}.
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
# Query Engine 🔎
|
||||
# Query Engine 🔎
|
||||
|
||||
The Query Engine is the retrieval module of the Graph RAG Library. It is one of the two main components of the Graph RAG library, the other being the Indexing Pipeline (see [Indexing Pipeline](../index/overview.md)).
|
||||
It is responsible for the following tasks:
|
||||
|
||||
- [Local Search](#local-search)
|
||||
- [Global Search](#global-search)
|
||||
- [DRIFT Search](#drift-search)
|
||||
- [Question Generation](#question-generation)
|
||||
|
||||
## Local Search
|
||||
@ -19,6 +20,12 @@ Global search method generates answers by searching over all AI-generated commun
|
||||
|
||||
More about this can be checked at the [Global Search](global_search.md) documentation.
|
||||
|
||||
## DRIFT Search
|
||||
|
||||
DRIFT Search introduces a new approach to local search queries by including community information in the search process. This greatly expands the breadth of the query’s starting point and leads to retrieval and usage of a far higher variety of facts in the final answer. This addition expands the GraphRAG query engine by providing a more comprehensive option for local search, which uses community insights to refine a query into detailed follow-up questions.
|
||||
|
||||
To learn more about DRIFT Search, please refer to the [DRIFT Search](drift_search.md) documentation.
|
||||
|
||||
## Question Generation
|
||||
|
||||
This functionality takes a list of user queries and generates the next candidate questions. This is useful for generating follow-up questions in a conversation or for generating a list of questions for the investigator to dive deeper into the dataset.
|
||||
|
||||
@ -10,6 +10,7 @@ Backwards compatibility is not guaranteed at this time.
|
||||
from graphrag.api.index import build_index
|
||||
from graphrag.api.prompt_tune import DocSelectionType, generate_indexing_prompts
|
||||
from graphrag.api.query import (
|
||||
drift_search,
|
||||
global_search,
|
||||
global_search_streaming,
|
||||
local_search,
|
||||
@ -24,6 +25,7 @@ __all__ = [ # noqa: RUF022
|
||||
"global_search_streaming",
|
||||
"local_search",
|
||||
"local_search_streaming",
|
||||
"drift_search",
|
||||
# prompt tuning API
|
||||
"DocSelectionType",
|
||||
"generate_indexing_prompts",
|
||||
|
||||
@ -26,17 +26,23 @@ from pydantic import validate_call
|
||||
|
||||
from graphrag.config import GraphRagConfig
|
||||
from graphrag.logging import PrintProgressReporter
|
||||
from graphrag.query.factories import get_global_search_engine, get_local_search_engine
|
||||
from graphrag.query.factories import (
|
||||
get_drift_search_engine,
|
||||
get_global_search_engine,
|
||||
get_local_search_engine,
|
||||
)
|
||||
from graphrag.query.indexer_adapters import (
|
||||
read_indexer_covariates,
|
||||
read_indexer_entities,
|
||||
read_indexer_relationships,
|
||||
read_indexer_report_embeddings,
|
||||
read_indexer_reports,
|
||||
read_indexer_text_units,
|
||||
)
|
||||
from graphrag.query.structured_search.base import SearchResult # noqa: TCH001
|
||||
from graphrag.utils.cli import redact
|
||||
from graphrag.vector_stores import VectorStoreFactory, VectorStoreType
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
|
||||
reporter = PrintProgressReporter("")
|
||||
|
||||
@ -195,8 +201,9 @@ async def local_search(
|
||||
|
||||
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
|
||||
|
||||
description_embedding_store = _get_embedding_description_store(
|
||||
description_embedding_store = _get_embedding_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
container_suffix="entity-description",
|
||||
)
|
||||
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
@ -268,8 +275,9 @@ async def local_search_streaming(
|
||||
|
||||
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
|
||||
|
||||
description_embedding_store = _get_embedding_description_store(
|
||||
description_embedding_store = _get_embedding_store(
|
||||
conf_args=vector_store_args, # type: ignore
|
||||
container_suffix="entity-description",
|
||||
)
|
||||
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
@ -300,11 +308,100 @@ async def local_search_streaming(
|
||||
yield stream_chunk
|
||||
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
async def drift_search(
|
||||
config: GraphRagConfig,
|
||||
nodes: pd.DataFrame,
|
||||
entities: pd.DataFrame,
|
||||
community_reports: pd.DataFrame,
|
||||
text_units: pd.DataFrame,
|
||||
relationships: pd.DataFrame,
|
||||
community_level: int,
|
||||
query: str,
|
||||
) -> tuple[
|
||||
str | dict[str, Any] | list[dict[str, Any]],
|
||||
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
||||
]:
|
||||
"""Perform a DRIFT search and return the context data and response.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
||||
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
|
||||
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
|
||||
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
|
||||
- text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet)
|
||||
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from create_final_relationships.parquet)
|
||||
- community_level (int): The community level to search at.
|
||||
- query (str): The user query to search for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
config = _patch_vector_store(
|
||||
config, nodes, entities, community_level, with_reports=community_reports
|
||||
)
|
||||
|
||||
# TODO: update filepath of lancedb (if used) until the new config engine has been implemented
|
||||
# TODO: remove the type ignore annotations below once the new config engine has been refactored
|
||||
vector_store_type = config.embeddings.vector_store.get("type") # type: ignore
|
||||
vector_store_args = config.embeddings.vector_store
|
||||
if vector_store_type == VectorStoreType.LanceDB:
|
||||
db_uri = config.embeddings.vector_store["db_uri"] # type: ignore
|
||||
lancedb_dir = Path(config.root_dir).resolve() / db_uri
|
||||
vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore
|
||||
|
||||
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
|
||||
|
||||
description_embedding_store = _get_embedding_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
container_suffix="entity-description",
|
||||
)
|
||||
|
||||
full_content_embedding_store = _get_embedding_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
container_suffix="community-full-content",
|
||||
)
|
||||
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
_reports = read_indexer_reports(community_reports, nodes, community_level)
|
||||
read_indexer_report_embeddings(_reports, full_content_embedding_store)
|
||||
|
||||
search_engine = get_drift_search_engine(
|
||||
config=config,
|
||||
reports=_reports,
|
||||
text_units=read_indexer_text_units(text_units),
|
||||
entities=_entities,
|
||||
relationships=read_indexer_relationships(relationships),
|
||||
description_embedding_store=description_embedding_store, # type: ignore
|
||||
)
|
||||
|
||||
result: SearchResult = await search_engine.asearch(query=query)
|
||||
response = result.response
|
||||
context_data = _reformat_context_data(result.context_data) # type: ignore
|
||||
|
||||
# TODO: Map/reduce the response to a single string with a comprehensive answer including all follow-ups
|
||||
# For the time being, return highest scoring response (position 0) and context data
|
||||
match response:
|
||||
case dict():
|
||||
return response["nodes"][0]["answer"], context_data # type: ignore
|
||||
case str():
|
||||
return response, context_data
|
||||
case list():
|
||||
return response, context_data
|
||||
|
||||
|
||||
def _patch_vector_store(
|
||||
config: GraphRagConfig,
|
||||
nodes: pd.DataFrame,
|
||||
entities: pd.DataFrame,
|
||||
community_level: int,
|
||||
with_reports: pd.DataFrame | None = None,
|
||||
) -> GraphRagConfig:
|
||||
# TODO: remove the following patch that checks for a vector_store prior to v1 release
|
||||
# TODO: this is a backwards compatibility patch that injects the default vector_store settings into the config if it is not present
|
||||
@ -315,7 +412,9 @@ def _patch_vector_store(
|
||||
# 3 .create lancedb vector_store instance
|
||||
# 4. upload vector embeddings from the input dataframes to the vector_store
|
||||
if not config.embeddings.vector_store:
|
||||
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
|
||||
from graphrag.query.input.loaders.dfs import (
|
||||
store_entity_semantic_embeddings,
|
||||
)
|
||||
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
|
||||
config.embeddings.vector_store = {
|
||||
@ -337,21 +436,59 @@ def _patch_vector_store(
|
||||
store_entity_semantic_embeddings(
|
||||
entities=_entities, vectorstore=description_embedding_store
|
||||
)
|
||||
|
||||
if with_reports is not None:
|
||||
from graphrag.query.input.loaders.dfs import (
|
||||
store_reports_semantic_embeddings,
|
||||
)
|
||||
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
|
||||
community_reports = with_reports
|
||||
collection_name = (
|
||||
config.embeddings.vector_store.get("container_name", "default")
|
||||
if config.embeddings.vector_store
|
||||
else "default"
|
||||
)
|
||||
# Store report embeddings
|
||||
_reports = read_indexer_reports(
|
||||
community_reports,
|
||||
nodes,
|
||||
community_level,
|
||||
content_embedding_col="full_content_embedding",
|
||||
config=config,
|
||||
)
|
||||
|
||||
full_content_embedding_store = LanceDBVectorStore(
|
||||
db_uri=config.embeddings.vector_store["db_uri"],
|
||||
collection_name=f"{collection_name}-community-full-content",
|
||||
overwrite=config.embeddings.vector_store["overwrite"],
|
||||
)
|
||||
full_content_embedding_store.connect(
|
||||
db_uri=config.embeddings.vector_store["db_uri"]
|
||||
)
|
||||
# dump embeddings from the reports list to the full_content_embedding_store
|
||||
store_reports_semantic_embeddings(
|
||||
reports=_reports, vectorstore=full_content_embedding_store
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _get_embedding_description_store(
|
||||
def _get_embedding_store(
|
||||
config_args: dict,
|
||||
):
|
||||
container_suffix: str,
|
||||
) -> BaseVectorStore:
|
||||
"""Get the embedding description store."""
|
||||
vector_store_type = config_args["type"]
|
||||
collection_name = f"{config_args['container_name']}-entity-description"
|
||||
description_embedding_store = VectorStoreFactory.get_vector_store(
|
||||
collection_name = (
|
||||
f"{config_args.get('container_name', 'default')}-{container_suffix}"
|
||||
)
|
||||
embedding_store = VectorStoreFactory.get_vector_store(
|
||||
vector_store_type=vector_store_type,
|
||||
kwargs={**config_args, "collection_name": collection_name},
|
||||
)
|
||||
description_embedding_store.connect(**config_args)
|
||||
return description_embedding_store
|
||||
embedding_store.connect(**config_args)
|
||||
return embedding_store
|
||||
|
||||
|
||||
def _reformat_context_data(context_data: dict) -> dict:
|
||||
@ -374,7 +511,11 @@ def _reformat_context_data(context_data: dict) -> dict:
|
||||
"sources": [],
|
||||
}
|
||||
for key in context_data:
|
||||
records = context_data[key].to_dict(orient="records")
|
||||
records = (
|
||||
context_data[key].to_dict(orient="records")
|
||||
if context_data[key] is not None and not isinstance(context_data[key], dict)
|
||||
else context_data[key]
|
||||
)
|
||||
if len(records) < 1:
|
||||
continue
|
||||
final_format[key] = records
|
||||
|
||||
@ -19,7 +19,7 @@ from graphrag.prompt_tune.loader import MIN_CHUNK_SIZE
|
||||
from .index import index_cli
|
||||
from .initialize import initialize_project_at
|
||||
from .prompt_tune import prompt_tune
|
||||
from .query import run_global_search, run_local_search
|
||||
from .query import run_drift_search, run_global_search, run_local_search
|
||||
|
||||
INVALID_METHOD_ERROR = "Invalid method"
|
||||
|
||||
@ -34,6 +34,7 @@ class SearchType(Enum):
|
||||
|
||||
LOCAL = "local"
|
||||
GLOBAL = "global"
|
||||
DRIFT = "drift"
|
||||
|
||||
def __str__(self):
|
||||
"""Return the string representation of the enum value."""
|
||||
@ -293,5 +294,14 @@ def _query_cli(
|
||||
streaming=streaming,
|
||||
query=query,
|
||||
)
|
||||
case SearchType.DRIFT:
|
||||
run_drift_search(
|
||||
config_filepath=config,
|
||||
data_dir=data,
|
||||
root_dir=root,
|
||||
community_level=community_level,
|
||||
streaming=False, # Drift search does not support streaming (yet)
|
||||
query=query,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(INVALID_METHOD_ERROR)
|
||||
|
||||
@ -190,11 +190,72 @@ def run_local_search(
|
||||
return response, context_data
|
||||
|
||||
|
||||
def run_drift_search(
|
||||
config_filepath: Path | None,
|
||||
data_dir: Path | None,
|
||||
root_dir: Path,
|
||||
community_level: int,
|
||||
streaming: bool,
|
||||
query: str,
|
||||
):
|
||||
"""Perform a local search with a given query.
|
||||
|
||||
Loads index files required for local search and calls the Query API.
|
||||
"""
|
||||
root = root_dir.resolve()
|
||||
config = load_config(root, config_filepath)
|
||||
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
|
||||
resolve_paths(config)
|
||||
|
||||
dataframe_dict = _resolve_parquet_files(
|
||||
root_dir=root_dir,
|
||||
config=config,
|
||||
parquet_list=[
|
||||
"create_final_nodes.parquet",
|
||||
"create_final_community_reports.parquet",
|
||||
"create_final_text_units.parquet",
|
||||
"create_final_relationships.parquet",
|
||||
"create_final_entities.parquet",
|
||||
],
|
||||
)
|
||||
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
|
||||
final_community_reports: pd.DataFrame = dataframe_dict[
|
||||
"create_final_community_reports"
|
||||
]
|
||||
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]
|
||||
final_relationships: pd.DataFrame = dataframe_dict["create_final_relationships"]
|
||||
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
|
||||
|
||||
# call the Query API
|
||||
if streaming:
|
||||
error_msg = "Streaming is not supported yet for DRIFT search."
|
||||
raise NotImplementedError(error_msg)
|
||||
|
||||
# not streaming
|
||||
response, context_data = asyncio.run(
|
||||
api.drift_search(
|
||||
config=config,
|
||||
nodes=final_nodes,
|
||||
entities=final_entities,
|
||||
community_reports=final_community_reports,
|
||||
text_units=final_text_units,
|
||||
relationships=final_relationships,
|
||||
community_level=community_level,
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
reporter.success(f"DRIFT Search Response:\n{response}")
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
# TODO: Map/Reduce Drift Search answer to a single response
|
||||
return response, context_data
|
||||
|
||||
|
||||
def _resolve_parquet_files(
|
||||
root_dir: Path,
|
||||
config: GraphRagConfig,
|
||||
parquet_list: list[str],
|
||||
optional_list: list[str],
|
||||
optional_list: list[str] | None = None,
|
||||
) -> dict[str, pd.DataFrame]:
|
||||
"""Read parquet files to a dataframe dict."""
|
||||
dataframe_dict = {}
|
||||
@ -208,15 +269,16 @@ def _resolve_parquet_files(
|
||||
dataframe_dict[df_key] = df_value
|
||||
|
||||
# for optional parquet files, set the dict entry to None instead of erroring out if it does not exist
|
||||
for optional_file in optional_list:
|
||||
file_exists = asyncio.run(storage_obj.has(optional_file))
|
||||
df_key = optional_file.split(".")[0]
|
||||
if file_exists:
|
||||
df_value = asyncio.run(
|
||||
_load_table_from_storage(name=optional_file, storage=storage_obj)
|
||||
)
|
||||
dataframe_dict[df_key] = df_value
|
||||
else:
|
||||
dataframe_dict[df_key] = None
|
||||
if optional_list:
|
||||
for optional_file in optional_list:
|
||||
file_exists = asyncio.run(storage_obj.has(optional_file))
|
||||
df_key = optional_file.split(".")[0]
|
||||
if file_exists:
|
||||
df_value = asyncio.run(
|
||||
_load_table_from_storage(name=optional_file, storage=storage_obj)
|
||||
)
|
||||
dataframe_dict[df_key] = df_value
|
||||
else:
|
||||
dataframe_dict[df_key] = None
|
||||
|
||||
return dataframe_dict
|
||||
|
||||
@ -22,4 +22,8 @@ all_embeddings: set[str] = {
|
||||
community_full_content_embedding,
|
||||
text_unit_text_embedding,
|
||||
}
|
||||
required_embeddings: set[str] = {entity_description_embedding}
|
||||
required_embeddings: set[str] = {
|
||||
entity_description_embedding,
|
||||
community_full_content_embedding,
|
||||
text_unit_text_embedding,
|
||||
}
|
||||
|
||||
@ -21,6 +21,10 @@ from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKe
|
||||
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
|
||||
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
|
||||
from graphrag.query.llm.oai.typing import OpenaiApiType
|
||||
from graphrag.query.structured_search.drift_search.drift_context import (
|
||||
DRIFTSearchContextBuilder,
|
||||
)
|
||||
from graphrag.query.structured_search.drift_search.search import DRIFTSearch
|
||||
from graphrag.query.structured_search.global_search.community_context import (
|
||||
GlobalCommunityContext,
|
||||
)
|
||||
@ -198,3 +202,31 @@ def get_global_search_engine(
|
||||
concurrent_coroutines=gs_config.concurrency,
|
||||
response_type=response_type,
|
||||
)
|
||||
|
||||
|
||||
def get_drift_search_engine(
|
||||
config: GraphRagConfig,
|
||||
reports: list[CommunityReport],
|
||||
text_units: list[TextUnit],
|
||||
entities: list[Entity],
|
||||
relationships: list[Relationship],
|
||||
description_embedding_store: BaseVectorStore,
|
||||
) -> DRIFTSearch:
|
||||
"""Create a local search engine based on data + configuration."""
|
||||
llm = get_llm(config)
|
||||
text_embedder = get_text_embedder(config)
|
||||
token_encoder = tiktoken.get_encoding(config.encoding_model)
|
||||
|
||||
return DRIFTSearch(
|
||||
llm=llm,
|
||||
context_builder=DRIFTSearchContextBuilder(
|
||||
chat_llm=llm,
|
||||
text_embedder=text_embedder,
|
||||
entities=entities,
|
||||
relationships=relationships,
|
||||
reports=reports,
|
||||
entity_text_embeddings=description_embedding_store,
|
||||
text_units=text_units,
|
||||
),
|
||||
token_encoder=token_encoder,
|
||||
)
|
||||
|
||||
@ -10,7 +10,9 @@ from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.model import CommunityReport, Covariate, Entity, Relationship, TextUnit
|
||||
from graphrag.query.factories import get_text_embedder
|
||||
from graphrag.query.input.loaders.dfs import (
|
||||
read_community_reports,
|
||||
read_covariates,
|
||||
@ -18,6 +20,8 @@ from graphrag.query.input.loaders.dfs import (
|
||||
read_relationships,
|
||||
read_text_units,
|
||||
)
|
||||
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
|
||||
|
||||
def read_indexer_text_units(final_text_units: pd.DataFrame) -> list[TextUnit]:
|
||||
@ -63,7 +67,8 @@ def read_indexer_reports(
|
||||
final_community_reports: pd.DataFrame,
|
||||
final_nodes: pd.DataFrame,
|
||||
community_level: int,
|
||||
content_embedding_col: str | None = None,
|
||||
content_embedding_col: str = "full_content_embedding",
|
||||
config: GraphRagConfig | None = None,
|
||||
) -> list[CommunityReport]:
|
||||
"""Read in the Community Reports from the raw indexing outputs."""
|
||||
report_df = final_community_reports
|
||||
@ -78,6 +83,14 @@ def read_indexer_reports(
|
||||
|
||||
report_df = _filter_under_community_level(report_df, community_level)
|
||||
report_df = report_df.merge(filtered_community_df, on="community", how="inner")
|
||||
if config and (
|
||||
content_embedding_col not in report_df.columns
|
||||
or report_df.loc[:, content_embedding_col].isna().any()
|
||||
):
|
||||
embedder = get_text_embedder(config)
|
||||
report_df = embed_community_reports(
|
||||
report_df, embedder, embedding_col=content_embedding_col
|
||||
)
|
||||
|
||||
return read_community_reports(
|
||||
df=report_df,
|
||||
@ -88,6 +101,15 @@ def read_indexer_reports(
|
||||
)
|
||||
|
||||
|
||||
def read_indexer_report_embeddings(
|
||||
community_reports: list[CommunityReport],
|
||||
embeddings_store: BaseVectorStore,
|
||||
):
|
||||
"""Read in the Community Reports from the raw indexing outputs."""
|
||||
for report in community_reports:
|
||||
report.full_content_embedding = embeddings_store.search_by_id(report.id).vector
|
||||
|
||||
|
||||
def read_indexer_entities(
|
||||
final_nodes: pd.DataFrame,
|
||||
final_entities: pd.DataFrame,
|
||||
@ -133,6 +155,25 @@ def read_indexer_entities(
|
||||
)
|
||||
|
||||
|
||||
def embed_community_reports(
|
||||
reports_df: pd.DataFrame,
|
||||
embedder: OpenAIEmbedding,
|
||||
source_col: str = "full_content",
|
||||
embedding_col: str = "full_content_embedding",
|
||||
) -> pd.DataFrame:
|
||||
"""Embed a source column of the reports dataframe using the given embedder."""
|
||||
if source_col not in reports_df.columns:
|
||||
error_msg = f"Reports missing {source_col} column"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if embedding_col not in reports_df.columns:
|
||||
reports_df[embedding_col] = reports_df.loc[:, source_col].apply(
|
||||
lambda x: embedder.embed(x)
|
||||
)
|
||||
|
||||
return reports_df
|
||||
|
||||
|
||||
def _filter_under_community_level(
|
||||
df: pd.DataFrame, community_level: int
|
||||
) -> pd.DataFrame:
|
||||
|
||||
@ -114,6 +114,28 @@ def store_entity_behavior_embeddings(
|
||||
return vectorstore
|
||||
|
||||
|
||||
def store_reports_semantic_embeddings(
|
||||
reports: list[CommunityReport],
|
||||
vectorstore: BaseVectorStore,
|
||||
) -> BaseVectorStore:
|
||||
"""Store entity semantic embeddings in a vectorstore."""
|
||||
documents = [
|
||||
VectorStoreDocument(
|
||||
id=report.id,
|
||||
text=report.full_content,
|
||||
vector=report.full_content_embedding,
|
||||
attributes=(
|
||||
{"title": report.title, **report.attributes}
|
||||
if report.attributes
|
||||
else {"title": report.title}
|
||||
),
|
||||
)
|
||||
for report in reports
|
||||
]
|
||||
vectorstore.load_documents(documents=documents)
|
||||
return vectorstore
|
||||
|
||||
|
||||
def read_relationships(
|
||||
df: pd.DataFrame,
|
||||
id_col: str = "id",
|
||||
|
||||
@ -114,7 +114,9 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
"""
|
||||
report_df = pd.DataFrame([asdict(report) for report in reports])
|
||||
missing_content_error = "Some reports are missing full content."
|
||||
missing_embedding_error = "Some reports are missing full content embeddings."
|
||||
missing_embedding_error = (
|
||||
"Some reports are missing full content embeddings. {missing} out of {total}"
|
||||
)
|
||||
|
||||
if (
|
||||
"full_content" not in report_df.columns
|
||||
@ -126,7 +128,12 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
"full_content_embedding" not in report_df.columns
|
||||
or report_df["full_content_embedding"].isna().sum() > 0
|
||||
):
|
||||
raise ValueError(missing_embedding_error)
|
||||
raise ValueError(
|
||||
missing_embedding_error.format(
|
||||
missing=report_df["full_content_embedding"].isna().sum(),
|
||||
total=len(report_df),
|
||||
)
|
||||
)
|
||||
return report_df
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -194,3 +194,13 @@ class AzureAISearch(BaseVectorStore):
|
||||
query_embedding=query_embedding, k=k
|
||||
)
|
||||
return []
|
||||
|
||||
def search_by_id(self, id: str) -> VectorStoreDocument:
|
||||
"""Search for a document by id."""
|
||||
response = self.db_connection.get_document(id)
|
||||
return VectorStoreDocument(
|
||||
id=response.get("id", ""),
|
||||
text=response.get("text", ""),
|
||||
vector=response.get("vector", []),
|
||||
attributes=(json.loads(response.get("attributes", "{}"))),
|
||||
)
|
||||
|
||||
@ -79,3 +79,7 @@ class BaseVectorStore(ABC):
|
||||
@abstractmethod
|
||||
def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
|
||||
"""Build a query filter to filter documents by id."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_id(self, id: str) -> VectorStoreDocument:
|
||||
"""Search for a document by id."""
|
||||
|
||||
@ -135,3 +135,19 @@ class LanceDBVectorStore(BaseVectorStore):
|
||||
if query_embedding:
|
||||
return self.similarity_search_by_vector(query_embedding, k)
|
||||
return []
|
||||
|
||||
def search_by_id(self, id: str) -> VectorStoreDocument:
|
||||
"""Search for a document by id."""
|
||||
doc = (
|
||||
self.document_collection.search()
|
||||
.where(f"id == '{id}'", prefilter=True)
|
||||
.to_list()
|
||||
)
|
||||
if doc:
|
||||
return VectorStoreDocument(
|
||||
id=doc[0]["id"],
|
||||
text=doc[0]["text"],
|
||||
vector=doc[0]["vector"],
|
||||
attributes=json.loads(doc[0]["attributes"]),
|
||||
)
|
||||
return VectorStoreDocument(id=id, text=None, vector=None)
|
||||
|
||||
@ -42,13 +42,15 @@ nav:
|
||||
- Manual Tuning: "prompt_tuning/manual_prompt_tuning.md"
|
||||
- Query:
|
||||
- Overview: "query/overview.md"
|
||||
- Local Search: "query/local_search.md"
|
||||
- Question Generation: "query/question_generation.md"
|
||||
- Global Search: "query/global_search.md"
|
||||
- Local Search: "query/local_search.md"
|
||||
- DRIFT Search: "query/drift_search.md"
|
||||
- Question Generation: "query/question_generation.md"
|
||||
- Notebooks:
|
||||
- Overview: "query/notebooks/overview.md"
|
||||
- Global Search: "examples_notebooks/global_search.ipynb"
|
||||
- Local Search: "examples_notebooks/local_search.ipynb"
|
||||
- DRIFT Search: "examples_notebooks/drift_search.ipynb"
|
||||
- Microsoft Research Blog: "blog_posts.md"
|
||||
- Extras:
|
||||
- CLI: "cli.md"
|
||||
|
||||
@ -197,8 +197,10 @@ class TestIndexer:
|
||||
|
||||
# check that the number of workflows matches the number of artifacts
|
||||
assert (
|
||||
len(artifact_files) == (expected_artifacts + 1)
|
||||
), f"Expected {len(expected_workflows) + 1} artifacts, found: {len(artifact_files)}"
|
||||
len(artifact_files) == (expected_artifacts + 3)
|
||||
), (
|
||||
f"Expected {expected_artifacts + 3} artifacts, found: {len(artifact_files)}"
|
||||
) # Embeddings add to the count
|
||||
|
||||
for artifact in artifact_files:
|
||||
if artifact.endswith(".parquet"):
|
||||
|
||||
@ -54,6 +54,11 @@ class MockBaseVectorStore(BaseVectorStore):
|
||||
def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
|
||||
return [document for document in self.documents if document.id in include_ids]
|
||||
|
||||
def search_by_id(self, id: str) -> VectorStoreDocument:
|
||||
result = self.documents[0]
|
||||
result.id = id
|
||||
return result
|
||||
|
||||
|
||||
class MockBaseTextEmbedding(BaseTextEmbedding):
|
||||
def embed(self, text: str, **kwargs: Any) -> list[float]:
|
||||
|
||||
@ -1,15 +1,36 @@
|
||||
# Config Breaking Changes
|
||||
|
||||
## New required Embeddings
|
||||
|
||||
### Change
|
||||
|
||||
- Added new required embeddings for `DRIFTSearch` and base RAG capabilities.
|
||||
|
||||
### Migration
|
||||
|
||||
- Run a new index, leveraging existing cache.
|
||||
|
||||
## Vector Store required by default
|
||||
|
||||
### Change
|
||||
|
||||
- Vector store is now required by default for all search methods.
|
||||
|
||||
### Migration
|
||||
|
||||
- Run graphrag init command to generate a new settings.yaml file with the vector store configuration.
|
||||
- Run a new index, leveraging existing cache.
|
||||
|
||||
## Deprecate timestamp paths
|
||||
|
||||
### Change
|
||||
|
||||
- Remove support for timestamp paths, those using `${timestamp}` directory nesting.
|
||||
- Remove support for timestamp paths, those using `${timestamp}` directory nesting.
|
||||
- Use the same directory for storage output and reporting output.
|
||||
|
||||
### Migration
|
||||
|
||||
- Ensure output directories no longer use `${timestamp}` directory nesting.
|
||||
- Ensure output directories no longer use `${timestamp}` directory nesting.
|
||||
|
||||
**Using Environment Variables**
|
||||
|
||||
@ -33,4 +54,4 @@ reporting:
|
||||
base_dir: "output" # changed from "output/${timestamp}/reports"
|
||||
```
|
||||
|
||||
[Full docs on using JSON or YAML files for configuration](https://microsoft.github.io/graphrag/config/json_yaml/).
|
||||
[Full docs on using JSON or YAML files for configuration](https://microsoft.github.io/graphrag/config/json_yaml/).
|
||||
|
||||
Loading…
Reference in New Issue
Block a user