From 22df2f80d02b5f514fd70a6f96728c0a236190ad Mon Sep 17 00:00:00 2001 From: Alonso Guevara Date: Tue, 27 Aug 2024 16:15:16 -0600 Subject: [PATCH] Fix/text unit code cleanup (#1040) * Optimized _build_text_unit_context function for improved time and space complexity Refactored the _build_text_unit_context function to enhance its performance and efficiency. Key optimizations include: 1. Set for Text Unit IDs: Replaced list-based membership checks with a set (text_unit_ids_set) to achieve constant-time complexity for membership checks, reducing overall time complexity. 2. Direct Attribute Removal: Utilized pop with a default value (None) to directly remove attributes entity_order and num_relationships from text units, minimizing overhead and avoiding potential KeyError. 3. Default Dictionary for Entity Orders: Implemented defaultdict for managing entity orders, simplifying the ranking process and improving readability. These improvements result in a more efficient function with better performance, especially when handling large datasets or numerous selected entities. The refactoring ensures that the core functionality remains unchanged while enhancing both time and space complexity. * Format * Ruff fixes * semver --------- Co-authored-by: arjun-234 Co-authored-by: Arjun D. <103405661+arjun-234@users.noreply.github.com> --- .../patch-20240827212041426794.json | 4 ++ .../local_search/mixed_context.py | 55 ++++++++----------- 2 files changed, 28 insertions(+), 31 deletions(-) create mode 100644 .semversioner/next-release/patch-20240827212041426794.json diff --git a/.semversioner/next-release/patch-20240827212041426794.json b/.semversioner/next-release/patch-20240827212041426794.json new file mode 100644 index 00000000..f3646013 --- /dev/null +++ b/.semversioner/next-release/patch-20240827212041426794.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Refactor text unit build at local search" +} diff --git a/graphrag/query/structured_search/local_search/mixed_context.py b/graphrag/query/structured_search/local_search/mixed_context.py index 117101e9..21eac2e8 100644 --- a/graphrag/query/structured_search/local_search/mixed_context.py +++ b/graphrag/query/structured_search/local_search/mixed_context.py @@ -309,42 +309,36 @@ class LocalSearchMixedContext(LocalContextBuilder): context_name: str = "Sources", ) -> tuple[str, dict[str, pd.DataFrame]]: """Rank matching text units and add them to the context window until it hits the max_tokens limit.""" - if len(selected_entities) == 0 or len(self.text_units) == 0: + if not selected_entities or not self.text_units: return ("", {context_name.lower(): pd.DataFrame()}) - selected_text_units = list[TextUnit]() - # for each matching text unit, rank first by the order of the entities that match it, then by the number of matching relationships - # that the text unit has with the matching entities - for index, entity in enumerate(selected_entities): - if entity.text_unit_ids: - for text_id in entity.text_unit_ids: - if ( - text_id not in [unit.id for unit in selected_text_units] - and text_id in self.text_units - ): - selected_unit = self.text_units[text_id] - num_relationships = count_relationships( - selected_unit, entity, self.relationships - ) - if selected_unit.attributes is None: - selected_unit.attributes = {} - selected_unit.attributes["entity_order"] = index - selected_unit.attributes["num_relationships"] = ( - num_relationships - ) - selected_text_units.append(selected_unit) + selected_text_units = [] + text_unit_ids_set = set() + + for index, entity in enumerate(selected_entities): + for text_id in entity.text_unit_ids or []: + if text_id not in text_unit_ids_set and text_id in self.text_units: + text_unit_ids_set.add(text_id) + selected_unit = self.text_units[text_id] + num_relationships = count_relationships( + selected_unit, entity, self.relationships + ) + if selected_unit.attributes is None: + selected_unit.attributes = {} + selected_unit.attributes["entity_order"] = index + selected_unit.attributes["num_relationships"] = num_relationships + selected_text_units.append(selected_unit) - # sort selected text units by ascending order of entity order and descending order of number of relationships selected_text_units.sort( key=lambda x: ( - x.attributes["entity_order"], # type: ignore - -x.attributes["num_relationships"], # type: ignore + x.attributes["entity_order"], + -x.attributes["num_relationships"], ) ) for unit in selected_text_units: - del unit.attributes["entity_order"] # type: ignore - del unit.attributes["num_relationships"] # type: ignore + unit.attributes.pop("entity_order", None) + unit.attributes.pop("num_relationships", None) context_text, context_data = build_text_unit_context( text_units=selected_text_units, @@ -362,8 +356,8 @@ class LocalSearchMixedContext(LocalContextBuilder): ) context_key = context_name.lower() if context_key not in context_data: + candidate_context_data["in_context"] = False context_data[context_key] = candidate_context_data - context_data[context_key]["in_context"] = False else: if ( "id" in candidate_context_data.columns @@ -371,12 +365,11 @@ class LocalSearchMixedContext(LocalContextBuilder): ): candidate_context_data["in_context"] = candidate_context_data[ "id" - ].isin( # cspell:disable-line - context_data[context_key]["id"] - ) + ].isin(context_data[context_key]["id"]) context_data[context_key] = candidate_context_data else: context_data[context_key]["in_context"] = True + return (str(context_text), context_data) def _build_local_context(