mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Reduce Drift Response and Streaming endpoint (#1624)
* Adding basic wrappes for reduce in drift * Add response_type parameter to run_drift_search and enhance reduce response functionality * Add streaming endpoint * Semver * Spellcheck * Ruff checks * Count tokens on reduce * Use list comprehension and remove llm_params map in favor of just using kwargs
This commit is contained in:
parent
4637270e9a
commit
3defab2ea4
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Add Drift Reduce response and streaming endpoint"
|
||||
}
|
||||
@ -13,6 +13,7 @@ from graphrag.api.query import (
|
||||
basic_search,
|
||||
basic_search_streaming,
|
||||
drift_search,
|
||||
drift_search_streaming,
|
||||
global_search,
|
||||
global_search_streaming,
|
||||
local_search,
|
||||
@ -29,6 +30,7 @@ __all__ = [ # noqa: RUF022
|
||||
"local_search",
|
||||
"local_search_streaming",
|
||||
"drift_search",
|
||||
"drift_search_streaming",
|
||||
"basic_search",
|
||||
"basic_search_streaming",
|
||||
# prompt tuning API
|
||||
|
||||
@ -348,6 +348,87 @@ async def local_search_streaming(
|
||||
yield stream_chunk
|
||||
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
async def drift_search_streaming(
|
||||
config: GraphRagConfig,
|
||||
nodes: pd.DataFrame,
|
||||
entities: pd.DataFrame,
|
||||
community_reports: pd.DataFrame,
|
||||
text_units: pd.DataFrame,
|
||||
relationships: pd.DataFrame,
|
||||
community_level: int,
|
||||
response_type: str,
|
||||
query: str,
|
||||
) -> AsyncGenerator:
|
||||
"""Perform a DRIFT search and return the context data and response.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
||||
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
|
||||
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
|
||||
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
|
||||
- text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet)
|
||||
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from create_final_relationships.parquet)
|
||||
- community_level (int): The community level to search at.
|
||||
- query (str): The user query to search for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
vector_store_args = config.embeddings.vector_store
|
||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
||||
|
||||
description_embedding_store = _get_embedding_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
embedding_name=entity_description_embedding,
|
||||
)
|
||||
|
||||
full_content_embedding_store = _get_embedding_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
embedding_name=community_full_content_embedding,
|
||||
)
|
||||
|
||||
entities_ = read_indexer_entities(nodes, entities, community_level)
|
||||
reports = read_indexer_reports(community_reports, nodes, community_level)
|
||||
read_indexer_report_embeddings(reports, full_content_embedding_store)
|
||||
prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt)
|
||||
reduce_prompt = _load_search_prompt(
|
||||
config.root_dir, config.drift_search.reduce_prompt
|
||||
)
|
||||
|
||||
search_engine = get_drift_search_engine(
|
||||
config=config,
|
||||
reports=reports,
|
||||
text_units=read_indexer_text_units(text_units),
|
||||
entities=entities_,
|
||||
relationships=read_indexer_relationships(relationships),
|
||||
description_embedding_store=description_embedding_store, # type: ignore
|
||||
local_system_prompt=prompt,
|
||||
reduce_system_prompt=reduce_prompt,
|
||||
response_type=response_type,
|
||||
)
|
||||
|
||||
search_result = search_engine.astream_search(query=query)
|
||||
|
||||
# when streaming results, a context data object is returned as the first result
|
||||
# and the query response in subsequent tokens
|
||||
context_data = None
|
||||
get_context_data = True
|
||||
async for stream_chunk in search_result:
|
||||
if get_context_data:
|
||||
context_data = _reformat_context_data(stream_chunk) # type: ignore
|
||||
yield context_data
|
||||
get_context_data = False
|
||||
else:
|
||||
yield stream_chunk
|
||||
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
async def drift_search(
|
||||
config: GraphRagConfig,
|
||||
@ -357,6 +438,7 @@ async def drift_search(
|
||||
text_units: pd.DataFrame,
|
||||
relationships: pd.DataFrame,
|
||||
community_level: int,
|
||||
response_type: str,
|
||||
query: str,
|
||||
) -> tuple[
|
||||
str | dict[str, Any] | list[dict[str, Any]],
|
||||
@ -400,6 +482,10 @@ async def drift_search(
|
||||
reports = read_indexer_reports(community_reports, nodes, community_level)
|
||||
read_indexer_report_embeddings(reports, full_content_embedding_store)
|
||||
prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt)
|
||||
reduce_prompt = _load_search_prompt(
|
||||
config.root_dir, config.drift_search.reduce_prompt
|
||||
)
|
||||
|
||||
search_engine = get_drift_search_engine(
|
||||
config=config,
|
||||
reports=reports,
|
||||
@ -408,21 +494,15 @@ async def drift_search(
|
||||
relationships=read_indexer_relationships(relationships),
|
||||
description_embedding_store=description_embedding_store, # type: ignore
|
||||
local_system_prompt=prompt,
|
||||
reduce_system_prompt=reduce_prompt,
|
||||
response_type=response_type,
|
||||
)
|
||||
|
||||
result: SearchResult = await search_engine.asearch(query=query)
|
||||
response = result.response
|
||||
context_data = _reformat_context_data(result.context_data) # type: ignore
|
||||
|
||||
# TODO: Map/reduce the response to a single string with a comprehensive answer including all follow-ups
|
||||
# For the time being, return highest scoring response (position 0) and context data
|
||||
match response:
|
||||
case dict():
|
||||
return response["nodes"][0]["answer"], context_data # type: ignore
|
||||
case str():
|
||||
return response, context_data
|
||||
case list():
|
||||
return response, context_data
|
||||
return response, context_data
|
||||
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
|
||||
@ -14,7 +14,10 @@ from graphrag.prompts.index.community_report import (
|
||||
from graphrag.prompts.index.entity_extraction import GRAPH_EXTRACTION_PROMPT
|
||||
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
|
||||
from graphrag.prompts.query.basic_search_system_prompt import BASIC_SEARCH_SYSTEM_PROMPT
|
||||
from graphrag.prompts.query.drift_search_system_prompt import DRIFT_LOCAL_SYSTEM_PROMPT
|
||||
from graphrag.prompts.query.drift_search_system_prompt import (
|
||||
DRIFT_LOCAL_SYSTEM_PROMPT,
|
||||
DRIFT_REDUCE_PROMPT,
|
||||
)
|
||||
from graphrag.prompts.query.global_search_knowledge_system_prompt import (
|
||||
GENERAL_KNOWLEDGE_INSTRUCTION,
|
||||
)
|
||||
@ -57,6 +60,7 @@ def initialize_project_at(path: Path) -> None:
|
||||
"claim_extraction": CLAIM_EXTRACTION_PROMPT,
|
||||
"community_report": COMMUNITY_REPORT_PROMPT,
|
||||
"drift_search_system_prompt": DRIFT_LOCAL_SYSTEM_PROMPT,
|
||||
"drift_reduce_prompt": DRIFT_REDUCE_PROMPT,
|
||||
"global_search_map_system_prompt": MAP_SYSTEM_PROMPT,
|
||||
"global_search_reduce_system_prompt": REDUCE_SYSTEM_PROMPT,
|
||||
"global_search_knowledge_system_prompt": GENERAL_KNOWLEDGE_INSTRUCTION,
|
||||
|
||||
@ -460,7 +460,8 @@ def _query_cli(
|
||||
data_dir=data,
|
||||
root_dir=root,
|
||||
community_level=community_level,
|
||||
streaming=False, # Drift search does not support streaming (yet)
|
||||
streaming=streaming,
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
)
|
||||
case SearchType.BASIC:
|
||||
|
||||
@ -202,6 +202,7 @@ def run_drift_search(
|
||||
data_dir: Path | None,
|
||||
root_dir: Path,
|
||||
community_level: int,
|
||||
response_type: str,
|
||||
streaming: bool,
|
||||
query: str,
|
||||
):
|
||||
@ -234,8 +235,33 @@ def run_drift_search(
|
||||
|
||||
# call the Query API
|
||||
if streaming:
|
||||
error_msg = "Streaming is not supported yet for DRIFT search."
|
||||
raise NotImplementedError(error_msg)
|
||||
|
||||
async def run_streaming_search():
|
||||
full_response = ""
|
||||
context_data = None
|
||||
get_context_data = True
|
||||
async for stream_chunk in api.drift_search_streaming(
|
||||
config=config,
|
||||
nodes=final_nodes,
|
||||
entities=final_entities,
|
||||
community_reports=final_community_reports,
|
||||
text_units=final_text_units,
|
||||
relationships=final_relationships,
|
||||
community_level=community_level,
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
):
|
||||
if get_context_data:
|
||||
context_data = stream_chunk
|
||||
get_context_data = False
|
||||
else:
|
||||
full_response += stream_chunk
|
||||
print(stream_chunk, end="") # noqa: T201
|
||||
sys.stdout.flush() # flush output buffer to display text immediately
|
||||
print() # noqa: T201
|
||||
return full_response, context_data
|
||||
|
||||
return asyncio.run(run_streaming_search())
|
||||
|
||||
# not streaming
|
||||
response, context_data = asyncio.run(
|
||||
@ -247,6 +273,7 @@ def run_drift_search(
|
||||
text_units=final_text_units,
|
||||
relationships=final_relationships,
|
||||
community_level=community_level,
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
@ -281,8 +308,6 @@ def run_basic_search(
|
||||
)
|
||||
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]
|
||||
|
||||
print(streaming) # noqa: T201
|
||||
|
||||
# # call the Query API
|
||||
if streaming:
|
||||
|
||||
|
||||
@ -589,6 +589,7 @@ def create_graphrag_config(
|
||||
):
|
||||
drift_search_model = DRIFTSearchConfig(
|
||||
prompt=reader.str("prompt") or None,
|
||||
reduce_prompt=reader.str("reduce_prompt") or None,
|
||||
temperature=reader.float("llm_temperature")
|
||||
or defs.DRIFT_SEARCH_LLM_TEMPERATURE,
|
||||
top_p=reader.float("llm_top_p") or defs.DRIFT_SEARCH_LLM_TOP_P,
|
||||
@ -597,6 +598,10 @@ def create_graphrag_config(
|
||||
or defs.DRIFT_SEARCH_MAX_TOKENS,
|
||||
data_max_tokens=reader.int("data_max_tokens")
|
||||
or defs.DRIFT_SEARCH_DATA_MAX_TOKENS,
|
||||
reduce_max_tokens=reader.int("reduce_max_tokens")
|
||||
or defs.DRIFT_SEARCH_REDUCE_MAX_TOKENS,
|
||||
reduce_temperature=reader.float("reduce_temperature")
|
||||
or defs.DRIFT_SEARCH_REDUCE_LLM_TEMPERATURE,
|
||||
concurrency=reader.int("concurrency") or defs.DRIFT_SEARCH_CONCURRENCY,
|
||||
drift_k_followups=reader.int("drift_k_followups")
|
||||
or defs.DRIFT_SEARCH_K_FOLLOW_UPS,
|
||||
|
||||
@ -149,6 +149,9 @@ DRIFT_SEARCH_K_FOLLOW_UPS = 20
|
||||
DRIFT_SEARCH_PRIMER_FOLDS = 5
|
||||
DRIFT_SEARCH_PRIMER_MAX_TOKENS = 12_000
|
||||
|
||||
DRIFT_SEARCH_REDUCE_LLM_TEMPERATURE = 0
|
||||
DRIFT_SEARCH_REDUCE_MAX_TOKENS = 2_000
|
||||
|
||||
DRIFT_LOCAL_SEARCH_TEXT_UNIT_PROP = 0.9
|
||||
DRIFT_LOCAL_SEARCH_COMMUNITY_PROP = 0.1
|
||||
DRIFT_LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES = 10
|
||||
|
||||
@ -132,6 +132,7 @@ global_search:
|
||||
|
||||
drift_search:
|
||||
prompt: "prompts/drift_search_system_prompt.txt"
|
||||
reduce_prompt: "prompts/drift_search_reduce_prompt.txt"
|
||||
|
||||
basic_search:
|
||||
prompt: "prompts/basic_search_system_prompt.txt"
|
||||
|
||||
@ -14,6 +14,9 @@ class DRIFTSearchConfig(BaseModel):
|
||||
prompt: str | None = Field(
|
||||
description="The drift search prompt to use.", default=None
|
||||
)
|
||||
reduce_prompt: str | None = Field(
|
||||
description="The drift search reduce prompt to use.", default=None
|
||||
)
|
||||
temperature: float = Field(
|
||||
description="The temperature to use for token generation.",
|
||||
default=defs.DRIFT_SEARCH_LLM_TEMPERATURE,
|
||||
@ -35,6 +38,16 @@ class DRIFTSearchConfig(BaseModel):
|
||||
default=defs.DRIFT_SEARCH_DATA_MAX_TOKENS,
|
||||
)
|
||||
|
||||
reduce_max_tokens: int = Field(
|
||||
description="The reduce llm maximum tokens response to produce.",
|
||||
default=defs.DRIFT_SEARCH_REDUCE_MAX_TOKENS,
|
||||
)
|
||||
|
||||
reduce_temperature: float = Field(
|
||||
description="The temperature to use for token generation in reduce.",
|
||||
default=defs.DRIFT_SEARCH_REDUCE_LLM_TEMPERATURE,
|
||||
)
|
||||
|
||||
concurrency: int = Field(
|
||||
description="The number of concurrent requests.",
|
||||
default=defs.DRIFT_SEARCH_CONCURRENCY,
|
||||
|
||||
@ -106,7 +106,7 @@ If you decide to use general knowledge, you should add a delimiter stating that
|
||||
|
||||
---Target response length and format---
|
||||
|
||||
Multiple paragraphs
|
||||
{response_type}
|
||||
|
||||
|
||||
---Goal---
|
||||
@ -133,8 +133,6 @@ If you decide to use general knowledge, you should add a delimiter stating that
|
||||
|
||||
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. Now answer the following query using the data above:
|
||||
|
||||
{query}
|
||||
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@ -173,7 +173,9 @@ def get_drift_search_engine(
|
||||
entities: list[Entity],
|
||||
relationships: list[Relationship],
|
||||
description_embedding_store: BaseVectorStore,
|
||||
response_type: str,
|
||||
local_system_prompt: str | None = None,
|
||||
reduce_system_prompt: str | None = None,
|
||||
) -> DRIFTSearch:
|
||||
"""Create a local search engine based on data + configuration."""
|
||||
llm = get_llm(config)
|
||||
@ -191,7 +193,9 @@ def get_drift_search_engine(
|
||||
entity_text_embeddings=description_embedding_store,
|
||||
text_units=text_units,
|
||||
local_system_prompt=local_system_prompt,
|
||||
reduce_system_prompt=reduce_system_prompt,
|
||||
config=config.drift_search,
|
||||
response_type=response_type,
|
||||
),
|
||||
token_encoder=token_encoder,
|
||||
)
|
||||
|
||||
@ -19,6 +19,7 @@ from graphrag.model.relationship import Relationship
|
||||
from graphrag.model.text_unit import TextUnit
|
||||
from graphrag.prompts.query.drift_search_system_prompt import (
|
||||
DRIFT_LOCAL_SYSTEM_PROMPT,
|
||||
DRIFT_REDUCE_PROMPT,
|
||||
)
|
||||
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
|
||||
from graphrag.query.llm.base import BaseTextEmbedding
|
||||
@ -51,6 +52,8 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
config: DRIFTSearchConfig | None = None,
|
||||
local_system_prompt: str | None = None,
|
||||
local_mixed_context: LocalSearchMixedContext | None = None,
|
||||
reduce_system_prompt: str | None = None,
|
||||
response_type: str | None = None,
|
||||
):
|
||||
"""Initialize the DRIFT search context builder with necessary components."""
|
||||
self.config = config or DRIFTSearchConfig()
|
||||
@ -58,6 +61,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
self.text_embedder = text_embedder
|
||||
self.token_encoder = token_encoder
|
||||
self.local_system_prompt = local_system_prompt or DRIFT_LOCAL_SYSTEM_PROMPT
|
||||
self.reduce_system_prompt = reduce_system_prompt or DRIFT_REDUCE_PROMPT
|
||||
|
||||
self.entities = entities
|
||||
self.entity_text_embeddings = entity_text_embeddings
|
||||
@ -67,6 +71,8 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
self.covariates = covariates
|
||||
self.embedding_vectorstore_key = embedding_vectorstore_key
|
||||
|
||||
self.response_type = response_type
|
||||
|
||||
self.local_mixed_context = (
|
||||
local_mixed_context or self.init_local_context_builder()
|
||||
)
|
||||
|
||||
@ -15,6 +15,7 @@ from graphrag.config.models.drift_search_config import DRIFTSearchConfig
|
||||
from graphrag.query.context_builder.conversation_history import ConversationHistory
|
||||
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
|
||||
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.query.structured_search.base import BaseSearch, SearchResult
|
||||
from graphrag.query.structured_search.drift_search.action import DriftAction
|
||||
from graphrag.query.structured_search.drift_search.drift_context import (
|
||||
@ -169,6 +170,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: Any = None,
|
||||
reduce: bool = True,
|
||||
**kwargs,
|
||||
) -> SearchResult:
|
||||
"""
|
||||
@ -177,6 +179,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
Args:
|
||||
query (str): The query to search for.
|
||||
conversation_history (Any, optional): The conversation history, if any.
|
||||
reduce (bool, optional): Whether to reduce the response to a single comprehensive response.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -248,8 +251,21 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
include_context=True
|
||||
)
|
||||
|
||||
reduced_response = response_state
|
||||
if reduce:
|
||||
# Reduce response_state to a single comprehensive response
|
||||
reduced_response = await self._reduce_response(
|
||||
responses=response_state,
|
||||
query=query,
|
||||
llm_calls=llm_calls,
|
||||
prompt_tokens=prompt_tokens,
|
||||
output_tokens=output_tokens,
|
||||
max_tokens=self.config.reduce_max_tokens,
|
||||
temperature=self.config.reduce_temperature,
|
||||
)
|
||||
|
||||
return SearchResult(
|
||||
response=response_state,
|
||||
response=reduced_response,
|
||||
context_data=context_data,
|
||||
context_text=context_text,
|
||||
completion_time=t_elapsed,
|
||||
@ -281,7 +297,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
error_msg = "Synchronous DRIFT is not implemented."
|
||||
raise NotImplementedError(error_msg)
|
||||
|
||||
def astream_search(
|
||||
async def astream_search(
|
||||
self, query: str, conversation_history: ConversationHistory | None = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
@ -290,10 +306,127 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
Args:
|
||||
query (str): The query to search for.
|
||||
conversation_history (ConversationHistory, optional): The conversation history.
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError: Streaming DRIFT search is not implemented.
|
||||
"""
|
||||
error_msg = "Streaming DRIFT search is not implemented."
|
||||
raise NotImplementedError(error_msg)
|
||||
result = await self.asearch(
|
||||
query=query, conversation_history=conversation_history, reduce=False
|
||||
)
|
||||
|
||||
if isinstance(result.response, list):
|
||||
result.response = result.response[0]
|
||||
|
||||
async for resp in self._reduce_response_streaming(
|
||||
responses=result.response,
|
||||
query=query,
|
||||
max_tokens=self.config.reduce_max_tokens,
|
||||
temperature=self.config.reduce_temperature,
|
||||
):
|
||||
yield resp
|
||||
|
||||
async def _reduce_response(
|
||||
self,
|
||||
responses: str | dict[str, Any],
|
||||
query: str,
|
||||
llm_calls: dict[str, int],
|
||||
prompt_tokens: dict[str, int],
|
||||
output_tokens: dict[str, int],
|
||||
**llm_kwargs,
|
||||
) -> str:
|
||||
"""Reduce the response to a single comprehensive response.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
responses : str|dict[str, Any]
|
||||
The responses to reduce.
|
||||
query : str
|
||||
The original query.
|
||||
llm_kwargs : dict[str, Any]
|
||||
Additional keyword arguments to pass to the LLM.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The reduced response.
|
||||
"""
|
||||
reduce_responses = []
|
||||
|
||||
if isinstance(responses, str):
|
||||
reduce_responses = [responses]
|
||||
else:
|
||||
reduce_responses = [
|
||||
response["answer"]
|
||||
for response in responses.get("nodes", [])
|
||||
if response.get("answer")
|
||||
]
|
||||
|
||||
search_prompt = self.context_builder.reduce_system_prompt.format(
|
||||
context_data=reduce_responses,
|
||||
response_type=self.context_builder.response_type,
|
||||
)
|
||||
search_messages = [
|
||||
{"role": "system", "content": search_prompt},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
|
||||
reduced_response = self.llm.generate(
|
||||
messages=search_messages,
|
||||
streaming=False,
|
||||
callbacks=None,
|
||||
**llm_kwargs,
|
||||
)
|
||||
|
||||
llm_calls["reduce"] = 1
|
||||
prompt_tokens["reduce"] = num_tokens(
|
||||
search_prompt, self.token_encoder
|
||||
) + num_tokens(query, self.token_encoder)
|
||||
output_tokens["reduce"] = num_tokens(reduced_response, self.token_encoder)
|
||||
|
||||
return reduced_response
|
||||
|
||||
async def _reduce_response_streaming(
|
||||
self,
|
||||
responses: str | dict[str, Any],
|
||||
query: str,
|
||||
**llm_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Reduce the response to a single comprehensive response.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
responses : str|dict[str, Any]
|
||||
The responses to reduce.
|
||||
query : str
|
||||
The original query.
|
||||
llm_kwargs : dict[str, Any]
|
||||
Additional keyword arguments to pass to the LLM.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The reduced response.
|
||||
"""
|
||||
reduce_responses = []
|
||||
|
||||
if isinstance(responses, str):
|
||||
reduce_responses = [responses]
|
||||
else:
|
||||
reduce_responses = [
|
||||
response["answer"]
|
||||
for response in responses.get("nodes", [])
|
||||
if response.get("answer")
|
||||
]
|
||||
|
||||
search_prompt = self.context_builder.reduce_system_prompt.format(
|
||||
context_data=reduce_responses,
|
||||
response_type=self.context_builder.response_type,
|
||||
)
|
||||
search_messages = [
|
||||
{"role": "system", "content": search_prompt},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
|
||||
async for resp in self.llm.astream_generate(
|
||||
search_messages,
|
||||
callbacks=None,
|
||||
**llm_kwargs,
|
||||
):
|
||||
yield resp
|
||||
|
||||
1059
poetry.lock
generated
1059
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -65,13 +65,11 @@ numpy = "^1.25.2"
|
||||
graspologic = "^3.4.1"
|
||||
networkx = "^3.4.2"
|
||||
pandas = "^2.2.3"
|
||||
matplotlib = "^3.9.3"
|
||||
pyarrow = "^15.0.0"
|
||||
umap-learn = "^0.5.6"
|
||||
|
||||
# Configuration
|
||||
pyyaml = "^6.0.2"
|
||||
pyaml-env = "^1.2.1"
|
||||
python-dotenv = "^1.0.1"
|
||||
|
||||
pydantic = "^2.10.3"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user