From e2a448170a36f8f0ace317f1eff8b89949e08f4a Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Fri, 25 Apr 2025 14:12:18 -0700 Subject: [PATCH] Fix/minor query fixes (#1893) * fixed token count for drift search * basic search fixes * updated basic search prompt * fixed text splitting logic * Lint/format * Semver * Fix text splitting tests --------- Co-authored-by: ha2trinh --- .../patch-20250423234829757628.json | 4 + graphrag/config/defaults.py | 1 + graphrag/config/models/basic_search_config.py | 4 + .../index/text_splitting/text_splitting.py | 4 + .../query/basic_search_system_prompt.py | 29 ++++--- graphrag/query/factory.py | 3 + .../basic_search/basic_context.py | 84 +++++++++++++++---- .../structured_search/basic_search/search.py | 6 ++ .../structured_search/drift_search/search.py | 2 +- .../text_splitting/test_text_splitting.py | 2 - 10 files changed, 108 insertions(+), 31 deletions(-) create mode 100644 .semversioner/next-release/patch-20250423234829757628.json diff --git a/.semversioner/next-release/patch-20250423234829757628.json b/.semversioner/next-release/patch-20250423234829757628.json new file mode 100644 index 00000000..7af4894f --- /dev/null +++ b/.semversioner/next-release/patch-20250423234829757628.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Fixes to basic search." +} diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index cea7ba8e..627544ec 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -42,6 +42,7 @@ class BasicSearchDefaults: prompt: None = None k: int = 10 + max_context_tokens: int = 12_000 chat_model_id: str = DEFAULT_CHAT_MODEL_ID embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID diff --git a/graphrag/config/models/basic_search_config.py b/graphrag/config/models/basic_search_config.py index 8221cd3f..66a1e685 100644 --- a/graphrag/config/models/basic_search_config.py +++ b/graphrag/config/models/basic_search_config.py @@ -27,3 +27,7 @@ class BasicSearchConfig(BaseModel): description="The number of text units to include in search context.", default=graphrag_config_defaults.basic_search.k, ) + max_context_tokens: int = Field( + description="The maximum tokens.", + default=graphrag_config_defaults.basic_search.max_context_tokens, + ) diff --git a/graphrag/index/text_splitting/text_splitting.py b/graphrag/index/text_splitting/text_splitting.py index 16329046..57f2f236 100644 --- a/graphrag/index/text_splitting/text_splitting.py +++ b/graphrag/index/text_splitting/text_splitting.py @@ -152,6 +152,8 @@ def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]: while start_idx < len(input_ids): chunk_text = tokenizer.decode(list(chunk_ids)) result.append(chunk_text) # Append chunked text as string + if cur_idx == len(input_ids): + break start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] @@ -186,6 +188,8 @@ def split_multiple_texts_on_tokens( chunk_text = tokenizer.decode([id for _, id in chunk_ids]) doc_indices = list({doc_idx for doc_idx, _ in chunk_ids}) result.append(TextChunk(chunk_text, doc_indices, len(chunk_ids))) + if cur_idx == len(input_ids): + break start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] diff --git a/graphrag/prompts/query/basic_search_system_prompt.py b/graphrag/prompts/query/basic_search_system_prompt.py index f98ea058..a20fb6ad 100644 --- a/graphrag/prompts/query/basic_search_system_prompt.py +++ b/graphrag/prompts/query/basic_search_system_prompt.py @@ -11,23 +11,25 @@ You are a helpful assistant responding to questions about data in the tables pro ---Goal--- -Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. +Generate a response of the target length and format that responds to the user's question, summarizing all relevant information in the input data tables appropriate for the response length and format. -If you don't know the answer, just say so. Do not make anything up. +You should use the data provided in the data tables below as the primary context for generating the response. + +If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. Points supported by data should list their data references as follows: -"This is an example sentence supported by multiple text references [Data: Sources (record ids)]." +"This is an example sentence supported by multiple data references [Data: Sources (record ids)]." Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. For example: -"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16)]." +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Sources (1, 3)]" -where 15 and 16 represent the id (not the index) of the relevant data record. +where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "source_id" column in the provided tables. -Do not include information where the supporting text for it is not provided. +Do not include information where the supporting evidence for it is not provided. ---Target response length and format--- @@ -42,23 +44,26 @@ Do not include information where the supporting text for it is not provided. ---Goal--- -Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. +Generate a response of the target length and format that responds to the user's question, summarizing all relevant information in the input data appropriate for the response length and format. -If you don't know the answer, just say so. Do not make anything up. +You should use the data provided in the data tables below as the primary context for generating the response. + +If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. Points supported by data should list their data references as follows: -"This is an example sentence supported by multiple text references [Data: Sources (record ids)]." +"This is an example sentence supported by multiple data references [Data: Sources (record ids)]." Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. For example: -"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16)]." +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Sources (1, 3)]" -where 15 and 16 represent the id (not the index) of the relevant data record. +where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "source_id" column in the provided tables. + +Do not include information where the supporting evidence for it is not provided. -Do not include information where the supporting text for it is not provided. ---Target response length and format--- diff --git a/graphrag/query/factory.py b/graphrag/query/factory.py index 76fa1f43..907c83ca 100644 --- a/graphrag/query/factory.py +++ b/graphrag/query/factory.py @@ -275,6 +275,7 @@ def get_basic_search_engine( text_unit_embeddings: BaseVectorStore, config: GraphRagConfig, system_prompt: str | None = None, + response_type: str = "multiple paragraphs", callbacks: list[QueryCallbacks] | None = None, ) -> BasicSearch: """Create a basic search engine based on data + configuration.""" @@ -312,6 +313,7 @@ def get_basic_search_engine( return BasicSearch( model=chat_model, system_prompt=system_prompt, + response_type=response_type, context_builder=BasicSearchContext( text_embedder=embedding_model, text_unit_embeddings=text_unit_embeddings, @@ -323,6 +325,7 @@ def get_basic_search_engine( context_builder_params={ "embedding_vectorstore_key": "id", "k": bs_config.k, + "max_context_tokens": bs_config.max_context_tokens, }, callbacks=callbacks, ) diff --git a/graphrag/query/structured_search/basic_search/basic_context.py b/graphrag/query/structured_search/basic_search/basic_context.py index 0dd139ea..d08bdc2e 100644 --- a/graphrag/query/structured_search/basic_search/basic_context.py +++ b/graphrag/query/structured_search/basic_search/basic_context.py @@ -3,6 +3,9 @@ """Basic Context Builder implementation.""" +import logging +from typing import cast + import pandas as pd import tiktoken @@ -13,8 +16,11 @@ from graphrag.query.context_builder.builders import ( ContextBuilderResult, ) from graphrag.query.context_builder.conversation_history import ConversationHistory +from graphrag.query.llm.text_utils import num_tokens from graphrag.vector_stores.base import BaseVectorStore +log = logging.getLogger(__name__) + class BasicSearchContext(BasicContextBuilder): """Class representing the Basic Search Context Builder.""" @@ -32,30 +38,76 @@ class BasicSearchContext(BasicContextBuilder): self.text_units = text_units self.text_unit_embeddings = text_unit_embeddings self.embedding_vectorstore_key = embedding_vectorstore_key + self.text_id_map = self._map_ids() def build_context( self, query: str, conversation_history: ConversationHistory | None = None, + k: int = 10, + max_context_tokens: int = 12_000, + context_name: str = "Sources", + column_delimiter: str = "|", + text_id_col: str = "source_id", + text_col: str = "text", **kwargs, ) -> ContextBuilderResult: - """Build the context for the local search mode.""" - search_results = self.text_unit_embeddings.similarity_search_by_text( - text=query, - text_embedder=lambda t: self.text_embedder.embed(t), - k=kwargs.get("k", 10), - ) - # we don't have a friendly id on text_units, so just copy the index - sources = [ - {"id": str(search_results.index(r)), "text": r.document.text} - for r in search_results - ] - # make a delimited table for the context; this imitates graphrag context building - table = ["id|text"] + [f"{s['id']}|{s['text']}" for s in sources] + """Build the context for the basic search mode.""" + if query != "": + related_texts = self.text_unit_embeddings.similarity_search_by_text( + text=query, + text_embedder=lambda t: self.text_embedder.embed(t), + k=k, + ) + related_text_list = [ + { + text_id_col: self.text_id_map[f"{chunk.document.id}"], + text_col: chunk.document.text, + } + for chunk in related_texts + ] + related_text_df = pd.DataFrame(related_text_list) + else: + related_text_df = pd.DataFrame({ + text_id_col: [], + text_col: [], + }) - columns = pd.Index(["id", "text"]) + # add these related text chunks into context until we fill up the context window + current_tokens = 0 + text_ids = [] + current_tokens = num_tokens( + text_id_col + column_delimiter + text_col + "\n", self.token_encoder + ) + for i, row in related_text_df.iterrows(): + text = row[text_id_col] + column_delimiter + row[text_col] + "\n" + tokens = num_tokens(text, self.token_encoder) + if current_tokens + tokens > max_context_tokens: + msg = f"Reached token limit: {current_tokens + tokens}. Reverting to previous context state" + log.info(msg) + break + + current_tokens += tokens + text_ids.append(i) + final_text_df = cast( + "pd.DataFrame", + related_text_df[related_text_df.index.isin(text_ids)].reset_index( + drop=True + ), + ) + final_text = final_text_df.to_csv( + index=False, escapechar="\\", sep=column_delimiter + ) return ContextBuilderResult( - context_chunks="\n\n".join(table), - context_records={"sources": pd.DataFrame(sources, columns=columns)}, + context_chunks=final_text, + context_records={context_name: final_text_df}, ) + + def _map_ids(self) -> dict[str, str]: + """Map id to short id in the text units.""" + id_map = {} + text_units = self.text_units or [] + for unit in text_units: + id_map[unit.id] = unit.short_id + return id_map diff --git a/graphrag/query/structured_search/basic_search/search.py b/graphrag/query/structured_search/basic_search/search.py index e2fb29c0..a5dca578 100644 --- a/graphrag/query/structured_search/basic_search/search.py +++ b/graphrag/query/structured_search/basic_search/search.py @@ -108,6 +108,9 @@ class BasicSearch(BaseSearch[BasicContextBuilder]): llm_calls=1, prompt_tokens=num_tokens(search_prompt, self.token_encoder), output_tokens=sum(output_tokens.values()), + llm_calls_categories=llm_calls, + prompt_tokens_categories=prompt_tokens, + output_tokens_categories=output_tokens, ) except Exception: @@ -120,6 +123,9 @@ class BasicSearch(BaseSearch[BasicContextBuilder]): llm_calls=1, prompt_tokens=num_tokens(search_prompt, self.token_encoder), output_tokens=0, + llm_calls_categories=llm_calls, + prompt_tokens_categories=prompt_tokens, + output_tokens_categories=output_tokens, ) async def stream_search( diff --git a/graphrag/query/structured_search/drift_search/search.py b/graphrag/query/structured_search/drift_search/search.py index 2099e359..10d62343 100644 --- a/graphrag/query/structured_search/drift_search/search.py +++ b/graphrag/query/structured_search/drift_search/search.py @@ -213,7 +213,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]): primer_context, token_ct = await self.context_builder.build_context(query) llm_calls["build_context"] = token_ct["llm_calls"] prompt_tokens["build_context"] = token_ct["prompt_tokens"] - output_tokens["build_context"] = token_ct["prompt_tokens"] + output_tokens["build_context"] = token_ct["output_tokens"] primer_response = await self.primer.search( query=query, top_k_reports=primer_context diff --git a/tests/unit/indexing/text_splitting/test_text_splitting.py b/tests/unit/indexing/text_splitting/test_text_splitting.py index da87d473..10a5a063 100644 --- a/tests/unit/indexing/text_splitting/test_text_splitting.py +++ b/tests/unit/indexing/text_splitting/test_text_splitting.py @@ -136,7 +136,6 @@ def test_split_single_text_on_tokens(): " by this t", "his test o", "est only.", - "nly.", ] result = split_single_text_on_tokens(text=text, tokenizer=tokenizer) @@ -197,7 +196,6 @@ def test_split_single_text_on_tokens_no_overlap(): " this test", " test only", " only.", - ".", ] result = split_single_text_on_tokens(text=text, tokenizer=tokenizer)