diff --git a/.semversioner/next-release/patch-20240827212041426794.json b/.semversioner/next-release/patch-20240827212041426794.json new file mode 100644 index 00000000..f3646013 --- /dev/null +++ b/.semversioner/next-release/patch-20240827212041426794.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Refactor text unit build at local search" +} diff --git a/graphrag/query/structured_search/local_search/mixed_context.py b/graphrag/query/structured_search/local_search/mixed_context.py index 117101e9..21eac2e8 100644 --- a/graphrag/query/structured_search/local_search/mixed_context.py +++ b/graphrag/query/structured_search/local_search/mixed_context.py @@ -309,42 +309,36 @@ class LocalSearchMixedContext(LocalContextBuilder): context_name: str = "Sources", ) -> tuple[str, dict[str, pd.DataFrame]]: """Rank matching text units and add them to the context window until it hits the max_tokens limit.""" - if len(selected_entities) == 0 or len(self.text_units) == 0: + if not selected_entities or not self.text_units: return ("", {context_name.lower(): pd.DataFrame()}) - selected_text_units = list[TextUnit]() - # for each matching text unit, rank first by the order of the entities that match it, then by the number of matching relationships - # that the text unit has with the matching entities - for index, entity in enumerate(selected_entities): - if entity.text_unit_ids: - for text_id in entity.text_unit_ids: - if ( - text_id not in [unit.id for unit in selected_text_units] - and text_id in self.text_units - ): - selected_unit = self.text_units[text_id] - num_relationships = count_relationships( - selected_unit, entity, self.relationships - ) - if selected_unit.attributes is None: - selected_unit.attributes = {} - selected_unit.attributes["entity_order"] = index - selected_unit.attributes["num_relationships"] = ( - num_relationships - ) - selected_text_units.append(selected_unit) + selected_text_units = [] + text_unit_ids_set = set() + + for index, entity in enumerate(selected_entities): + for text_id in entity.text_unit_ids or []: + if text_id not in text_unit_ids_set and text_id in self.text_units: + text_unit_ids_set.add(text_id) + selected_unit = self.text_units[text_id] + num_relationships = count_relationships( + selected_unit, entity, self.relationships + ) + if selected_unit.attributes is None: + selected_unit.attributes = {} + selected_unit.attributes["entity_order"] = index + selected_unit.attributes["num_relationships"] = num_relationships + selected_text_units.append(selected_unit) - # sort selected text units by ascending order of entity order and descending order of number of relationships selected_text_units.sort( key=lambda x: ( - x.attributes["entity_order"], # type: ignore - -x.attributes["num_relationships"], # type: ignore + x.attributes["entity_order"], + -x.attributes["num_relationships"], ) ) for unit in selected_text_units: - del unit.attributes["entity_order"] # type: ignore - del unit.attributes["num_relationships"] # type: ignore + unit.attributes.pop("entity_order", None) + unit.attributes.pop("num_relationships", None) context_text, context_data = build_text_unit_context( text_units=selected_text_units, @@ -362,8 +356,8 @@ class LocalSearchMixedContext(LocalContextBuilder): ) context_key = context_name.lower() if context_key not in context_data: + candidate_context_data["in_context"] = False context_data[context_key] = candidate_context_data - context_data[context_key]["in_context"] = False else: if ( "id" in candidate_context_data.columns @@ -371,12 +365,11 @@ class LocalSearchMixedContext(LocalContextBuilder): ): candidate_context_data["in_context"] = candidate_context_data[ "id" - ].isin( # cspell:disable-line - context_data[context_key]["id"] - ) + ].isin(context_data[context_key]["id"]) context_data[context_key] = candidate_context_data else: context_data[context_key]["in_context"] = True + return (str(context_text), context_data) def _build_local_context(