Reduce Drift Response and Streaming endpoint (#1624)

* Adding basic wrappes for reduce in drift

* Add response_type parameter to run_drift_search and enhance reduce response functionality

* Add streaming endpoint

* Semver

* Spellcheck

* Ruff checks

* Count tokens on reduce

* Use list comprehension and remove llm_params map in favor of just using kwargs
This commit is contained in:
Alonso Guevara 2025-01-15 14:23:25 -06:00 committed by GitHub
parent 4637270e9a
commit 3defab2ea4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 810 additions and 582 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Add Drift Reduce response and streaming endpoint"
}

View File

@ -13,6 +13,7 @@ from graphrag.api.query import (
basic_search,
basic_search_streaming,
drift_search,
drift_search_streaming,
global_search,
global_search_streaming,
local_search,
@ -29,6 +30,7 @@ __all__ = [ # noqa: RUF022
"local_search",
"local_search_streaming",
"drift_search",
"drift_search_streaming",
"basic_search",
"basic_search_streaming",
# prompt tuning API

View File

@ -348,6 +348,87 @@ async def local_search_streaming(
yield stream_chunk
@validate_call(config={"arbitrary_types_allowed": True})
async def drift_search_streaming(
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
community_reports: pd.DataFrame,
text_units: pd.DataFrame,
relationships: pd.DataFrame,
community_level: int,
response_type: str,
query: str,
) -> AsyncGenerator:
"""Perform a DRIFT search and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet)
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from create_final_relationships.parquet)
- community_level (int): The community level to search at.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
description_embedding_store = _get_embedding_store(
config_args=vector_store_args, # type: ignore
embedding_name=entity_description_embedding,
)
full_content_embedding_store = _get_embedding_store(
config_args=vector_store_args, # type: ignore
embedding_name=community_full_content_embedding,
)
entities_ = read_indexer_entities(nodes, entities, community_level)
reports = read_indexer_reports(community_reports, nodes, community_level)
read_indexer_report_embeddings(reports, full_content_embedding_store)
prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt)
reduce_prompt = _load_search_prompt(
config.root_dir, config.drift_search.reduce_prompt
)
search_engine = get_drift_search_engine(
config=config,
reports=reports,
text_units=read_indexer_text_units(text_units),
entities=entities_,
relationships=read_indexer_relationships(relationships),
description_embedding_store=description_embedding_store, # type: ignore
local_system_prompt=prompt,
reduce_system_prompt=reduce_prompt,
response_type=response_type,
)
search_result = search_engine.astream_search(query=query)
# when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens
context_data = None
get_context_data = True
async for stream_chunk in search_result:
if get_context_data:
context_data = _reformat_context_data(stream_chunk) # type: ignore
yield context_data
get_context_data = False
else:
yield stream_chunk
@validate_call(config={"arbitrary_types_allowed": True})
async def drift_search(
config: GraphRagConfig,
@ -357,6 +438,7 @@ async def drift_search(
text_units: pd.DataFrame,
relationships: pd.DataFrame,
community_level: int,
response_type: str,
query: str,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
@ -400,6 +482,10 @@ async def drift_search(
reports = read_indexer_reports(community_reports, nodes, community_level)
read_indexer_report_embeddings(reports, full_content_embedding_store)
prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt)
reduce_prompt = _load_search_prompt(
config.root_dir, config.drift_search.reduce_prompt
)
search_engine = get_drift_search_engine(
config=config,
reports=reports,
@ -408,21 +494,15 @@ async def drift_search(
relationships=read_indexer_relationships(relationships),
description_embedding_store=description_embedding_store, # type: ignore
local_system_prompt=prompt,
reduce_system_prompt=reduce_prompt,
response_type=response_type,
)
result: SearchResult = await search_engine.asearch(query=query)
response = result.response
context_data = _reformat_context_data(result.context_data) # type: ignore
# TODO: Map/reduce the response to a single string with a comprehensive answer including all follow-ups
# For the time being, return highest scoring response (position 0) and context data
match response:
case dict():
return response["nodes"][0]["answer"], context_data # type: ignore
case str():
return response, context_data
case list():
return response, context_data
return response, context_data
@validate_call(config={"arbitrary_types_allowed": True})

View File

@ -14,7 +14,10 @@ from graphrag.prompts.index.community_report import (
from graphrag.prompts.index.entity_extraction import GRAPH_EXTRACTION_PROMPT
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
from graphrag.prompts.query.basic_search_system_prompt import BASIC_SEARCH_SYSTEM_PROMPT
from graphrag.prompts.query.drift_search_system_prompt import DRIFT_LOCAL_SYSTEM_PROMPT
from graphrag.prompts.query.drift_search_system_prompt import (
DRIFT_LOCAL_SYSTEM_PROMPT,
DRIFT_REDUCE_PROMPT,
)
from graphrag.prompts.query.global_search_knowledge_system_prompt import (
GENERAL_KNOWLEDGE_INSTRUCTION,
)
@ -57,6 +60,7 @@ def initialize_project_at(path: Path) -> None:
"claim_extraction": CLAIM_EXTRACTION_PROMPT,
"community_report": COMMUNITY_REPORT_PROMPT,
"drift_search_system_prompt": DRIFT_LOCAL_SYSTEM_PROMPT,
"drift_reduce_prompt": DRIFT_REDUCE_PROMPT,
"global_search_map_system_prompt": MAP_SYSTEM_PROMPT,
"global_search_reduce_system_prompt": REDUCE_SYSTEM_PROMPT,
"global_search_knowledge_system_prompt": GENERAL_KNOWLEDGE_INSTRUCTION,

View File

@ -460,7 +460,8 @@ def _query_cli(
data_dir=data,
root_dir=root,
community_level=community_level,
streaming=False, # Drift search does not support streaming (yet)
streaming=streaming,
response_type=response_type,
query=query,
)
case SearchType.BASIC:

View File

@ -202,6 +202,7 @@ def run_drift_search(
data_dir: Path | None,
root_dir: Path,
community_level: int,
response_type: str,
streaming: bool,
query: str,
):
@ -234,8 +235,33 @@ def run_drift_search(
# call the Query API
if streaming:
error_msg = "Streaming is not supported yet for DRIFT search."
raise NotImplementedError(error_msg)
async def run_streaming_search():
full_response = ""
context_data = None
get_context_data = True
async for stream_chunk in api.drift_search_streaming(
config=config,
nodes=final_nodes,
entities=final_entities,
community_reports=final_community_reports,
text_units=final_text_units,
relationships=final_relationships,
community_level=community_level,
response_type=response_type,
query=query,
):
if get_context_data:
context_data = stream_chunk
get_context_data = False
else:
full_response += stream_chunk
print(stream_chunk, end="") # noqa: T201
sys.stdout.flush() # flush output buffer to display text immediately
print() # noqa: T201
return full_response, context_data
return asyncio.run(run_streaming_search())
# not streaming
response, context_data = asyncio.run(
@ -247,6 +273,7 @@ def run_drift_search(
text_units=final_text_units,
relationships=final_relationships,
community_level=community_level,
response_type=response_type,
query=query,
)
)
@ -281,8 +308,6 @@ def run_basic_search(
)
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]
print(streaming) # noqa: T201
# # call the Query API
if streaming:

View File

@ -589,6 +589,7 @@ def create_graphrag_config(
):
drift_search_model = DRIFTSearchConfig(
prompt=reader.str("prompt") or None,
reduce_prompt=reader.str("reduce_prompt") or None,
temperature=reader.float("llm_temperature")
or defs.DRIFT_SEARCH_LLM_TEMPERATURE,
top_p=reader.float("llm_top_p") or defs.DRIFT_SEARCH_LLM_TOP_P,
@ -597,6 +598,10 @@ def create_graphrag_config(
or defs.DRIFT_SEARCH_MAX_TOKENS,
data_max_tokens=reader.int("data_max_tokens")
or defs.DRIFT_SEARCH_DATA_MAX_TOKENS,
reduce_max_tokens=reader.int("reduce_max_tokens")
or defs.DRIFT_SEARCH_REDUCE_MAX_TOKENS,
reduce_temperature=reader.float("reduce_temperature")
or defs.DRIFT_SEARCH_REDUCE_LLM_TEMPERATURE,
concurrency=reader.int("concurrency") or defs.DRIFT_SEARCH_CONCURRENCY,
drift_k_followups=reader.int("drift_k_followups")
or defs.DRIFT_SEARCH_K_FOLLOW_UPS,

View File

@ -149,6 +149,9 @@ DRIFT_SEARCH_K_FOLLOW_UPS = 20
DRIFT_SEARCH_PRIMER_FOLDS = 5
DRIFT_SEARCH_PRIMER_MAX_TOKENS = 12_000
DRIFT_SEARCH_REDUCE_LLM_TEMPERATURE = 0
DRIFT_SEARCH_REDUCE_MAX_TOKENS = 2_000
DRIFT_LOCAL_SEARCH_TEXT_UNIT_PROP = 0.9
DRIFT_LOCAL_SEARCH_COMMUNITY_PROP = 0.1
DRIFT_LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES = 10

View File

@ -132,6 +132,7 @@ global_search:
drift_search:
prompt: "prompts/drift_search_system_prompt.txt"
reduce_prompt: "prompts/drift_search_reduce_prompt.txt"
basic_search:
prompt: "prompts/basic_search_system_prompt.txt"

View File

@ -14,6 +14,9 @@ class DRIFTSearchConfig(BaseModel):
prompt: str | None = Field(
description="The drift search prompt to use.", default=None
)
reduce_prompt: str | None = Field(
description="The drift search reduce prompt to use.", default=None
)
temperature: float = Field(
description="The temperature to use for token generation.",
default=defs.DRIFT_SEARCH_LLM_TEMPERATURE,
@ -35,6 +38,16 @@ class DRIFTSearchConfig(BaseModel):
default=defs.DRIFT_SEARCH_DATA_MAX_TOKENS,
)
reduce_max_tokens: int = Field(
description="The reduce llm maximum tokens response to produce.",
default=defs.DRIFT_SEARCH_REDUCE_MAX_TOKENS,
)
reduce_temperature: float = Field(
description="The temperature to use for token generation in reduce.",
default=defs.DRIFT_SEARCH_REDUCE_LLM_TEMPERATURE,
)
concurrency: int = Field(
description="The number of concurrent requests.",
default=defs.DRIFT_SEARCH_CONCURRENCY,

View File

@ -106,7 +106,7 @@ If you decide to use general knowledge, you should add a delimiter stating that
---Target response length and format---
Multiple paragraphs
{response_type}
---Goal---
@ -133,8 +133,6 @@ If you decide to use general knowledge, you should add a delimiter stating that
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. Now answer the following query using the data above:
{query}
"""

View File

@ -173,7 +173,9 @@ def get_drift_search_engine(
entities: list[Entity],
relationships: list[Relationship],
description_embedding_store: BaseVectorStore,
response_type: str,
local_system_prompt: str | None = None,
reduce_system_prompt: str | None = None,
) -> DRIFTSearch:
"""Create a local search engine based on data + configuration."""
llm = get_llm(config)
@ -191,7 +193,9 @@ def get_drift_search_engine(
entity_text_embeddings=description_embedding_store,
text_units=text_units,
local_system_prompt=local_system_prompt,
reduce_system_prompt=reduce_system_prompt,
config=config.drift_search,
response_type=response_type,
),
token_encoder=token_encoder,
)

View File

@ -19,6 +19,7 @@ from graphrag.model.relationship import Relationship
from graphrag.model.text_unit import TextUnit
from graphrag.prompts.query.drift_search_system_prompt import (
DRIFT_LOCAL_SYSTEM_PROMPT,
DRIFT_REDUCE_PROMPT,
)
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.llm.base import BaseTextEmbedding
@ -51,6 +52,8 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
config: DRIFTSearchConfig | None = None,
local_system_prompt: str | None = None,
local_mixed_context: LocalSearchMixedContext | None = None,
reduce_system_prompt: str | None = None,
response_type: str | None = None,
):
"""Initialize the DRIFT search context builder with necessary components."""
self.config = config or DRIFTSearchConfig()
@ -58,6 +61,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
self.text_embedder = text_embedder
self.token_encoder = token_encoder
self.local_system_prompt = local_system_prompt or DRIFT_LOCAL_SYSTEM_PROMPT
self.reduce_system_prompt = reduce_system_prompt or DRIFT_REDUCE_PROMPT
self.entities = entities
self.entity_text_embeddings = entity_text_embeddings
@ -67,6 +71,8 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
self.covariates = covariates
self.embedding_vectorstore_key = embedding_vectorstore_key
self.response_type = response_type
self.local_mixed_context = (
local_mixed_context or self.init_local_context_builder()
)

View File

@ -15,6 +15,7 @@ from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import BaseSearch, SearchResult
from graphrag.query.structured_search.drift_search.action import DriftAction
from graphrag.query.structured_search.drift_search.drift_context import (
@ -169,6 +170,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
self,
query: str,
conversation_history: Any = None,
reduce: bool = True,
**kwargs,
) -> SearchResult:
"""
@ -177,6 +179,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
Args:
query (str): The query to search for.
conversation_history (Any, optional): The conversation history, if any.
reduce (bool, optional): Whether to reduce the response to a single comprehensive response.
Returns
-------
@ -248,8 +251,21 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
include_context=True
)
reduced_response = response_state
if reduce:
# Reduce response_state to a single comprehensive response
reduced_response = await self._reduce_response(
responses=response_state,
query=query,
llm_calls=llm_calls,
prompt_tokens=prompt_tokens,
output_tokens=output_tokens,
max_tokens=self.config.reduce_max_tokens,
temperature=self.config.reduce_temperature,
)
return SearchResult(
response=response_state,
response=reduced_response,
context_data=context_data,
context_text=context_text,
completion_time=t_elapsed,
@ -281,7 +297,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
error_msg = "Synchronous DRIFT is not implemented."
raise NotImplementedError(error_msg)
def astream_search(
async def astream_search(
self, query: str, conversation_history: ConversationHistory | None = None
) -> AsyncGenerator[str, None]:
"""
@ -290,10 +306,127 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
Args:
query (str): The query to search for.
conversation_history (ConversationHistory, optional): The conversation history.
Raises
------
NotImplementedError: Streaming DRIFT search is not implemented.
"""
error_msg = "Streaming DRIFT search is not implemented."
raise NotImplementedError(error_msg)
result = await self.asearch(
query=query, conversation_history=conversation_history, reduce=False
)
if isinstance(result.response, list):
result.response = result.response[0]
async for resp in self._reduce_response_streaming(
responses=result.response,
query=query,
max_tokens=self.config.reduce_max_tokens,
temperature=self.config.reduce_temperature,
):
yield resp
async def _reduce_response(
self,
responses: str | dict[str, Any],
query: str,
llm_calls: dict[str, int],
prompt_tokens: dict[str, int],
output_tokens: dict[str, int],
**llm_kwargs,
) -> str:
"""Reduce the response to a single comprehensive response.
Parameters
----------
responses : str|dict[str, Any]
The responses to reduce.
query : str
The original query.
llm_kwargs : dict[str, Any]
Additional keyword arguments to pass to the LLM.
Returns
-------
str
The reduced response.
"""
reduce_responses = []
if isinstance(responses, str):
reduce_responses = [responses]
else:
reduce_responses = [
response["answer"]
for response in responses.get("nodes", [])
if response.get("answer")
]
search_prompt = self.context_builder.reduce_system_prompt.format(
context_data=reduce_responses,
response_type=self.context_builder.response_type,
)
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
reduced_response = self.llm.generate(
messages=search_messages,
streaming=False,
callbacks=None,
**llm_kwargs,
)
llm_calls["reduce"] = 1
prompt_tokens["reduce"] = num_tokens(
search_prompt, self.token_encoder
) + num_tokens(query, self.token_encoder)
output_tokens["reduce"] = num_tokens(reduced_response, self.token_encoder)
return reduced_response
async def _reduce_response_streaming(
self,
responses: str | dict[str, Any],
query: str,
**llm_kwargs,
) -> AsyncGenerator[str, None]:
"""Reduce the response to a single comprehensive response.
Parameters
----------
responses : str|dict[str, Any]
The responses to reduce.
query : str
The original query.
llm_kwargs : dict[str, Any]
Additional keyword arguments to pass to the LLM.
Returns
-------
str
The reduced response.
"""
reduce_responses = []
if isinstance(responses, str):
reduce_responses = [responses]
else:
reduce_responses = [
response["answer"]
for response in responses.get("nodes", [])
if response.get("answer")
]
search_prompt = self.context_builder.reduce_system_prompt.format(
context_data=reduce_responses,
response_type=self.context_builder.response_type,
)
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
async for resp in self.llm.astream_generate(
search_messages,
callbacks=None,
**llm_kwargs,
):
yield resp

1059
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -65,13 +65,11 @@ numpy = "^1.25.2"
graspologic = "^3.4.1"
networkx = "^3.4.2"
pandas = "^2.2.3"
matplotlib = "^3.9.3"
pyarrow = "^15.0.0"
umap-learn = "^0.5.6"
# Configuration
pyyaml = "^6.0.2"
pyaml-env = "^1.2.1"
python-dotenv = "^1.0.1"
pydantic = "^2.10.3"