mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
fix changes
This commit is contained in:
parent
4877ba675b
commit
5bf8ce6a56
@ -17,11 +17,13 @@ from graphrag.query.context_builder.conversation_history import (
|
||||
class ContextBuilderResult:
|
||||
"""A class to hold the results of the build_context."""
|
||||
|
||||
context_chunks: str | list[str]
|
||||
context_records: dict[str, pd.DataFrame]
|
||||
llm_calls: int = 0
|
||||
prompt_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMParameters:
|
||||
"""A class to hold LLM call parameters."""
|
||||
@ -115,4 +117,3 @@ class BasicContextBuilder(ABC):
|
||||
@abstractmethod
|
||||
def get_llm_values(self) -> LLMParameters:
|
||||
"""Get the LLM call values."""
|
||||
|
||||
|
||||
@ -42,6 +42,7 @@ class SearchResult:
|
||||
prompt_tokens_categories: dict[str, int] | None = None
|
||||
output_tokens_categories: dict[str, int] | None = None
|
||||
|
||||
|
||||
T = TypeVar(
|
||||
"T",
|
||||
GlobalContextBuilder,
|
||||
@ -89,7 +90,7 @@ class BaseSearch(ABC, Generic[T]):
|
||||
yield "" # This makes it an async generator.
|
||||
msg = "Subclasses must implement this method"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def format_records(
|
||||
self,
|
||||
|
||||
@ -95,7 +95,7 @@ class BasicSearchContext(BasicContextBuilder):
|
||||
)
|
||||
|
||||
return {context_name.lower(): final_text_df}
|
||||
|
||||
|
||||
def get_llm_values(self) -> LLMParameters:
|
||||
"""Get the LLM call values."""
|
||||
return LLMParameters(
|
||||
|
||||
@ -52,15 +52,15 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
|
||||
self.callbacks = callbacks or []
|
||||
self.response_type = response_type
|
||||
|
||||
async def format_records(self, records, column_delimiter = "|") -> str | list[str]:
|
||||
async def format_records(self, records, column_delimiter="|") -> str | list[str]:
|
||||
"""Format context records into a string representation."""
|
||||
if len(records) == 1:
|
||||
_, context_records_df = next(iter(records.items()))
|
||||
_, context_records_df = next(iter(records.items()))
|
||||
|
||||
if context_records_df is not None:
|
||||
return context_records_df.to_csv(
|
||||
index=False, escapechar="\\", sep=column_delimiter
|
||||
)
|
||||
if context_records_df is not None:
|
||||
return context_records_df.to_csv(
|
||||
index=False, escapechar="\\", sep=column_delimiter
|
||||
)
|
||||
return ""
|
||||
|
||||
async def search(
|
||||
@ -94,7 +94,7 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
|
||||
records=context_records,
|
||||
column_delimiter=column_delimiter,
|
||||
)
|
||||
|
||||
|
||||
search_prompt = self.system_prompt.format(
|
||||
context_data=context_chunks,
|
||||
response_type=self.response_type,
|
||||
@ -167,8 +167,9 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
|
||||
logger.debug("GENERATE ANSWER: %s. QUERY: %s", start_time, query)
|
||||
|
||||
context_chunks = await self.format_records(
|
||||
records=context_records,
|
||||
column_delimiter=column_delimiter,)
|
||||
records=context_records,
|
||||
column_delimiter=column_delimiter,
|
||||
)
|
||||
|
||||
search_prompt = self.system_prompt.format(
|
||||
context_data=context_chunks, response_type=self.response_type
|
||||
|
||||
@ -300,6 +300,17 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
output_tokens_categories=output_tokens,
|
||||
)
|
||||
|
||||
async def format_records(self, records, column_delimiter="|") -> str | list[str]:
|
||||
"""Format context records into a string representation."""
|
||||
if len(records) == 1:
|
||||
_, context_records_df = next(iter(records.items()))
|
||||
|
||||
if context_records_df is not None:
|
||||
return context_records_df.to_csv(
|
||||
index=False, escapechar="\\", sep=column_delimiter
|
||||
)
|
||||
return ""
|
||||
|
||||
async def stream_search(
|
||||
self, query: str, conversation_history: ConversationHistory | None = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
|
||||
@ -96,6 +96,17 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
|
||||
self.semaphore = asyncio.Semaphore(concurrent_coroutines)
|
||||
|
||||
async def format_records(self, records, column_delimiter="|") -> str | list[str]:
|
||||
"""Format context records into a string representation."""
|
||||
if len(records) == 1:
|
||||
_, context_records_df = next(iter(records.items()))
|
||||
|
||||
if context_records_df is not None:
|
||||
return context_records_df.to_csv(
|
||||
index=False, escapechar="\\", sep=column_delimiter
|
||||
)
|
||||
return ""
|
||||
|
||||
async def stream_search(
|
||||
self,
|
||||
query: str,
|
||||
|
||||
@ -48,6 +48,17 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
self.callbacks = callbacks or []
|
||||
self.response_type = response_type
|
||||
|
||||
async def format_records(self, records, column_delimiter="|") -> str | list[str]:
|
||||
"""Format context records into a string representation."""
|
||||
if len(records) == 1:
|
||||
_, context_records_df = next(iter(records.items()))
|
||||
|
||||
if context_records_df is not None:
|
||||
return context_records_df.to_csv(
|
||||
index=False, escapechar="\\", sep=column_delimiter
|
||||
)
|
||||
return ""
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user