Collapse create base extracted entities (#1235)

* Set up base assertions

* Replace entity_extract

* Finish collapsing workflow

* Semver

* Update snoke tests
This commit is contained in:
Nathan Evans 2024-09-30 17:32:56 -07:00 committed by GitHub
parent 630679f8e3
commit 9070ea5c3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 294 additions and 82 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Collapse entity extraction."
}

View File

@ -54,6 +54,38 @@ async def entity_extract(
entity_types=DEFAULT_ENTITY_TYPES,
**kwargs,
) -> TableContainer:
"""Extract entities from a piece of text."""
source = cast(pd.DataFrame, input.get_input())
output = await entity_extract_df(
source,
cache,
callbacks,
column,
id_column,
to,
strategy,
graph_to,
async_mode,
entity_types,
**kwargs,
)
return TableContainer(table=output)
async def entity_extract_df(
input: pd.DataFrame,
cache: PipelineCache,
callbacks: VerbCallbacks,
column: str,
id_column: str,
to: str,
strategy: dict[str, Any] | None,
graph_to: str | None = None,
async_mode: AsyncType = AsyncType.AsyncIO,
entity_types=DEFAULT_ENTITY_TYPES,
**kwargs,
) -> pd.DataFrame:
"""
Extract entities from a piece of text.
@ -135,7 +167,6 @@ async def entity_extract(
log.debug("entity_extract strategy=%s", strategy)
if entity_types is None:
entity_types = DEFAULT_ENTITY_TYPES
output = cast(pd.DataFrame, input.get_input())
strategy = strategy or {}
strategy_exec = _load_strategy(
strategy.get("type", ExtractEntityStrategyType.graph_intelligence)
@ -159,7 +190,7 @@ async def entity_extract(
return [result.entities, result.graphml_graph]
results = await derive_from_rows(
output,
input,
run_strategy,
callbacks,
scheduling_type=async_mode,
@ -176,11 +207,11 @@ async def entity_extract(
to_result.append(None)
graph_to_result.append(None)
output[to] = to_result
input[to] = to_result
if graph_to is not None:
output[graph_to] = graph_to_result
input[graph_to] = graph_to_result
return TableContainer(table=output.reset_index(drop=True))
return input.reset_index(drop=True)
def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy:

View File

@ -34,6 +34,21 @@ def merge_graphs(
edges: dict[str, Any] = DEFAULT_EDGE_OPERATIONS,
**_kwargs,
) -> TableContainer:
"""Merge multiple graphs together. The graphs are expected to be in graphml format. The verb outputs a new column containing the merged graph."""
input_df = cast(pd.DataFrame, input.get_input())
output = merge_graphs_df(input_df, callbacks, column, to, nodes, edges)
return TableContainer(table=output)
def merge_graphs_df(
input: pd.DataFrame,
callbacks: VerbCallbacks,
column: str,
to: str,
nodes: dict[str, Any] = DEFAULT_NODE_OPERATIONS,
edges: dict[str, Any] = DEFAULT_EDGE_OPERATIONS,
) -> pd.DataFrame:
"""
Merge multiple graphs together. The graphs are expected to be in graphml format. The verb outputs a new column containing the merged graph.
@ -82,7 +97,6 @@ def merge_graphs(
- __average__: This operation takes the mean of the attribute with the last value seen.
- __multiply__: This operation multiplies the attribute with the last value seen.
"""
input_df = input.get_input()
output = pd.DataFrame()
node_ops = {
@ -95,15 +109,15 @@ def merge_graphs(
}
mega_graph = nx.Graph()
num_total = len(input_df)
for graphml in progress_iterable(input_df[column], callbacks.progress, num_total):
num_total = len(input)
for graphml in progress_iterable(input[column], callbacks.progress, num_total):
graph = load_graph(cast(str | nx.Graph, graphml))
merge_nodes(mega_graph, graph, node_ops)
merge_edges(mega_graph, graph, edge_ops)
output[to] = ["\n".join(nx.generate_graphml(mega_graph))]
return TableContainer(table=output)
return output
def merge_nodes(

View File

@ -3,6 +3,9 @@
"""A module containing snapshot method definition."""
from typing import cast
import pandas as pd
from datashaper import TableContainer, VerbInput, verb
from graphrag.index.storage import PipelineStorage
@ -17,14 +20,24 @@ async def snapshot(
**_kwargs: dict,
) -> TableContainer:
"""Take a entire snapshot of the tabular data."""
data = input.get_input()
data = cast(pd.DataFrame, input.get_input())
for fmt in formats:
if fmt == "parquet":
await storage.set(name + ".parquet", data.to_parquet())
elif fmt == "json":
await storage.set(
name + ".json", data.to_json(orient="records", lines=True)
)
await snapshot_df(data, name, formats, storage)
return TableContainer(table=data)
async def snapshot_df(
input: pd.DataFrame,
name: str,
formats: list[str],
storage: PipelineStorage,
):
"""Take a entire snapshot of the tabular data."""
for fmt in formats:
if fmt == "parquet":
await storage.set(name + ".parquet", input.to_parquet())
elif fmt == "json":
await storage.set(
name + ".json", input.to_json(orient="records", lines=True)
)

View File

@ -20,76 +20,65 @@ def build_steps(
* `workflow:create_base_text_units`
"""
entity_extraction_config = config.get("entity_extract", {})
column = entity_extraction_config.get("text_column", "chunk")
id_column = entity_extraction_config.get("id_column", "chunk_id")
async_mode = entity_extraction_config.get("async_mode", AsyncType.AsyncIO)
strategy = entity_extraction_config.get("strategy")
num_threads = entity_extraction_config.get("num_threads", 4)
entity_types = entity_extraction_config.get("entity_types")
graph_merge_operations_config = config.get(
"graph_merge_operations",
{
"nodes": {
"source_id": {
"operation": "concat",
"delimiter": ", ",
"distinct": True,
},
"description": ({
"operation": "concat",
"separator": "\n",
"distinct": False,
}),
},
"edges": {
"source_id": {
"operation": "concat",
"delimiter": ", ",
"distinct": True,
},
"description": ({
"operation": "concat",
"separator": "\n",
"distinct": False,
}),
"weight": "sum",
},
},
)
nodes = graph_merge_operations_config.get("nodes")
edges = graph_merge_operations_config.get("edges")
graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False
raw_entity_snapshot_enabled = config.get("raw_entity_snapshot", False) or False
return [
{
"verb": "entity_extract",
"verb": "create_base_extracted_entities",
"args": {
**entity_extraction_config,
"column": entity_extraction_config.get("text_column", "chunk"),
"id_column": entity_extraction_config.get("id_column", "chunk_id"),
"async_mode": entity_extraction_config.get(
"async_mode", AsyncType.AsyncIO
),
"to": "entities",
"graph_to": "entity_graph",
"column": column,
"id_column": id_column,
"async_mode": async_mode,
"strategy": strategy,
"num_threads": num_threads,
"entity_types": entity_types,
"nodes": nodes,
"edges": edges,
"raw_entity_snapshot_enabled": raw_entity_snapshot_enabled,
"graphml_snapshot_enabled": graphml_snapshot_enabled,
},
"input": {"source": "workflow:create_base_text_units"},
},
{
"verb": "snapshot",
"enabled": raw_entity_snapshot_enabled,
"args": {
"name": "raw_extracted_entities",
"formats": ["json"],
},
},
{
"verb": "merge_graphs",
"args": {
"column": "entity_graph",
"to": "entity_graph",
**config.get(
"graph_merge_operations",
{
"nodes": {
"source_id": {
"operation": "concat",
"delimiter": ", ",
"distinct": True,
},
"description": ({
"operation": "concat",
"separator": "\n",
"distinct": False,
}),
},
"edges": {
"source_id": {
"operation": "concat",
"delimiter": ", ",
"distinct": True,
},
"description": ({
"operation": "concat",
"separator": "\n",
"distinct": False,
}),
"weight": "sum",
},
},
),
},
},
{
"verb": "snapshot_rows",
"enabled": graphml_snapshot_enabled,
"args": {
"base_name": "merged_graph",
"column": "entity_graph",
"formats": [{"format": "text", "extension": "graphml"}],
},
},
]

View File

@ -5,6 +5,7 @@
from .create_base_documents import create_base_documents
from .create_base_entity_graph import create_base_entity_graph
from .create_base_extracted_entities import create_base_extracted_entities
from .create_base_text_units import create_base_text_units
from .create_final_communities import create_final_communities
from .create_final_community_reports import create_final_community_reports
@ -21,6 +22,7 @@ from .create_summarized_entities import create_summarized_entities
__all__ = [
"create_base_documents",
"create_base_entity_graph",
"create_base_extracted_entities",
"create_base_text_units",
"create_final_communities",
"create_final_community_reports",

View File

@ -0,0 +1,86 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to extract and format covariates."""
from typing import Any, cast
import pandas as pd
from datashaper import (
AsyncType,
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.cache import PipelineCache
from graphrag.index.storage import PipelineStorage
from graphrag.index.verbs.entities.extraction.entity_extract import entity_extract_df
from graphrag.index.verbs.graph.merge.merge_graphs import merge_graphs_df
from graphrag.index.verbs.snapshot import snapshot_df
from graphrag.index.verbs.snapshot_rows import snapshot_rows_df
@verb(name="create_base_extracted_entities", treats_input_tables_as_immutable=True)
async def create_base_extracted_entities(
input: VerbInput,
cache: PipelineCache,
callbacks: VerbCallbacks,
storage: PipelineStorage,
column: str,
id_column: str,
nodes: dict[str, Any],
edges: dict[str, Any],
strategy: dict[str, Any] | None,
async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
graphml_snapshot_enabled: bool = False,
raw_entity_snapshot_enabled: bool = False,
**kwargs: dict,
) -> VerbResult:
"""All the steps to extract and format covariates."""
source = cast(pd.DataFrame, input.get_input())
entity_graph = await entity_extract_df(
source,
cache,
callbacks,
column=column,
id_column=id_column,
strategy=strategy,
async_mode=async_mode,
entity_types=entity_types,
to="entities",
graph_to="entity_graph",
**kwargs,
)
if raw_entity_snapshot_enabled:
await snapshot_df(
entity_graph,
name="raw_extracted_entities",
storage=storage,
formats=["json"],
)
merged_graph = merge_graphs_df(
entity_graph,
callbacks,
column="entity_graph",
to="entity_graph",
nodes=nodes,
edges=edges,
)
if graphml_snapshot_enabled:
await snapshot_rows_df(
merged_graph,
base_name="merged_graph",
column="entity_graph",
storage=storage,
formats=[{"format": "text", "extension": "graphml"}],
)
return create_verb_result(cast(Table, merged_graph))

View File

@ -15,7 +15,7 @@
1,
2000
],
"subworkflows": 2,
"subworkflows": 1,
"max_runtime": 300
},
"create_summarized_entities": {

View File

@ -15,7 +15,7 @@
1,
2000
],
"subworkflows": 2,
"subworkflows": 1,
"max_runtime": 300
},
"create_final_covariates": {

View File

@ -0,0 +1,73 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
import networkx as nx
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
from graphrag.index.workflows.v1.create_base_extracted_entities import (
build_steps,
workflow_name,
)
from .util import (
get_config_for_workflow,
get_workflow_output,
load_expected,
load_input_tables,
)
async def test_create_base_extracted_entities():
input_tables = load_input_tables(["workflow:create_base_text_units"])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
del config["entity_extract"]["strategy"]["llm"]
steps = build_steps(config)
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
# let's parse a sample of the raw graphml
actual_graphml_0 = actual["entity_graph"][:1][0]
actual_graph_0 = nx.parse_graphml(actual_graphml_0)
assert actual_graph_0.number_of_nodes() == 3
assert actual_graph_0.number_of_edges() == 2
assert actual.columns == expected.columns
async def test_create_base_extracted_entities_with_snapshots():
input_tables = load_input_tables(["workflow:create_base_text_units"])
expected = load_expected(workflow_name)
storage = MemoryPipelineStorage()
config = get_config_for_workflow(workflow_name)
del config["entity_extract"]["strategy"]["llm"]
config["raw_entity_snapshot"] = True
config["graphml_snapshot"] = True
steps = build_steps(config)
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
storage=storage,
)
print(storage.keys())
assert actual.columns == expected.columns
assert storage.keys() == ["raw_extracted_entities.json", "merged_graph.graphml"]