formatting

This commit is contained in:
Julian Whiting 2024-10-11 16:24:16 -04:00
parent 2df0567f5c
commit d24e0bd3cc
15 changed files with 709 additions and 342 deletions

View File

@ -114,7 +114,7 @@ DRIFT_SEARCH_MAX_TOKENS = 12_000
DRIFT_SEARCH_DATA_MAX_TOKENS = 12_000
DRIFT_SEARCH_CONCURRENCY = 32
DRIFT_SEARCH_PRIMER_K = 20
DRIFT_SEARCH_K_FOLLOW_UPS = 20
DRIFT_SEARCH_PRIMER_FOLDS = 5
DRIFT_SEARCH_PRIMER_MAX_TOKENS = 12_000
@ -128,4 +128,4 @@ DRIFT_LOCAL_SEARCH_LLM_TOP_P = 1
DRIFT_LOCAL_SEARCH_LLM_N = 1
DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS = 2000
DRIFT_SEARCH_STEPS = 3
DRIFT_N_DEPTH = 3

View File

@ -31,7 +31,7 @@ __all__ = [
"ClaimExtractionConfig",
"ClusterGraphConfig",
"CommunityReportsConfig",
"DRIFTSearchConfig",
"DRIFTSearchConfig",
"EmbedGraphConfig",
"EntityExtractionConfig",
"GlobalSearchConfig",

View File

@ -4,6 +4,7 @@
"""Parameterization settings for the default configuration."""
from pydantic import BaseModel, Field
import graphrag.config.defaults as defs
@ -36,9 +37,9 @@ class DRIFTSearchConfig(BaseModel):
default=defs.DRIFT_SEARCH_CONCURRENCY,
)
search_primer_k: int = Field(
drift_k_followups: int = Field(
description="The number of top global results to retrieve.",
default=defs.DRIFT_SEARCH_PRIMER_K,
default=defs.DRIFT_SEARCH_K_FOLLOW_UPS,
)
primer_folds: int = Field(
@ -51,9 +52,9 @@ class DRIFTSearchConfig(BaseModel):
default=defs.DRIFT_SEARCH_PRIMER_MAX_TOKENS,
)
steps: int = Field(
n_depth: int = Field(
description="The number of drift search steps to take.",
default=defs.DRIFT_SEARCH_STEPS,
default=defs.DRIFT_N_DEPTH,
)
local_search_text_unit_prop: float = Field(

View File

@ -4,7 +4,7 @@
"""Base classes for global and local context builders."""
from abc import ABC, abstractmethod
from typing import Any
import pandas as pd
from graphrag.query.context_builder.conversation_history import (
@ -39,9 +39,9 @@ class DRIFTContextBuilder(ABC):
"""Base class for DRIFT-search context builders."""
@abstractmethod
def build_primer_context(
def build_context(
self,
query: str,
**kwargs,
) -> pd.DataFrame:
"""Build the context for the primer search actions"""
"""Build the context for the primer search actions."""

View File

@ -17,11 +17,11 @@ from graphrag.model import (
Relationship,
TextUnit,
)
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
from graphrag.query.structured_search.base import BaseSearch
from graphrag.query.structured_search.global_search.community_context import (
GlobalCommunityContext,
)
@ -110,7 +110,7 @@ def get_local_search_engine(
covariates: dict[str, list[Covariate]],
response_type: str,
description_embedding_store: BaseVectorStore,
) -> BaseSearch:
) -> LocalSearch:
"""Create a local search engine based on data + configuration."""
llm = get_llm(config)
text_embedder = get_text_embedder(config)
@ -161,7 +161,7 @@ def get_global_search_engine(
reports: list[CommunityReport],
entities: list[Entity],
response_type: str,
) -> BaseSearch:
) -> GlobalSearch:
"""Create a global search engine based on data + configuration."""
token_encoder = tiktoken.get_encoding(config.encoding_model)
gs_config = config.global_search

View File

@ -6,15 +6,15 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any
from typing import Any, Generic, TypeVar
import pandas as pd
import tiktoken
from graphrag.query.context_builder.builders import (
DRIFTContextBuilder,
GlobalContextBuilder,
LocalContextBuilder,
DRIFTContextBuilder,
)
from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
@ -35,13 +35,16 @@ class SearchResult:
prompt_tokens: int
class BaseSearch(ABC):
T = TypeVar("T", GlobalContextBuilder, LocalContextBuilder, DRIFTContextBuilder)
class BaseSearch(ABC, Generic[T]):
"""The Base Search implementation."""
def __init__(
self,
llm: BaseLLM,
context_builder: GlobalContextBuilder | LocalContextBuilder | DRIFTContextBuilder,
context_builder: T,
token_encoder: tiktoken.Encoding | None = None,
llm_params: dict[str, Any] | None = None,
context_builder_params: dict[str, Any] | None = None,
@ -75,5 +78,5 @@ class BaseSearch(ABC):
self,
query: str,
conversation_history: ConversationHistory | None = None,
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[str, None] | None:
"""Stream search for the given query."""

View File

@ -0,0 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""DriftSearch module."""

View File

@ -1,101 +1,206 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""DRIFT Search Query State."""
import json
import logging
from typing import Optional, Dict, Any, List
from typing import Any
log = logging.getLogger(__name__)
class DriftAction:
"""
Represents an action containing query, answer, score, and follow-up actions.
We want to be able to encapsulate the action strings being produced by the LLM in a clean way.
Represent an action containing a query, answer, score, and follow-up actions.
This class encapsulates action strings produced by the LLM in a structured way.
"""
def __init__(self, query: str, answer: str | None = None, follow_ups: List['DriftAction'] | List[str] = []):
def __init__(
self,
query: str,
answer: str | None = None,
follow_ups: list["DriftAction"] | None = None,
):
"""
Initialize the DriftAction with a query, optional answer, and follow-up actions.
Args:
query (str): The query for the action.
answer (Optional[str]): The answer to the query, if available.
follow_ups (Optional[list[DriftAction]]): A list of follow-up actions.
"""
self.query = query
self.answer: str | None = answer # corresponds to an 'intermediate_answer'
self.score: Optional[float] = None
self.follow_ups: List['DriftAction'] | List[str] = follow_ups
self.metadata: Dict[str, Any] = {}
# Should contain metadata explaining how to execute the action. Will not always be local search in the future.
self.answer: str | None = answer # Corresponds to an 'intermediate_answer'
self.score: float | None = None
self.follow_ups: list[DriftAction] = (
follow_ups if follow_ups is not None else []
)
self.metadata: dict[str, Any] = {}
@property
def is_complete(self) -> bool:
"""Check if the action is complete (i.e., an answer is available)."""
return self.answer is not None
def search(self, search_engine: Any, scorer: Any = None):
raise NotImplementedError("Search method not implemented for DriftAction. Use asearch instead.")
async def asearch(self, search_engine: Any, global_query: str, scorer: Any = None):
"""
Execute an asynchronous search using the search engine, and update the action with the results.
If a scorer is provided, compute the score for the action.
Args:
search_engine (Any): The search engine to execute the query.
global_query (str): The global query string.
scorer (Any, optional): Scorer to compute scores for the action.
Returns
-------
self : DriftAction
Updated action with search results.
"""
if self.is_complete:
log.warning("Action already complete. Skipping search.")
return self
search_result = await search_engine.asearch(
drift_query=global_query, query=self.query
)
async def asearch(self, search_engine: Any, global_query:str, scorer: Any = None):
# TODO: test that graph stores actions as reference... This SHOULD update the graph object.
search_result = await search_engine.asearch(drift_query=global_query, query=self.query)
try:
response = json.loads(search_result.response)
except json.JSONDecodeError:
raise ValueError(f"Failed to parse response: {search_result.response}. Ensure it is JSON serializable.")
self.answer = response.pop('response', None)
self.score = response.pop('score', float('-inf'))
self.metadata.update({'context_data':search_result.context_data})
self.follow_ups = response.pop('follow_up_queries', [])
if self.follow_ups == []:
except json.JSONDecodeError as e:
error_message = "Failed to parse search response"
log.exception("%s: %s", error_message, search_result.response)
raise ValueError(error_message) from e
self.answer = response.pop("response", None)
self.score = response.pop("score", float("-inf"))
self.metadata.update({"context_data": search_result.context_data})
self.metadata.update({"token_ct": search_result.token_ct})
self.follow_ups = response.pop("follow_up_queries", [])
if not self.follow_ups:
log.warning("No follow-up actions found for response: %s", response)
if scorer:
self.compute_score(scorer)
return self
def compute_score(self, scorer: Any):
score = scorer.compute_score(self.query, self.answer)
self.score = score if score is not None else float('-inf') # use -inf to help with sorting later
def serialize(self, include_follow_ups: bool = True) -> Dict[str, Any]:
"""
Serializes the action to a dictionary.
Compute the score for the action using the provided scorer.
Args:
scorer (Any): The scorer to compute the score.
"""
score = scorer.compute_score(self.query, self.answer)
self.score = score if score is not None else float("-inf") # Default to -inf for sorting
def serialize(self, include_follow_ups: bool = True) -> dict[str, Any]:
"""
Serialize the action to a dictionary.
Args:
include_follow_ups (bool): Whether to include follow-up actions in the serialization.
Returns
-------
dict[str, Any]
Serialized action as a dictionary.
"""
data = {
'query': self.query,
'answer': self.answer,
'score': self.score,
'metadata': self.metadata,
"query": self.query,
"answer": self.answer,
"score": self.score,
"metadata": self.metadata,
}
if include_follow_ups:
data['follow_ups'] = [action.serialize() for action in self.follow_ups] # TODO: handle leftover string followups
data["follow_ups"] = [action.serialize() for action in self.follow_ups]
return data
@classmethod
def deserialize(cls, data: Dict[str, Any]) -> 'DriftAction':
def deserialize(cls, data: dict[str, Any]) -> "DriftAction":
"""
Deserializes the action from a dictionary.
Deserialize the action from a dictionary.
Args:
data (dict[str, Any]): Serialized action data.
Returns
-------
DriftAction
A deserialized instance of DriftAction.
"""
action = cls(data['query'])
action.answer = data.get('answer')
action.score = data.get('score')
action.metadata = data.get('metadata', {})
if 'follow_ups' in data:
action.follow_ups = [cls.deserialize(fu_data) for fu_data in data.get('follow_up_queries', [])]
else:
action.follow_ups = []
action = cls(data["query"])
action.answer = data.get("answer")
action.score = data.get("score")
action.metadata = data.get("metadata", {})
action.follow_ups = (
[cls.deserialize(fu_data) for fu_data in data.get("follow_up_queries", [])]
if "follow_ups" in data
else []
)
return action
@classmethod
def from_primer_response(cls, query: str, response: str | Dict[str, Any] | List[Dict[str, Any]]) -> 'DriftAction':
def from_primer_response(
cls, query: str, response: str | dict[str, Any] | list[dict[str, Any]]
) -> "DriftAction":
"""
Creates a DriftAction from the DRIFTPrimer response.
Create a DriftAction from a DRIFTPrimer response.
Args:
query (str): The query string.
response (str | dict[str, Any] | list[dict[str, Any]]): Primer response data.
Returns
-------
DriftAction
A new instance of DriftAction based on the response.
Raises
------
ValueError
If the response is not a dictionary or expected format.
"""
if isinstance(response, dict):
action = cls(query, follow_ups=response.get('follow_up_queries', []), answer=response.get('intermediate_answer'))
action.answer = response.get('intermediate_answer')
action.score = response.get('score')
action = cls(
query,
follow_ups=response.get("follow_up_queries", []),
answer=response.get("intermediate_answer"),
)
action.score = response.get("score")
return action
else:
raise ValueError(f'Response must be a dictionary. Found: {type(response)}'
f' with content: {response}')
error_message = "Response must be a dictionary"
raise ValueError(error_message)
def __hash__(self):
# Necessary for storing in networkx.MultiDiGraph. Assumes unique queries.
"""
Allow DriftAction objects to be hashable for use in networkx.MultiDiGraph.
Assumes queries are unique.
Returns
-------
int
Hash based on the query.
"""
return hash(self.query)
def __eq__(self, other):
"""
Check equality based on the query string.
Args:
other (Any): Another object to compare with.
Returns
-------
bool
True if the other object is a DriftAction with the same query, False otherwise.
"""
return isinstance(other, DriftAction) and self.query == other.query

View File

@ -1,5 +1,10 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""DRIFT Context Builder implementation."""
import logging
from typing import Any, List
from typing import Any
import numpy as np
import pandas as pd
@ -17,9 +22,7 @@ from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKe
from graphrag.query.llm.base import BaseTextEmbedding
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.structured_search.base import DRIFTContextBuilder
from graphrag.query.structured_search.drift_search.primer import (
PrimerQueryProcessor,
)
from graphrag.query.structured_search.drift_search.primer import PrimerQueryProcessor
from graphrag.query.structured_search.drift_search.system_prompt import (
DRIFT_LOCAL_SYSTEM_PROMPT,
)
@ -32,30 +35,31 @@ log = logging.getLogger(__name__)
class DRIFTSearchContextBuilder(DRIFTContextBuilder):
"""Class representing the DRIFT Search Context Builder."""
def __init__(
self,
chat_llm: ChatOpenAI,
text_embedder: BaseTextEmbedding,
entities: List[Entity],
entities: list[Entity],
entity_text_embeddings: BaseVectorStore,
text_units: List[TextUnit] | None = None,
reports: List[CommunityReport] | None = None,
relationships: List[Relationship] | None = None,
covariates: dict[str, List[Covariate]] | None = None,
text_units: list[TextUnit] | None = None,
reports: list[CommunityReport] | None = None,
relationships: list[Relationship] | None = None,
covariates: dict[str, list[Covariate]] | None = None,
token_encoder: tiktoken.Encoding | None = None,
embedding_vectorstore_key: str = EntityVectorStoreKey.ID,
config: DRIFTSearchConfig | None = None,
local_system_prompt: str = DRIFT_LOCAL_SYSTEM_PROMPT,
local_mixed_context: LocalSearchMixedContext | None = None,
):
"""Initialize the DRIFT search context builder with necessary components."""
self.config = config or DRIFTSearchConfig()
self.chat_llm = chat_llm
self.text_embedder = text_embedder
self.token_encoder = token_encoder
self.local_system_prompt = local_system_prompt
self.drift_primer_context = None
self.entities = entities
self.entity_text_embeddings = entity_text_embeddings
self.reports = reports
@ -64,11 +68,19 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
self.covariates = covariates
self.embedding_vectorstore_key = embedding_vectorstore_key
self.llm_tokens = 0
self.local_mixed_context = (
local_mixed_context or self.init_local_context_builder()
)
def init_local_context_builder(self) -> LocalSearchMixedContext:
"""
Initialize the local search mixed context builder.
Returns
-------
LocalSearchMixedContext: Initialized local context.
"""
return LocalSearchMixedContext(
community_reports=self.reports,
text_units=self.text_units,
@ -82,45 +94,79 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
)
@staticmethod
def convert_reports_to_df(
reports: List[CommunityReport],
) -> pd.DataFrame:
def convert_reports_to_df(reports: list[CommunityReport]) -> pd.DataFrame:
"""
Converts a list of CommunityReport objects to a DataFrame.
"""
df = pd.DataFrame([report.__dict__ for report in reports])
if df["full_content"].isnull().any():
raise ValueError("Some reports are missing full content.")
Convert a list of CommunityReport objects to a pandas DataFrame.
Args
----
reports : list[CommunityReport]
List of CommunityReport objects.
Returns
-------
pd.DataFrame: DataFrame with report data.
Raises
------
ValueError: If some reports are missing full content or full content embeddings.
"""
report_df = pd.DataFrame([report.__dict__ for report in reports])
missing_content_error = "Some reports are missing full content."
missing_embedding_error = "Some reports are missing full content embeddings."
if report_df["full_content"].isna().sum() > 0:
raise ValueError(missing_content_error)
if (
"full_content_embedding" not in df.columns
or df["full_content_embedding"].isnull().any()
"full_content_embedding" not in report_df.columns
or report_df["full_content_embedding"].isna().sum() > 0
):
raise ValueError("Some reports are missing full content embeddings.")
return df
raise ValueError(missing_embedding_error)
return report_df
@staticmethod
def check_query_doc_encodings(
query_embedding: Any, embedding: Any
) -> bool:
def check_query_doc_encodings(query_embedding: Any, embedding: Any) -> bool:
"""
Checks if the embeddings have the same type, length, are not empty,
and have matching element types.
Check if the embeddings are compatible.
Args
----
query_embedding : Any
Embedding of the query.
embedding : Any
Embedding to compare against.
Returns
-------
bool: True if embeddings match, otherwise False.
"""
if not (
return (
isinstance(query_embedding, type(embedding))
and len(query_embedding) == len(embedding)
and isinstance(query_embedding[0], type(embedding[0]))
):
return False
return True
)
def build_primer_context(self, query: str, **kwargs) -> pd.DataFrame:
if not self.reports:
raise ValueError(
"No community reports available... Please provide a list of reports."
)
def build_context(self, query: str, **kwargs) -> pd.DataFrame:
"""
Build DRIFT search context.
Args
----
query : str
Search query string.
Returns
-------
pd.DataFrame: Top-k most similar documents.
Raises
------
ValueError: If no community reports are available, or embeddings
are incompatible.
"""
if self.reports is None:
missing_reports_error = "No community reports available. Please provide a list of reports."
raise ValueError(missing_reports_error)
query_processor = PrimerQueryProcessor(
chat_llm=self.chat_llm,
@ -129,29 +175,26 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
reports=self.reports,
)
query_embedding = query_processor(query)
query_embedding, token_ct = query_processor(query)
self.llm_tokens += token_ct
report_df = self.convert_reports_to_df(self.reports)
if self.check_query_doc_encodings(
query_embedding, report_df["full_content_embedding"].iloc[0]
):
report_df["similarity"] = report_df[
"full_content_embedding"
].apply(
report_df["similarity"] = report_df["full_content_embedding"].apply(
lambda x: np.dot(x, query_embedding)
/ (np.linalg.norm(x) * np.linalg.norm(query_embedding))
)
top_k = (
report_df.sort_values("similarity", ascending=False)
.head(self.config.search_primer_k)
top_k = report_df.sort_values("similarity", ascending=False).head(
self.config.drift_k_followups
)
else:
raise ValueError(
incompatible_embeddings_error = (
"Query and document embeddings are not compatible. "
"Please ensure that the embeddings are of the same type and length."
f" Query: {query_embedding}, Document: {report_df['full_content_embedding'].iloc[0]}"
)
raise ValueError(incompatible_embeddings_error)
return top_k[['short_id', 'community_id', 'full_content']]
return top_k.loc[:, ["short_id", "community_id", "full_content"]]

View File

@ -1,93 +1,148 @@
import logging
import json
import pandas as pd
import random
import time
from typing import Any, List, Dict, Tuple
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Primer for DRIFT search."""
import json
import logging
import secrets
import time
import pandas as pd
import tiktoken
from tqdm.asyncio import tqdm_asyncio
from graphrag.config.models.drift_config import DRIFTSearchConfig
from graphrag.model import CommunityReport
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.base import BaseTextEmbedding
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import SearchResult
from graphrag.query.structured_search.drift_search.system_prompt import DRIFT_PRIMER_PROMPT
from graphrag.query.structured_search.drift_search.system_prompt import (
DRIFT_PRIMER_PROMPT,
)
log = logging.getLogger(__name__)
class PrimerQueryProcessor:
"""Process the query by expanding it using community reports and generate follow-up actions."""
def __init__(self, chat_llm: ChatOpenAI, text_embedder: BaseTextEmbedding, reports: List[CommunityReport], token_encoder: tiktoken.Encoding | None = None,):
def __init__(
self,
chat_llm: ChatOpenAI,
text_embedder: BaseTextEmbedding,
reports: list[CommunityReport],
token_encoder: tiktoken.Encoding | None = None,
):
"""
Initialize the PrimerQueryProcessor.
Args:
chat_llm (ChatOpenAI): The language model used to process the query.
text_embedder (BaseTextEmbedding): The text embedding model.
reports (list[CommunityReport]): List of community reports.
token_encoder (tiktoken.Encoding, optional): Token encoder for token counting.
"""
self.chat_llm = chat_llm
self.text_embedder = text_embedder
self.token_encoder = token_encoder
self.reports = reports
def expand_query(self, query: str) -> Tuple[str, int]:
def expand_query(self, query: str) -> tuple[str, int]:
"""
Expand the query using a random community report template.
Args:
query (str): The original search query.
Returns
-------
tuple[str, int]: Expanded query text and the number of tokens used.
"""
token_ct = 0
template = secrets.choice(self.reports).full_content # nosec S311
template = random.choice(self.reports).full_content
prompt = f"""Create a hypothetical answer to the following query: {query}\n\n
Format it to follow the structure of the template below:\n\n
{template}\n"
Ensure that the hypothetical answer does not reference new named entities that are not present in the original query."""
prompt = (
f"Create a hypothetical answer to the following query: {query}\n\n"
"Format it to follow the structure of the template below:\n\n"
f"{template}\n"
"Ensure that the hypothetical answer does not reference new named entities that are not present in the original query."
)
messages = [
{"role": "user", "content": prompt}
]
messages = [{"role": "user", "content": prompt}]
text = self.chat_llm.generate(messages)
if self.token_encoder:
token_ct = len(self.token_encoder.encode(text)) + len(self.token_encoder.encode(query))
token_ct = num_tokens(text + query)
if text == "":
log.warning("Failed to generate expansion for query: %s", query)
return query, token_ct
return text, token_ct
def __call__(self, query: str) -> List[float]:
hyde_query, token_ct = self.expand_query(query) # TODO: implement token counting
def __call__(self, query: str) -> tuple[list[float], int]:
"""
Call method to process the query, expand it, and embed the result.
Args:
query (str): The search query.
Returns
-------
tuple[list[float], int]: List of embeddings for the expanded query and the token count.
"""
hyde_query, token_ct = self.expand_query(query)
log.info("Expanded query: %s", hyde_query)
return self.text_embedder.embed(hyde_query)
return self.text_embedder.embed(hyde_query), token_ct
class DRIFTPrimer:
"""
Performs initial query decomposition using global guidance from information in community reports.
"""
"""Perform initial query decomposition using global guidance from information in community reports."""
def __init__(
self,
config: DRIFTSearchConfig,
chat_llm: ChatOpenAI,
token_encoder: tiktoken.Encoding | None = None, # TODO: implement token counting
token_encoder: tiktoken.Encoding | None = None,
):
"""
Initialize the DRIFTPrimer.
Args:
config (DRIFTSearchConfig): Configuration settings for DRIFT search.
chat_llm (ChatOpenAI): The language model used for searching.
token_encoder (tiktoken.Encoding, optional): Token encoder for managing tokens.
"""
self.llm = chat_llm
self.config = config
self.token_encoder = token_encoder
async def decompose_query(
self, query: str, reports: pd.DataFrame
) -> tuple[dict, int]:
"""
Decompose the query into subqueries based on the fetched global structures.
async def decompose_query(self, query: str, reports: pd.DataFrame) -> Tuple[Dict, int]:
Args:
query (str): The original search query.
reports (pd.DataFrame): DataFrame containing community reports.
Returns
-------
tuple[dict, int]: Parsed response and the number of tokens used.
"""
Decomposes the query into subqueries based on the fetched global structures.
Returns a tuple of the parsed response and the number of tokens used.
"""
community_reports = "\n\n".join(reports['full_content'].tolist())
prompt = DRIFT_PRIMER_PROMPT.format(query=query, community_reports=community_reports)
community_reports = "\n\n".join(reports["full_content"].tolist())
prompt = DRIFT_PRIMER_PROMPT.format(
query=query, community_reports=community_reports
)
messages = [{"role": "user", "content": prompt}]
prompt_tokens = self._count_message_tokens(messages)
response = await self.llm.agenerate(messages, response_format={"type": "json_object"})
response = await self.llm.agenerate(
messages, response_format={"type": "json_object"}
)
parsed_response = json.loads(response)
token_ct = num_tokens(prompt + response, self.token_encoder)
return parsed_response, prompt_tokens
return parsed_response, token_ct
async def asearch(
self,
@ -96,70 +151,49 @@ class DRIFTPrimer:
) -> SearchResult:
"""
Asynchronous search method that processes the query and returns a SearchResult.
Args:
query (str): The search query.
top_k_reports (pd.DataFrame): DataFrame containing the top-k reports.
Returns
-------
SearchResult: The search result containing the response and context data.
"""
start_time = time.time()
report_folds = self.split_reports(top_k_reports)
tasks = []
prompt_tokens = 0
llm_calls = len(report_folds)
for fold in report_folds:
task = self.decompose_query(query, fold)
tasks.append(task)
tasks = [self.decompose_query(query, fold) for fold in report_folds]
results_with_tokens = await tqdm_asyncio.gather(*tasks)
for result in results_with_tokens:
# result is a tuple: (parsed_response, tokens_used)
_, tokens_used = result
prompt_tokens += tokens_used
completion_time = time.time() - start_time
search_result = SearchResult(
return SearchResult(
response=[response for response, _ in results_with_tokens],
context_data={'top_k_reports': top_k_reports},
context_data={"top_k_reports": top_k_reports},
context_text=str(top_k_reports),
completion_time=completion_time,
llm_calls=llm_calls,
prompt_tokens=prompt_tokens,
llm_calls=2,
prompt_tokens=sum(tokens for _, tokens in results_with_tokens),
)
return search_result
def split_reports(self, reports: pd.DataFrame) -> List[pd.DataFrame]:
def split_reports(self, reports: pd.DataFrame) -> list[pd.DataFrame]:
"""
Splits the reports into folds, allows for parallel processing.
Split the reports into folds, allowing for parallel processing.
Args:
reports (pd.DataFrame): DataFrame of community reports.
Returns
-------
list[pd.DataFrame]: List of report folds.
"""
folds = []
num_reports = len(reports)
primer_folds = self.config.primer_folds or 1 # Ensure at least one fold
for i in range(primer_folds):
start_idx = i * num_reports // primer_folds
if i == primer_folds - 1:
end_idx = num_reports
else:
end_idx = (i + 1) * num_reports // primer_folds
end_idx = num_reports if i == primer_folds - 1 else (i + 1) * num_reports // primer_folds
fold = reports.iloc[start_idx:end_idx]
folds.append(fold)
return folds
def _count_text_tokens(self, text: str) -> int:
"""
Counts the number of tokens in a given text using the token encoder.
"""
if self.token_encoder is None:
raise ValueError("Token encoder is not initialized.")
return len(self.token_encoder.encode(text))
def _count_message_tokens(self, messages: List[Dict[str, str]]) -> int:
"""
Counts the number of tokens in a list of messages.
"""
total_tokens = 0
for message in messages:
for _, value in message.items():
total_tokens += self._count_text_tokens(value)
return total_tokens

View File

@ -1,45 +1,72 @@
from collections.abc import AsyncGenerator
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""DRIFT Search implementation."""
import logging
from typing import Any, List
import time
from collections.abc import AsyncGenerator
from typing import Any
import tiktoken
from tqdm.asyncio import tqdm_asyncio
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.query.structured_search.base import BaseSearch, SearchResult
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.structured_search.local_search.search import LocalSearch
from graphrag.query.structured_search.drift_search.action import DriftAction
from graphrag.query.structured_search.drift_search.state import QueryState
from graphrag.query.structured_search.drift_search.drift_context import DRIFTSearchContextBuilder
from graphrag.query.structured_search.drift_search.primer import DRIFTPrimer
from graphrag.config.models.drift_config import DRIFTSearchConfig
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.structured_search.base import BaseSearch, SearchResult
from graphrag.query.structured_search.drift_search.action import DriftAction
from graphrag.query.structured_search.drift_search.drift_context import (
DRIFTSearchContextBuilder,
)
from graphrag.query.structured_search.drift_search.primer import DRIFTPrimer
from graphrag.query.structured_search.drift_search.state import QueryState
from graphrag.query.structured_search.local_search.search import LocalSearch
log = logging.getLogger(__name__)
class DRIFTSearch(BaseSearch):
class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
"""Class representing a DRIFT Search."""
def __init__(
self,
llm: ChatOpenAI,
context_builder: DRIFTSearchContextBuilder,
config: DRIFTSearchConfig | None = None,
token_encoder: tiktoken.Encoding | None = None, # TODO: implement token counting
query_state: QueryState | None = None,
self,
llm: ChatOpenAI,
context_builder: DRIFTSearchContextBuilder,
config: DRIFTSearchConfig | None = None,
token_encoder: tiktoken.Encoding | None = None,
query_state: QueryState | None = None,
):
"""
Initialize the DRIFTSearch class.
Args:
llm (ChatOpenAI): The language model used for searching.
context_builder (DRIFTSearchContextBuilder): Builder for search context.
config (DRIFTSearchConfig, optional): Configuration settings for DRIFTSearch.
token_encoder (tiktoken.Encoding, optional): Token encoder for managing tokens.
query_state (QueryState, optional): State of the current search query.
"""
super().__init__(llm, context_builder, token_encoder)
self.config = config or DRIFTSearchConfig()
self.context_builder = context_builder
self.token_encoder = token_encoder
self.query_state = query_state or QueryState()
self.primer = DRIFTPrimer(config=self.config, chat_llm=llm, token_encoder=token_encoder)
self.primer = DRIFTPrimer(
config=self.config, chat_llm=llm, token_encoder=token_encoder
)
self.local_search = self.init_local_search()
def init_local_search(self) -> LocalSearch:
"""
Initialize the LocalSearch object with parameters based on the DRIFT search configuration.
def init_local_search(self):
Returns
-------
LocalSearch: An instance of the LocalSearch class with the configured parameters.
"""
local_context_params = {
"text_unit_prop": self.config.local_search_text_unit_prop,
"community_prop": self.config.local_search_community_prop,
@ -49,17 +76,16 @@ class DRIFTSearch(BaseSearch):
"include_relationship_weight": True,
"include_community_rank": False,
"return_candidate_context": False,
"embedding_vectorstore_key": EntityVectorStoreKey.ID, # set this to EntityVectorStoreKey.TITLE if the vectorstore uses entity title as ids
"max_tokens": self.config.local_search_max_data_tokens,
"embedding_vectorstore_key": EntityVectorStoreKey.ID,
"max_tokens": self.config.local_search_max_data_tokens,
}
llm_params = {
"max_tokens": self.config.local_search_llm_max_gen_tokens,
"temperature": self.config.local_search_temperature,
"response_format": {"type": "json_object"}
"response_format": {"type": "json_object"},
}
return LocalSearch(
llm=self.llm,
system_prompt=self.context_builder.local_system_prompt,
@ -67,83 +93,203 @@ class DRIFTSearch(BaseSearch):
token_encoder=self.token_encoder,
llm_params=llm_params,
context_builder_params=local_context_params,
response_type="multiple paragraphs" # this has no bearing on the obj returned by OAI, only the format of the response within the obj returned by OAI.
response_type="multiple paragraphs",
)
def _process_primer_results(
self, query: str, search_results: SearchResult
) -> DriftAction:
"""
Process the results from the primer search to extract intermediate answers and follow-up queries.
def _process_primer_results(self, query:str, search_results: SearchResult) -> DriftAction:
Args:
query (str): The original search query.
search_results (SearchResult): The results from the primer search.
Returns
-------
DriftAction: Action generated from the primer response.
Raises
------
RuntimeError: If no intermediate answers or follow-up queries are found in the primer response.
"""
response = search_results.response
if isinstance(response, list) and isinstance(response[0], dict):
intermediate_answer = "\n\n".join([i['intermediate_answer'] for i in response])
follow_ups = [fu for i in response for fu in i['follow_up_queries']]
score = sum([i['score'] for i in response]) / len(response)
response = {'intermediate_answer': intermediate_answer, 'follow_up_queries': follow_ups, 'score': score}
return DriftAction.from_primer_response(query, response)
else:
raise ValueError(f'Response must be a list of dictionaries. Found: {type(response)}')
intermediate_answers = [
i["intermediate_answer"] for i in response if "intermediate_answer" in i
]
if len(intermediate_answers) == 0:
error_msg = "No intermediate answers found in primer response. Ensure that the primer response includes intermediate answers."
raise RuntimeError(error_msg)
# hard coded to use local search, but will be updated to use a general meta search engine
async def asearch_step(self, global_query:str, search_engine: LocalSearch, actions: List[DriftAction]) -> List[DriftAction]:
tasks = [action.asearch(search_engine=search_engine, global_query=global_query) for action in actions]
results = await tqdm_asyncio.gather(*tasks)
return results
intermediate_answer = "\n\n".join([
i["intermediate_answer"] for i in response if "intermediate_answer" in i
])
follow_ups = [
fu
for i in response
for fu in i["follow_up_queries"]
if "follow_up_queries" in i
]
if len(follow_ups) == 0:
error_msg = "No follow-up queries found in primer response. Ensure that the primer response includes follow-up queries."
raise RuntimeError(error_msg)
score = sum(i["score"] for i in response) / len(response)
response_data = {
"intermediate_answer": intermediate_answer,
"follow_up_queries": follow_ups,
"score": score,
}
return DriftAction.from_primer_response(query, response_data)
error_msg = "Response must be a list of dictionaries."
raise ValueError(error_msg)
async def asearch_step(
self, global_query: str, search_engine: LocalSearch, actions: list[DriftAction]
) -> list[DriftAction]:
"""
Perform an asynchronous search step by executing each DriftAction asynchronously.
Args:
global_query (str): The global query for the search.
search_engine (LocalSearch): The local search engine instance.
actions (list[DriftAction]): A list of actions to perform.
Returns
-------
list[DriftAction]: The results from executing the search actions asynchronously.
"""
tasks = [
action.asearch(search_engine=search_engine, global_query=global_query)
for action in actions
]
return await tqdm_asyncio.gather(*tasks)
async def asearch(
self,
query: str,
conversation_history: Any = None,
**kwargs,
self,
query: str,
conversation_history: Any = None,
**kwargs,
) -> SearchResult:
"""
Perform an asynchronous DRIFT search.
Args:
query (str): The query to search for.
conversation_history (Any, optional): The conversation history, if any.
## Check if query state is empty. Are there follow-up actions?
Returns
-------
SearchResult: The search result containing the response and context data.
Raises
------
ValueError: If the query is empty.
"""
if query == "":
error_msg = "DRIFT Search query cannot be empty."
raise ValueError(error_msg)
start_time = time.time()
primer_token_ct = 0
context_token_ct = 0
# Check if query state is empty
if not self.query_state.graph:
# Prime the search with the primer
primer_context = self.context_builder.build_primer_context(query) # pd.DataFrame
primer_response = await self.primer.asearch(query=query, top_k_reports=primer_context)
## Package response into DriftAction
primer_context = self.context_builder.build_context(query)
context_token_ct = self.context_builder.llm_tokens
primer_response = await self.primer.asearch(
query=query, top_k_reports=primer_context
)
primer_token_ct = primer_response.prompt_tokens
# Package response into DriftAction
init_action = self._process_primer_results(query, primer_response)
self.query_state.add_action(init_action)
self.query_state.add_all_follow_ups(init_action, init_action.follow_ups)
## Main loop ##
steps = 0
while steps < self.config.n:
## Rank actions
# Main loop
epochs = 0
llm_call_offset = 0
while epochs < self.config.n:
actions = self.query_state.rank_incomplete_actions()
if not actions:
log.info("No more actions to take. Exiting DRIFT loop..")
if len(actions) == 0:
log.info("No more actions to take. Exiting DRIFT loop.")
break
## Take the top k actions
actions = actions[:self.config.search_primer_k]
## Process actions
results = await self.asearch_step(global_query=query, search_engine=self.local_search, actions=actions) # currently harcoded to local search, but will be updated to use a more general search engine
actions = actions[: self.config.drift_k_followups]
llm_call_offset += len(actions) - self.config.drift_k_followups
# Process actions
results = await self.asearch_step(
global_query=query, search_engine=self.local_search, actions=actions
)
## Update query state
# Update query state
for action in results:
self.query_state.add_action(action)
self.query_state.add_all_follow_ups(action, action.follow_ups)
steps += 1
epochs += 1
t_elapsed = time.time() - start_time
# Calculate token usage
total_tokens = (
primer_token_ct + context_token_ct + self.query_state.action_token_ct()
)
# Package up context data
response_state, context_data, context_text = self.query_state.serialize(
include_context=True
)
return SearchResult(
response=self.query_state.serialize(),
context_data='test',
context_text='testing_only',
completion_time=0,
llm_calls=0,
prompt_tokens=0
response=response_state,
context_data=context_data,
context_text=context_text,
completion_time=t_elapsed,
llm_calls=1
+ self.config.primer_folds
+ (self.config.drift_k_followups - llm_call_offset) * self.config.n_depth,
prompt_tokens=total_tokens,
)
def search(
self,
query: str,
conversation_history: Any = None,
**kwargs,) -> SearchResult:
self,
query: str,
conversation_history: Any = None,
**kwargs,
) -> SearchResult:
"""
Perform a synchronous DRIFT search (Not Implemented).
raise NotImplementedError("Synchronous DRIFT is not implemented.")
Args:
query (str): The query to search for.
conversation_history (Any, optional): The conversation history.
def astream_search(self, query: str, conversation_history: ConversationHistory | None = None) -> AsyncGenerator[str, None]:
raise NotImplementedError("Streaming DRIFT search is not implemented.")
Raises
------
NotImplementedError: Synchronous DRIFT is not implemented.
"""
error_msg = "Synchronous DRIFT is not implemented."
raise NotImplementedError(error_msg)
def astream_search(
self, query: str, conversation_history: ConversationHistory | None = None
) -> AsyncGenerator[str, None]:
"""
Perform a streaming DRIFT search (Not Implemented).
Args:
query (str): The query to search for.
conversation_history (ConversationHistory, optional): The conversation history.
Raises
------
NotImplementedError: Streaming DRIFT search is not implemented.
"""
error_msg = "Streaming DRIFT search is not implemented."
raise NotImplementedError(error_msg)

View File

@ -1,70 +1,84 @@
import networkx as nx
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Manage the state of the DRIFT query, including a graph of actions."""
import logging
import random
from typing import List, Dict, Optional, Any
from typing import Any
import networkx as nx
from graphrag.query.structured_search.drift_search.action import DriftAction
log = logging.getLogger(__name__)
class QueryState:
"""
Manages the state of the query, including a graph of actions.
"""
"""Manage the state of the query, including a graph of actions."""
def __init__(self):
self.graph = nx.MultiDiGraph()
def add_action(self, action: DriftAction, metadata: Optional[Dict[str, Any]] = None):
"""
Adds an action to the graph with optional metadata.
"""
def add_action(
self, action: DriftAction, metadata: dict[str, Any] | None = None
):
"""Add an action to the graph with optional metadata."""
self.graph.add_node(action, **(metadata or {}))
def relate_actions(self, parent: DriftAction, child: DriftAction, weight: float = 1.0):
"""
Relates two actions in the graph.
"""
def relate_actions(
self, parent: DriftAction, child: DriftAction, weight: float = 1.0
):
"""Relate two actions in the graph."""
self.graph.add_edge(parent, child, weight=weight)
def add_all_follow_ups(self, action: DriftAction, follow_ups: List[DriftAction] | List[str], weight: float = 1.0):
"""
Adds all follow-up actions and links them to the given action.
"""
def add_all_follow_ups(
self,
action: DriftAction,
follow_ups: list[DriftAction] | list[str],
weight: float = 1.0,
):
"""Add all follow-up actions and links them to the given action."""
if len(follow_ups) == 0:
raise ValueError("No follow-up actions to add. Please provide a list of follow-up actions.")
log.warning("No follow-up actions for action: %s", action.query)
for follow_up in follow_ups:
if isinstance(follow_up, str):
follow_up = DriftAction(query=follow_up)
else:
log.warning(
"Follow-up action is not a string, found type: %s", type(follow_up)
)
self.add_action(follow_up)
self.relate_actions(action, follow_up, weight)
def find_incomplete_actions(self) -> List[DriftAction]:
"""
Finds all unanswered actions in the graph.
"""
def find_incomplete_actions(self) -> list[DriftAction]:
"""Find all unanswered actions in the graph."""
return [node for node in self.graph.nodes if not node.is_complete]
def rank_incomplete_actions(self, scorer: Any | None = None) -> List[DriftAction]:
"""
Ranks all unanswered actions in the graph if scorer available.
"""
def rank_incomplete_actions(self, scorer: Any | None = None) -> list[DriftAction]:
"""Rank all unanswered actions in the graph if scorer available."""
unanswered = self.find_incomplete_actions()
if scorer:
for node in unanswered:
node.compute_score(scorer)
return list(sorted(
unanswered,
key=lambda node: node.score if node.score is not None else float('-inf'),
reverse=True
))
else: # shuffle the list if no scorer
random.shuffle(unanswered)
return list(unanswered)
return sorted(
unanswered,
key=lambda node: node.score
if node.score is not None
else float("-inf"),
reverse=True,
)
# shuffle the list if no scorer
random.shuffle(unanswered)
return list(unanswered)
def serialize(self) -> Dict[str, Any]:
"""
Serializes the graph to a dictionary, including nodes and edges.
"""
def serialize(
self, include_context: bool = True
) -> dict[str, Any] | tuple[dict[str, Any], dict[str, Any], str]:
"""Serialize the graph to a dictionary, including nodes and edges."""
# Create a mapping from nodes to unique IDs
node_to_id = {node: idx for idx, node in enumerate(self.graph.nodes())}
@ -72,7 +86,7 @@ class QueryState:
nodes = []
for node in self.graph.nodes():
node_data = node.serialize(include_follow_ups=False)
node_data['id'] = node_to_id[node]
node_data["id"] = node_to_id[node]
node_attributes = self.graph.nodes[node]
if node_attributes:
node_data.update(node_attributes)
@ -88,17 +102,25 @@ class QueryState:
}
edges.append(edge_info)
if include_context:
context_data = {}
for node in nodes:
if node["metadata"].get("context_data") and node.get("query"):
context_data[node["query"]] = node["metadata"]["context_data"]
context_text = str(context_data)
return {"nodes": nodes, "edges": edges}, context_data, context_text
return {"nodes": nodes, "edges": edges}
def deserialize(self, data: Dict[str, Any]):
"""
Deserializes the dictionary back to a graph.
"""
def deserialize(self, data: dict[str, Any]):
"""Deserialize the dictionary back to a graph."""
self.graph.clear()
id_to_action = {}
for node_data in data.get("nodes", []):
node_id = node_data.pop('id')
node_id = node_data.pop("id")
action = DriftAction.deserialize(node_data)
self.add_action(action)
id_to_action[node_id] = action
@ -111,3 +133,7 @@ class QueryState:
target_action = id_to_action.get(target_id)
if source_action and target_action:
self.relate_actions(source_action, target_action, weight)
def action_token_ct(self) -> int:
"""Return the token count of the action."""
return sum(action.metadata.get("token_ct", 0) for action in self.graph.nodes)

View File

@ -1,3 +1,8 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""DRIFT Search prompts."""
DRIFT_LOCAL_SYSTEM_PROMPT = """
---Role---
@ -58,17 +63,13 @@ Pay close attention specifically to the Sources tables as it contains the most r
{response_type}
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
Add sections and commentary to the response as appropriate for the length and format.
Additionally provide a score for how well the response addresses the overall research question: {global_query}. Based on your response, suggest a few follow-up questions that could be asked to further explore the topic. Do not include scores or follow up questions in the 'response' field of the JSON, add them to the respective 'score' and 'follow_up_queries' keys of the JSON output. Format your response in JSON with the following keys and values:
Additionally provide a score for how well the response addresses the overall research question: {global_query}. Based on your response, suggest a few follow-up questions that could be asked to further explore the topic. Do not include scores or follow up questions in the 'response' field of the JSON, add them to the respective 'score' and 'follow_up_queries' keys of the JSON output. Generate at least five good follow-up queries. Format your response in JSON with the following keys and values:
'response': str, Put your answer, formatted in markdown, here. Do not answer the global query in this section.
{{'response': str, Put your answer, formatted in markdown, here. Do not answer the global query in this section.
'score': int,
'follow_up_queries': List[str]
1. score: How well the intermediate answer addresses the query. A score of 0 indicates a poor, unfocused answer, while a score of 100 indicates a highly focused, relevant answer that addresses the query in its entirety.
2. follow_up_queries: Use the data provided to generate follow-up queries to help refine your search. Do not ask compound questions, for example: "What is the market cap of Apple and Microsoft?". Use your knowledge of the entity distribution to focus on entity types that will be useful for searching a broad area of the knowledge graph.
'follow_up_queries': List[str]}}
"""
@ -137,18 +138,20 @@ Add sections and commentary to the response as appropriate for the length and fo
"""
DRIFT_PRIMER_PROMPT = """You are a helpful agent designed to reason over a knowledge graph in response to a user query.
This is a unique knowledge graph where edges are freeform text rather than verb operators. You will begin your reasoning looking at a summary of the content of the most relevant community and will provide a score for:
DRIFT_PRIMER_PROMPT = """You are a helpful agent designed to reason over a knowledge graph in response to a user query.
This is a unique knowledge graph where edges are freeform text rather than verb operators. You will begin your reasoning looking at a summary of the content of the most relevant communites and will provide:
1. score: How well the intermediate answer addresses the query. A score of 0 indicates a poor, unfocused answer, while a score of 100 indicates a highly focused, relevant answer that addresses the query in its entirety.
2. intermediate_answer: This answer should match the level of detail and length found in the community summaries. The intermediate answer should be exactly 2000 characters long. This must be formatted in markdown and must begin with a header that explains how the following text is related to the query.
3. follow_up_queries: A list of follow-up queries that could be asked to further explore the topic. These should be formatted as a list of strings. Generate at least five good follow-up queries.
Use this information to help you decide whether or not you need more information about the entities mentioned in the report. You may also use your general knowledge to think of entities which may help enrich your answer.
You will also provide a full answer from the content you have available. Use the data provided to generate follow-up queries to help refine your search. Do not ask compound questions, for example: "What is the market cap of Apple and Microsoft?". Use your knowledge of the entity distribution to focus on entity types that will be useful for searching a broad area of the knowledge graph.
For the query:
For the query:
{query}
@ -156,7 +159,7 @@ The top-ranked community summaries:
{community_reports}
Provide the intermediate answer, and all scores in JSON format with the following format:
Provide the intermediate answer, and all scores in JSON format following:
{{'intermediate_answer': str,
'score': int,

View File

@ -54,7 +54,7 @@ class GlobalSearchResult(SearchResult):
reduce_context_text: str | list[str] | dict[str, str]
class GlobalSearch(BaseSearch):
class GlobalSearch(BaseSearch[GlobalContextBuilder]):
"""Search orchestration for global search mode."""
def __init__(
@ -145,6 +145,7 @@ class GlobalSearch(BaseSearch):
- Step 2: Combine the answers from step 2 to generate the final answer
"""
# Step 1: Generate answers for each batch of community short summaries
start_time = time.time()
context_chunks, context_records = self.context_builder.build_context(
conversation_history=conversation_history, **self.context_builder_params

View File

@ -29,7 +29,7 @@ DEFAULT_LLM_PARAMS = {
log = logging.getLogger(__name__)
class LocalSearch(BaseSearch):
class LocalSearch(BaseSearch[LocalContextBuilder]):
"""Search orchestration for local search mode."""
def __init__(
@ -58,7 +58,6 @@ class LocalSearch(BaseSearch):
self,
query: str,
conversation_history: ConversationHistory | None = None,
drift_query: str | None = None,
**kwargs,
) -> SearchResult:
"""Build local search context that fits a single context window and generate answer for the user query."""
@ -73,12 +72,14 @@ class LocalSearch(BaseSearch):
)
log.info("GENERATE ANSWER: %s. QUERY: %s", start_time, query)
try:
if drift_query:
if "drift_query" in kwargs:
drift_query = kwargs["drift_query"]
search_prompt = self.system_prompt.format(
context_data=context_text, response_type=self.response_type, global_query=drift_query,
context_data=context_text,
response_type=self.response_type,
global_query=drift_query,
)
else:
search_prompt = self.system_prompt.format(
context_data=context_text, response_type=self.response_type
)