fix changes

This commit is contained in:
Gaudy Blanco 2025-12-05 19:20:50 -06:00
parent 4877ba675b
commit 5bf8ce6a56
7 changed files with 48 additions and 12 deletions

View File

@ -17,11 +17,13 @@ from graphrag.query.context_builder.conversation_history import (
class ContextBuilderResult:
"""A class to hold the results of the build_context."""
context_chunks: str | list[str]
context_records: dict[str, pd.DataFrame]
llm_calls: int = 0
prompt_tokens: int = 0
output_tokens: int = 0
@dataclass
class LLMParameters:
"""A class to hold LLM call parameters."""
@ -115,4 +117,3 @@ class BasicContextBuilder(ABC):
@abstractmethod
def get_llm_values(self) -> LLMParameters:
"""Get the LLM call values."""

View File

@ -42,6 +42,7 @@ class SearchResult:
prompt_tokens_categories: dict[str, int] | None = None
output_tokens_categories: dict[str, int] | None = None
T = TypeVar(
"T",
GlobalContextBuilder,
@ -89,7 +90,7 @@ class BaseSearch(ABC, Generic[T]):
yield "" # This makes it an async generator.
msg = "Subclasses must implement this method"
raise NotImplementedError(msg)
@abstractmethod
async def format_records(
self,

View File

@ -95,7 +95,7 @@ class BasicSearchContext(BasicContextBuilder):
)
return {context_name.lower(): final_text_df}
def get_llm_values(self) -> LLMParameters:
"""Get the LLM call values."""
return LLMParameters(

View File

@ -52,15 +52,15 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
self.callbacks = callbacks or []
self.response_type = response_type
async def format_records(self, records, column_delimiter = "|") -> str | list[str]:
async def format_records(self, records, column_delimiter="|") -> str | list[str]:
"""Format context records into a string representation."""
if len(records) == 1:
_, context_records_df = next(iter(records.items()))
_, context_records_df = next(iter(records.items()))
if context_records_df is not None:
return context_records_df.to_csv(
index=False, escapechar="\\", sep=column_delimiter
)
if context_records_df is not None:
return context_records_df.to_csv(
index=False, escapechar="\\", sep=column_delimiter
)
return ""
async def search(
@ -94,7 +94,7 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
records=context_records,
column_delimiter=column_delimiter,
)
search_prompt = self.system_prompt.format(
context_data=context_chunks,
response_type=self.response_type,
@ -167,8 +167,9 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
logger.debug("GENERATE ANSWER: %s. QUERY: %s", start_time, query)
context_chunks = await self.format_records(
records=context_records,
column_delimiter=column_delimiter,)
records=context_records,
column_delimiter=column_delimiter,
)
search_prompt = self.system_prompt.format(
context_data=context_chunks, response_type=self.response_type

View File

@ -300,6 +300,17 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
output_tokens_categories=output_tokens,
)
async def format_records(self, records, column_delimiter="|") -> str | list[str]:
"""Format context records into a string representation."""
if len(records) == 1:
_, context_records_df = next(iter(records.items()))
if context_records_df is not None:
return context_records_df.to_csv(
index=False, escapechar="\\", sep=column_delimiter
)
return ""
async def stream_search(
self, query: str, conversation_history: ConversationHistory | None = None
) -> AsyncGenerator[str, None]:

View File

@ -96,6 +96,17 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
self.semaphore = asyncio.Semaphore(concurrent_coroutines)
async def format_records(self, records, column_delimiter="|") -> str | list[str]:
"""Format context records into a string representation."""
if len(records) == 1:
_, context_records_df = next(iter(records.items()))
if context_records_df is not None:
return context_records_df.to_csv(
index=False, escapechar="\\", sep=column_delimiter
)
return ""
async def stream_search(
self,
query: str,

View File

@ -48,6 +48,17 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
self.callbacks = callbacks or []
self.response_type = response_type
async def format_records(self, records, column_delimiter="|") -> str | list[str]:
"""Format context records into a string representation."""
if len(records) == 1:
_, context_records_df = next(iter(records.items()))
if context_records_df is not None:
return context_records_df.to_csv(
index=False, escapechar="\\", sep=column_delimiter
)
return ""
async def search(
self,
query: str,