mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 09:07:20 +08:00
Collapse create base extracted entities (#1235)
* Set up base assertions * Replace entity_extract * Finish collapsing workflow * Semver * Update snoke tests
This commit is contained in:
parent
630679f8e3
commit
9070ea5c3c
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Collapse entity extraction."
|
||||
}
|
||||
@ -54,6 +54,38 @@ async def entity_extract(
|
||||
entity_types=DEFAULT_ENTITY_TYPES,
|
||||
**kwargs,
|
||||
) -> TableContainer:
|
||||
"""Extract entities from a piece of text."""
|
||||
source = cast(pd.DataFrame, input.get_input())
|
||||
output = await entity_extract_df(
|
||||
source,
|
||||
cache,
|
||||
callbacks,
|
||||
column,
|
||||
id_column,
|
||||
to,
|
||||
strategy,
|
||||
graph_to,
|
||||
async_mode,
|
||||
entity_types,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return TableContainer(table=output)
|
||||
|
||||
|
||||
async def entity_extract_df(
|
||||
input: pd.DataFrame,
|
||||
cache: PipelineCache,
|
||||
callbacks: VerbCallbacks,
|
||||
column: str,
|
||||
id_column: str,
|
||||
to: str,
|
||||
strategy: dict[str, Any] | None,
|
||||
graph_to: str | None = None,
|
||||
async_mode: AsyncType = AsyncType.AsyncIO,
|
||||
entity_types=DEFAULT_ENTITY_TYPES,
|
||||
**kwargs,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Extract entities from a piece of text.
|
||||
|
||||
@ -135,7 +167,6 @@ async def entity_extract(
|
||||
log.debug("entity_extract strategy=%s", strategy)
|
||||
if entity_types is None:
|
||||
entity_types = DEFAULT_ENTITY_TYPES
|
||||
output = cast(pd.DataFrame, input.get_input())
|
||||
strategy = strategy or {}
|
||||
strategy_exec = _load_strategy(
|
||||
strategy.get("type", ExtractEntityStrategyType.graph_intelligence)
|
||||
@ -159,7 +190,7 @@ async def entity_extract(
|
||||
return [result.entities, result.graphml_graph]
|
||||
|
||||
results = await derive_from_rows(
|
||||
output,
|
||||
input,
|
||||
run_strategy,
|
||||
callbacks,
|
||||
scheduling_type=async_mode,
|
||||
@ -176,11 +207,11 @@ async def entity_extract(
|
||||
to_result.append(None)
|
||||
graph_to_result.append(None)
|
||||
|
||||
output[to] = to_result
|
||||
input[to] = to_result
|
||||
if graph_to is not None:
|
||||
output[graph_to] = graph_to_result
|
||||
input[graph_to] = graph_to_result
|
||||
|
||||
return TableContainer(table=output.reset_index(drop=True))
|
||||
return input.reset_index(drop=True)
|
||||
|
||||
|
||||
def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy:
|
||||
|
||||
@ -34,6 +34,21 @@ def merge_graphs(
|
||||
edges: dict[str, Any] = DEFAULT_EDGE_OPERATIONS,
|
||||
**_kwargs,
|
||||
) -> TableContainer:
|
||||
"""Merge multiple graphs together. The graphs are expected to be in graphml format. The verb outputs a new column containing the merged graph."""
|
||||
input_df = cast(pd.DataFrame, input.get_input())
|
||||
output = merge_graphs_df(input_df, callbacks, column, to, nodes, edges)
|
||||
|
||||
return TableContainer(table=output)
|
||||
|
||||
|
||||
def merge_graphs_df(
|
||||
input: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
column: str,
|
||||
to: str,
|
||||
nodes: dict[str, Any] = DEFAULT_NODE_OPERATIONS,
|
||||
edges: dict[str, Any] = DEFAULT_EDGE_OPERATIONS,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Merge multiple graphs together. The graphs are expected to be in graphml format. The verb outputs a new column containing the merged graph.
|
||||
|
||||
@ -82,7 +97,6 @@ def merge_graphs(
|
||||
- __average__: This operation takes the mean of the attribute with the last value seen.
|
||||
- __multiply__: This operation multiplies the attribute with the last value seen.
|
||||
"""
|
||||
input_df = input.get_input()
|
||||
output = pd.DataFrame()
|
||||
|
||||
node_ops = {
|
||||
@ -95,15 +109,15 @@ def merge_graphs(
|
||||
}
|
||||
|
||||
mega_graph = nx.Graph()
|
||||
num_total = len(input_df)
|
||||
for graphml in progress_iterable(input_df[column], callbacks.progress, num_total):
|
||||
num_total = len(input)
|
||||
for graphml in progress_iterable(input[column], callbacks.progress, num_total):
|
||||
graph = load_graph(cast(str | nx.Graph, graphml))
|
||||
merge_nodes(mega_graph, graph, node_ops)
|
||||
merge_edges(mega_graph, graph, edge_ops)
|
||||
|
||||
output[to] = ["\n".join(nx.generate_graphml(mega_graph))]
|
||||
|
||||
return TableContainer(table=output)
|
||||
return output
|
||||
|
||||
|
||||
def merge_nodes(
|
||||
|
||||
@ -3,6 +3,9 @@
|
||||
|
||||
"""A module containing snapshot method definition."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import TableContainer, VerbInput, verb
|
||||
|
||||
from graphrag.index.storage import PipelineStorage
|
||||
@ -17,14 +20,24 @@ async def snapshot(
|
||||
**_kwargs: dict,
|
||||
) -> TableContainer:
|
||||
"""Take a entire snapshot of the tabular data."""
|
||||
data = input.get_input()
|
||||
data = cast(pd.DataFrame, input.get_input())
|
||||
|
||||
for fmt in formats:
|
||||
if fmt == "parquet":
|
||||
await storage.set(name + ".parquet", data.to_parquet())
|
||||
elif fmt == "json":
|
||||
await storage.set(
|
||||
name + ".json", data.to_json(orient="records", lines=True)
|
||||
)
|
||||
await snapshot_df(data, name, formats, storage)
|
||||
|
||||
return TableContainer(table=data)
|
||||
|
||||
|
||||
async def snapshot_df(
|
||||
input: pd.DataFrame,
|
||||
name: str,
|
||||
formats: list[str],
|
||||
storage: PipelineStorage,
|
||||
):
|
||||
"""Take a entire snapshot of the tabular data."""
|
||||
for fmt in formats:
|
||||
if fmt == "parquet":
|
||||
await storage.set(name + ".parquet", input.to_parquet())
|
||||
elif fmt == "json":
|
||||
await storage.set(
|
||||
name + ".json", input.to_json(orient="records", lines=True)
|
||||
)
|
||||
|
||||
@ -20,76 +20,65 @@ def build_steps(
|
||||
* `workflow:create_base_text_units`
|
||||
"""
|
||||
entity_extraction_config = config.get("entity_extract", {})
|
||||
|
||||
column = entity_extraction_config.get("text_column", "chunk")
|
||||
id_column = entity_extraction_config.get("id_column", "chunk_id")
|
||||
async_mode = entity_extraction_config.get("async_mode", AsyncType.AsyncIO)
|
||||
strategy = entity_extraction_config.get("strategy")
|
||||
num_threads = entity_extraction_config.get("num_threads", 4)
|
||||
entity_types = entity_extraction_config.get("entity_types")
|
||||
|
||||
graph_merge_operations_config = config.get(
|
||||
"graph_merge_operations",
|
||||
{
|
||||
"nodes": {
|
||||
"source_id": {
|
||||
"operation": "concat",
|
||||
"delimiter": ", ",
|
||||
"distinct": True,
|
||||
},
|
||||
"description": ({
|
||||
"operation": "concat",
|
||||
"separator": "\n",
|
||||
"distinct": False,
|
||||
}),
|
||||
},
|
||||
"edges": {
|
||||
"source_id": {
|
||||
"operation": "concat",
|
||||
"delimiter": ", ",
|
||||
"distinct": True,
|
||||
},
|
||||
"description": ({
|
||||
"operation": "concat",
|
||||
"separator": "\n",
|
||||
"distinct": False,
|
||||
}),
|
||||
"weight": "sum",
|
||||
},
|
||||
},
|
||||
)
|
||||
nodes = graph_merge_operations_config.get("nodes")
|
||||
edges = graph_merge_operations_config.get("edges")
|
||||
|
||||
graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False
|
||||
raw_entity_snapshot_enabled = config.get("raw_entity_snapshot", False) or False
|
||||
|
||||
return [
|
||||
{
|
||||
"verb": "entity_extract",
|
||||
"verb": "create_base_extracted_entities",
|
||||
"args": {
|
||||
**entity_extraction_config,
|
||||
"column": entity_extraction_config.get("text_column", "chunk"),
|
||||
"id_column": entity_extraction_config.get("id_column", "chunk_id"),
|
||||
"async_mode": entity_extraction_config.get(
|
||||
"async_mode", AsyncType.AsyncIO
|
||||
),
|
||||
"to": "entities",
|
||||
"graph_to": "entity_graph",
|
||||
"column": column,
|
||||
"id_column": id_column,
|
||||
"async_mode": async_mode,
|
||||
"strategy": strategy,
|
||||
"num_threads": num_threads,
|
||||
"entity_types": entity_types,
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"raw_entity_snapshot_enabled": raw_entity_snapshot_enabled,
|
||||
"graphml_snapshot_enabled": graphml_snapshot_enabled,
|
||||
},
|
||||
"input": {"source": "workflow:create_base_text_units"},
|
||||
},
|
||||
{
|
||||
"verb": "snapshot",
|
||||
"enabled": raw_entity_snapshot_enabled,
|
||||
"args": {
|
||||
"name": "raw_extracted_entities",
|
||||
"formats": ["json"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"verb": "merge_graphs",
|
||||
"args": {
|
||||
"column": "entity_graph",
|
||||
"to": "entity_graph",
|
||||
**config.get(
|
||||
"graph_merge_operations",
|
||||
{
|
||||
"nodes": {
|
||||
"source_id": {
|
||||
"operation": "concat",
|
||||
"delimiter": ", ",
|
||||
"distinct": True,
|
||||
},
|
||||
"description": ({
|
||||
"operation": "concat",
|
||||
"separator": "\n",
|
||||
"distinct": False,
|
||||
}),
|
||||
},
|
||||
"edges": {
|
||||
"source_id": {
|
||||
"operation": "concat",
|
||||
"delimiter": ", ",
|
||||
"distinct": True,
|
||||
},
|
||||
"description": ({
|
||||
"operation": "concat",
|
||||
"separator": "\n",
|
||||
"distinct": False,
|
||||
}),
|
||||
"weight": "sum",
|
||||
},
|
||||
},
|
||||
),
|
||||
},
|
||||
},
|
||||
{
|
||||
"verb": "snapshot_rows",
|
||||
"enabled": graphml_snapshot_enabled,
|
||||
"args": {
|
||||
"base_name": "merged_graph",
|
||||
"column": "entity_graph",
|
||||
"formats": [{"format": "text", "extension": "graphml"}],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
|
||||
from .create_base_documents import create_base_documents
|
||||
from .create_base_entity_graph import create_base_entity_graph
|
||||
from .create_base_extracted_entities import create_base_extracted_entities
|
||||
from .create_base_text_units import create_base_text_units
|
||||
from .create_final_communities import create_final_communities
|
||||
from .create_final_community_reports import create_final_community_reports
|
||||
@ -21,6 +22,7 @@ from .create_summarized_entities import create_summarized_entities
|
||||
__all__ = [
|
||||
"create_base_documents",
|
||||
"create_base_entity_graph",
|
||||
"create_base_extracted_entities",
|
||||
"create_base_text_units",
|
||||
"create_final_communities",
|
||||
"create_final_community_reports",
|
||||
|
||||
@ -0,0 +1,86 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All the steps to extract and format covariates."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
AsyncType,
|
||||
Table,
|
||||
VerbCallbacks,
|
||||
VerbInput,
|
||||
verb,
|
||||
)
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.storage import PipelineStorage
|
||||
from graphrag.index.verbs.entities.extraction.entity_extract import entity_extract_df
|
||||
from graphrag.index.verbs.graph.merge.merge_graphs import merge_graphs_df
|
||||
from graphrag.index.verbs.snapshot import snapshot_df
|
||||
from graphrag.index.verbs.snapshot_rows import snapshot_rows_df
|
||||
|
||||
|
||||
@verb(name="create_base_extracted_entities", treats_input_tables_as_immutable=True)
|
||||
async def create_base_extracted_entities(
|
||||
input: VerbInput,
|
||||
cache: PipelineCache,
|
||||
callbacks: VerbCallbacks,
|
||||
storage: PipelineStorage,
|
||||
column: str,
|
||||
id_column: str,
|
||||
nodes: dict[str, Any],
|
||||
edges: dict[str, Any],
|
||||
strategy: dict[str, Any] | None,
|
||||
async_mode: AsyncType = AsyncType.AsyncIO,
|
||||
entity_types: list[str] | None = None,
|
||||
graphml_snapshot_enabled: bool = False,
|
||||
raw_entity_snapshot_enabled: bool = False,
|
||||
**kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to extract and format covariates."""
|
||||
source = cast(pd.DataFrame, input.get_input())
|
||||
|
||||
entity_graph = await entity_extract_df(
|
||||
source,
|
||||
cache,
|
||||
callbacks,
|
||||
column=column,
|
||||
id_column=id_column,
|
||||
strategy=strategy,
|
||||
async_mode=async_mode,
|
||||
entity_types=entity_types,
|
||||
to="entities",
|
||||
graph_to="entity_graph",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if raw_entity_snapshot_enabled:
|
||||
await snapshot_df(
|
||||
entity_graph,
|
||||
name="raw_extracted_entities",
|
||||
storage=storage,
|
||||
formats=["json"],
|
||||
)
|
||||
|
||||
merged_graph = merge_graphs_df(
|
||||
entity_graph,
|
||||
callbacks,
|
||||
column="entity_graph",
|
||||
to="entity_graph",
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
)
|
||||
|
||||
if graphml_snapshot_enabled:
|
||||
await snapshot_rows_df(
|
||||
merged_graph,
|
||||
base_name="merged_graph",
|
||||
column="entity_graph",
|
||||
storage=storage,
|
||||
formats=[{"format": "text", "extension": "graphml"}],
|
||||
)
|
||||
|
||||
return create_verb_result(cast(Table, merged_graph))
|
||||
2
tests/fixtures/min-csv/config.json
vendored
2
tests/fixtures/min-csv/config.json
vendored
@ -15,7 +15,7 @@
|
||||
1,
|
||||
2000
|
||||
],
|
||||
"subworkflows": 2,
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300
|
||||
},
|
||||
"create_summarized_entities": {
|
||||
|
||||
2
tests/fixtures/text/config.json
vendored
2
tests/fixtures/text/config.json
vendored
@ -15,7 +15,7 @@
|
||||
1,
|
||||
2000
|
||||
],
|
||||
"subworkflows": 2,
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300
|
||||
},
|
||||
"create_final_covariates": {
|
||||
|
||||
73
tests/verbs/test_create_base_extracted_entities.py
Normal file
73
tests/verbs/test_create_base_extracted_entities.py
Normal file
@ -0,0 +1,73 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
|
||||
from graphrag.index.workflows.v1.create_base_extracted_entities import (
|
||||
build_steps,
|
||||
workflow_name,
|
||||
)
|
||||
|
||||
from .util import (
|
||||
get_config_for_workflow,
|
||||
get_workflow_output,
|
||||
load_expected,
|
||||
load_input_tables,
|
||||
)
|
||||
|
||||
|
||||
async def test_create_base_extracted_entities():
|
||||
input_tables = load_input_tables(["workflow:create_base_text_units"])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
del config["entity_extract"]["strategy"]["llm"]
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
# let's parse a sample of the raw graphml
|
||||
actual_graphml_0 = actual["entity_graph"][:1][0]
|
||||
actual_graph_0 = nx.parse_graphml(actual_graphml_0)
|
||||
|
||||
assert actual_graph_0.number_of_nodes() == 3
|
||||
assert actual_graph_0.number_of_edges() == 2
|
||||
|
||||
assert actual.columns == expected.columns
|
||||
|
||||
|
||||
async def test_create_base_extracted_entities_with_snapshots():
|
||||
input_tables = load_input_tables(["workflow:create_base_text_units"])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
storage = MemoryPipelineStorage()
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
del config["entity_extract"]["strategy"]["llm"]
|
||||
config["raw_entity_snapshot"] = True
|
||||
config["graphml_snapshot"] = True
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
storage=storage,
|
||||
)
|
||||
|
||||
print(storage.keys())
|
||||
|
||||
assert actual.columns == expected.columns
|
||||
|
||||
assert storage.keys() == ["raw_extracted_entities.json", "merged_graph.graphml"]
|
||||
Loading…
Reference in New Issue
Block a user