diff --git a/.semversioner/next-release/patch-20240930230641593846.json b/.semversioner/next-release/patch-20240930230641593846.json new file mode 100644 index 00000000..90e7a6eb --- /dev/null +++ b/.semversioner/next-release/patch-20240930230641593846.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Collapse entity extraction." +} diff --git a/graphrag/index/verbs/entities/extraction/entity_extract.py b/graphrag/index/verbs/entities/extraction/entity_extract.py index 4e961f67..2487c900 100644 --- a/graphrag/index/verbs/entities/extraction/entity_extract.py +++ b/graphrag/index/verbs/entities/extraction/entity_extract.py @@ -54,6 +54,38 @@ async def entity_extract( entity_types=DEFAULT_ENTITY_TYPES, **kwargs, ) -> TableContainer: + """Extract entities from a piece of text.""" + source = cast(pd.DataFrame, input.get_input()) + output = await entity_extract_df( + source, + cache, + callbacks, + column, + id_column, + to, + strategy, + graph_to, + async_mode, + entity_types, + **kwargs, + ) + + return TableContainer(table=output) + + +async def entity_extract_df( + input: pd.DataFrame, + cache: PipelineCache, + callbacks: VerbCallbacks, + column: str, + id_column: str, + to: str, + strategy: dict[str, Any] | None, + graph_to: str | None = None, + async_mode: AsyncType = AsyncType.AsyncIO, + entity_types=DEFAULT_ENTITY_TYPES, + **kwargs, +) -> pd.DataFrame: """ Extract entities from a piece of text. @@ -135,7 +167,6 @@ async def entity_extract( log.debug("entity_extract strategy=%s", strategy) if entity_types is None: entity_types = DEFAULT_ENTITY_TYPES - output = cast(pd.DataFrame, input.get_input()) strategy = strategy or {} strategy_exec = _load_strategy( strategy.get("type", ExtractEntityStrategyType.graph_intelligence) @@ -159,7 +190,7 @@ async def entity_extract( return [result.entities, result.graphml_graph] results = await derive_from_rows( - output, + input, run_strategy, callbacks, scheduling_type=async_mode, @@ -176,11 +207,11 @@ async def entity_extract( to_result.append(None) graph_to_result.append(None) - output[to] = to_result + input[to] = to_result if graph_to is not None: - output[graph_to] = graph_to_result + input[graph_to] = graph_to_result - return TableContainer(table=output.reset_index(drop=True)) + return input.reset_index(drop=True) def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy: diff --git a/graphrag/index/verbs/graph/merge/merge_graphs.py b/graphrag/index/verbs/graph/merge/merge_graphs.py index a551e4ce..d5a4c9f5 100644 --- a/graphrag/index/verbs/graph/merge/merge_graphs.py +++ b/graphrag/index/verbs/graph/merge/merge_graphs.py @@ -34,6 +34,21 @@ def merge_graphs( edges: dict[str, Any] = DEFAULT_EDGE_OPERATIONS, **_kwargs, ) -> TableContainer: + """Merge multiple graphs together. The graphs are expected to be in graphml format. The verb outputs a new column containing the merged graph.""" + input_df = cast(pd.DataFrame, input.get_input()) + output = merge_graphs_df(input_df, callbacks, column, to, nodes, edges) + + return TableContainer(table=output) + + +def merge_graphs_df( + input: pd.DataFrame, + callbacks: VerbCallbacks, + column: str, + to: str, + nodes: dict[str, Any] = DEFAULT_NODE_OPERATIONS, + edges: dict[str, Any] = DEFAULT_EDGE_OPERATIONS, +) -> pd.DataFrame: """ Merge multiple graphs together. The graphs are expected to be in graphml format. The verb outputs a new column containing the merged graph. @@ -82,7 +97,6 @@ def merge_graphs( - __average__: This operation takes the mean of the attribute with the last value seen. - __multiply__: This operation multiplies the attribute with the last value seen. """ - input_df = input.get_input() output = pd.DataFrame() node_ops = { @@ -95,15 +109,15 @@ def merge_graphs( } mega_graph = nx.Graph() - num_total = len(input_df) - for graphml in progress_iterable(input_df[column], callbacks.progress, num_total): + num_total = len(input) + for graphml in progress_iterable(input[column], callbacks.progress, num_total): graph = load_graph(cast(str | nx.Graph, graphml)) merge_nodes(mega_graph, graph, node_ops) merge_edges(mega_graph, graph, edge_ops) output[to] = ["\n".join(nx.generate_graphml(mega_graph))] - return TableContainer(table=output) + return output def merge_nodes( diff --git a/graphrag/index/verbs/snapshot.py b/graphrag/index/verbs/snapshot.py index a90fc283..3789a4c9 100644 --- a/graphrag/index/verbs/snapshot.py +++ b/graphrag/index/verbs/snapshot.py @@ -3,6 +3,9 @@ """A module containing snapshot method definition.""" +from typing import cast + +import pandas as pd from datashaper import TableContainer, VerbInput, verb from graphrag.index.storage import PipelineStorage @@ -17,14 +20,24 @@ async def snapshot( **_kwargs: dict, ) -> TableContainer: """Take a entire snapshot of the tabular data.""" - data = input.get_input() + data = cast(pd.DataFrame, input.get_input()) - for fmt in formats: - if fmt == "parquet": - await storage.set(name + ".parquet", data.to_parquet()) - elif fmt == "json": - await storage.set( - name + ".json", data.to_json(orient="records", lines=True) - ) + await snapshot_df(data, name, formats, storage) return TableContainer(table=data) + + +async def snapshot_df( + input: pd.DataFrame, + name: str, + formats: list[str], + storage: PipelineStorage, +): + """Take a entire snapshot of the tabular data.""" + for fmt in formats: + if fmt == "parquet": + await storage.set(name + ".parquet", input.to_parquet()) + elif fmt == "json": + await storage.set( + name + ".json", input.to_json(orient="records", lines=True) + ) diff --git a/graphrag/index/workflows/v1/create_base_extracted_entities.py b/graphrag/index/workflows/v1/create_base_extracted_entities.py index 30d608e9..7d0ea603 100644 --- a/graphrag/index/workflows/v1/create_base_extracted_entities.py +++ b/graphrag/index/workflows/v1/create_base_extracted_entities.py @@ -20,76 +20,65 @@ def build_steps( * `workflow:create_base_text_units` """ entity_extraction_config = config.get("entity_extract", {}) + + column = entity_extraction_config.get("text_column", "chunk") + id_column = entity_extraction_config.get("id_column", "chunk_id") + async_mode = entity_extraction_config.get("async_mode", AsyncType.AsyncIO) + strategy = entity_extraction_config.get("strategy") + num_threads = entity_extraction_config.get("num_threads", 4) + entity_types = entity_extraction_config.get("entity_types") + + graph_merge_operations_config = config.get( + "graph_merge_operations", + { + "nodes": { + "source_id": { + "operation": "concat", + "delimiter": ", ", + "distinct": True, + }, + "description": ({ + "operation": "concat", + "separator": "\n", + "distinct": False, + }), + }, + "edges": { + "source_id": { + "operation": "concat", + "delimiter": ", ", + "distinct": True, + }, + "description": ({ + "operation": "concat", + "separator": "\n", + "distinct": False, + }), + "weight": "sum", + }, + }, + ) + nodes = graph_merge_operations_config.get("nodes") + edges = graph_merge_operations_config.get("edges") + graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False raw_entity_snapshot_enabled = config.get("raw_entity_snapshot", False) or False return [ { - "verb": "entity_extract", + "verb": "create_base_extracted_entities", "args": { - **entity_extraction_config, - "column": entity_extraction_config.get("text_column", "chunk"), - "id_column": entity_extraction_config.get("id_column", "chunk_id"), - "async_mode": entity_extraction_config.get( - "async_mode", AsyncType.AsyncIO - ), - "to": "entities", - "graph_to": "entity_graph", + "column": column, + "id_column": id_column, + "async_mode": async_mode, + "strategy": strategy, + "num_threads": num_threads, + "entity_types": entity_types, + "nodes": nodes, + "edges": edges, + "raw_entity_snapshot_enabled": raw_entity_snapshot_enabled, + "graphml_snapshot_enabled": graphml_snapshot_enabled, }, "input": {"source": "workflow:create_base_text_units"}, }, - { - "verb": "snapshot", - "enabled": raw_entity_snapshot_enabled, - "args": { - "name": "raw_extracted_entities", - "formats": ["json"], - }, - }, - { - "verb": "merge_graphs", - "args": { - "column": "entity_graph", - "to": "entity_graph", - **config.get( - "graph_merge_operations", - { - "nodes": { - "source_id": { - "operation": "concat", - "delimiter": ", ", - "distinct": True, - }, - "description": ({ - "operation": "concat", - "separator": "\n", - "distinct": False, - }), - }, - "edges": { - "source_id": { - "operation": "concat", - "delimiter": ", ", - "distinct": True, - }, - "description": ({ - "operation": "concat", - "separator": "\n", - "distinct": False, - }), - "weight": "sum", - }, - }, - ), - }, - }, - { - "verb": "snapshot_rows", - "enabled": graphml_snapshot_enabled, - "args": { - "base_name": "merged_graph", - "column": "entity_graph", - "formats": [{"format": "text", "extension": "graphml"}], - }, - }, ] diff --git a/graphrag/index/workflows/v1/subflows/__init__.py b/graphrag/index/workflows/v1/subflows/__init__.py index 95c4774d..857d20c4 100644 --- a/graphrag/index/workflows/v1/subflows/__init__.py +++ b/graphrag/index/workflows/v1/subflows/__init__.py @@ -5,6 +5,7 @@ from .create_base_documents import create_base_documents from .create_base_entity_graph import create_base_entity_graph +from .create_base_extracted_entities import create_base_extracted_entities from .create_base_text_units import create_base_text_units from .create_final_communities import create_final_communities from .create_final_community_reports import create_final_community_reports @@ -21,6 +22,7 @@ from .create_summarized_entities import create_summarized_entities __all__ = [ "create_base_documents", "create_base_entity_graph", + "create_base_extracted_entities", "create_base_text_units", "create_final_communities", "create_final_community_reports", diff --git a/graphrag/index/workflows/v1/subflows/create_base_extracted_entities.py b/graphrag/index/workflows/v1/subflows/create_base_extracted_entities.py new file mode 100644 index 00000000..464b5a11 --- /dev/null +++ b/graphrag/index/workflows/v1/subflows/create_base_extracted_entities.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All the steps to extract and format covariates.""" + +from typing import Any, cast + +import pandas as pd +from datashaper import ( + AsyncType, + Table, + VerbCallbacks, + VerbInput, + verb, +) +from datashaper.table_store.types import VerbResult, create_verb_result + +from graphrag.index.cache import PipelineCache +from graphrag.index.storage import PipelineStorage +from graphrag.index.verbs.entities.extraction.entity_extract import entity_extract_df +from graphrag.index.verbs.graph.merge.merge_graphs import merge_graphs_df +from graphrag.index.verbs.snapshot import snapshot_df +from graphrag.index.verbs.snapshot_rows import snapshot_rows_df + + +@verb(name="create_base_extracted_entities", treats_input_tables_as_immutable=True) +async def create_base_extracted_entities( + input: VerbInput, + cache: PipelineCache, + callbacks: VerbCallbacks, + storage: PipelineStorage, + column: str, + id_column: str, + nodes: dict[str, Any], + edges: dict[str, Any], + strategy: dict[str, Any] | None, + async_mode: AsyncType = AsyncType.AsyncIO, + entity_types: list[str] | None = None, + graphml_snapshot_enabled: bool = False, + raw_entity_snapshot_enabled: bool = False, + **kwargs: dict, +) -> VerbResult: + """All the steps to extract and format covariates.""" + source = cast(pd.DataFrame, input.get_input()) + + entity_graph = await entity_extract_df( + source, + cache, + callbacks, + column=column, + id_column=id_column, + strategy=strategy, + async_mode=async_mode, + entity_types=entity_types, + to="entities", + graph_to="entity_graph", + **kwargs, + ) + + if raw_entity_snapshot_enabled: + await snapshot_df( + entity_graph, + name="raw_extracted_entities", + storage=storage, + formats=["json"], + ) + + merged_graph = merge_graphs_df( + entity_graph, + callbacks, + column="entity_graph", + to="entity_graph", + nodes=nodes, + edges=edges, + ) + + if graphml_snapshot_enabled: + await snapshot_rows_df( + merged_graph, + base_name="merged_graph", + column="entity_graph", + storage=storage, + formats=[{"format": "text", "extension": "graphml"}], + ) + + return create_verb_result(cast(Table, merged_graph)) diff --git a/tests/fixtures/min-csv/config.json b/tests/fixtures/min-csv/config.json index 86ed82c6..73a46f0d 100644 --- a/tests/fixtures/min-csv/config.json +++ b/tests/fixtures/min-csv/config.json @@ -15,7 +15,7 @@ 1, 2000 ], - "subworkflows": 2, + "subworkflows": 1, "max_runtime": 300 }, "create_summarized_entities": { diff --git a/tests/fixtures/text/config.json b/tests/fixtures/text/config.json index baf74511..473e7d8b 100644 --- a/tests/fixtures/text/config.json +++ b/tests/fixtures/text/config.json @@ -15,7 +15,7 @@ 1, 2000 ], - "subworkflows": 2, + "subworkflows": 1, "max_runtime": 300 }, "create_final_covariates": { diff --git a/tests/verbs/test_create_base_extracted_entities.py b/tests/verbs/test_create_base_extracted_entities.py new file mode 100644 index 00000000..6483471b --- /dev/null +++ b/tests/verbs/test_create_base_extracted_entities.py @@ -0,0 +1,73 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +import networkx as nx + +from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage +from graphrag.index.workflows.v1.create_base_extracted_entities import ( + build_steps, + workflow_name, +) + +from .util import ( + get_config_for_workflow, + get_workflow_output, + load_expected, + load_input_tables, +) + + +async def test_create_base_extracted_entities(): + input_tables = load_input_tables(["workflow:create_base_text_units"]) + expected = load_expected(workflow_name) + + config = get_config_for_workflow(workflow_name) + + del config["entity_extract"]["strategy"]["llm"] + + steps = build_steps(config) + + actual = await get_workflow_output( + input_tables, + { + "steps": steps, + }, + ) + + # let's parse a sample of the raw graphml + actual_graphml_0 = actual["entity_graph"][:1][0] + actual_graph_0 = nx.parse_graphml(actual_graphml_0) + + assert actual_graph_0.number_of_nodes() == 3 + assert actual_graph_0.number_of_edges() == 2 + + assert actual.columns == expected.columns + + +async def test_create_base_extracted_entities_with_snapshots(): + input_tables = load_input_tables(["workflow:create_base_text_units"]) + expected = load_expected(workflow_name) + + storage = MemoryPipelineStorage() + + config = get_config_for_workflow(workflow_name) + + del config["entity_extract"]["strategy"]["llm"] + config["raw_entity_snapshot"] = True + config["graphml_snapshot"] = True + + steps = build_steps(config) + + actual = await get_workflow_output( + input_tables, + { + "steps": steps, + }, + storage=storage, + ) + + print(storage.keys()) + + assert actual.columns == expected.columns + + assert storage.keys() == ["raw_extracted_entities.json", "merged_graph.graphml"]