Collapse create base entity graph (#1233)

* Collapse create_base_entity_graph

* Format/typing

* Semver

* Fix smoke tests

* Simplify assignment
This commit is contained in:
Nathan Evans 2024-09-30 15:39:42 -07:00 committed by GitHub
parent 00d5e77568
commit 5220bb7ecc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 331 additions and 75 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Collapse create-base-entity-graph."
}

View File

@ -73,6 +73,10 @@ class MemoryPipelineStorage(FilePipelineStorage):
"""Create a child storage instance."""
return self
def keys(self) -> list[str]:
"""Return the keys in the storage."""
return list(self._storage.keys())
def create_memory_storage() -> PipelineStorage:
"""Create memory storage."""

View File

@ -57,28 +57,45 @@ def cluster_graph(
```
"""
output_df = cast(pd.DataFrame, input.get_input())
results = output_df[column].apply(lambda graph: run_layout(strategy, graph))
output_df = cluster_graph_df(
cast(pd.DataFrame, input.get_input()),
callbacks,
strategy,
column,
to,
level_to=level_to,
)
return TableContainer(table=output_df)
def cluster_graph_df(
input: pd.DataFrame,
callbacks: VerbCallbacks,
strategy: dict[str, Any],
column: str,
to: str,
level_to: str | None = None,
) -> pd.DataFrame:
"""Apply a hierarchical clustering algorithm to a graph."""
results = input[column].apply(lambda graph: run_layout(strategy, graph))
community_map_to = "communities"
output_df[community_map_to] = results
input[community_map_to] = results
level_to = level_to or f"{to}_level"
output_df[level_to] = output_df.apply(
input[level_to] = input.apply(
lambda x: list({level for level, _, _ in x[community_map_to]}), axis=1
)
output_df[to] = [None] * len(output_df)
input[to] = None
num_total = len(output_df)
num_total = len(input)
# Create a seed for this run (if not provided)
seed = strategy.get("seed", Random().randint(0, 0xFFFFFFFF)) # noqa S311
# Go through each of the rows
graph_level_pairs_column: list[list[tuple[int, str]]] = []
for _, row in progress_iterable(
output_df.iterrows(), callbacks.progress, num_total
):
for _, row in progress_iterable(input.iterrows(), callbacks.progress, num_total):
levels = row[level_to]
graph_level_pairs: list[tuple[int, str]] = []
@ -96,21 +113,18 @@ def cluster_graph(
)
graph_level_pairs.append((level, graph))
graph_level_pairs_column.append(graph_level_pairs)
output_df[to] = graph_level_pairs_column
input[to] = graph_level_pairs_column
# explode the list of (level, graph) pairs into separate rows
output_df = output_df.explode(to, ignore_index=True)
input = input.explode(to, ignore_index=True)
# split the (level, graph) pairs into separate columns
# TODO: There is probably a better way to do this
output_df[[level_to, to]] = pd.DataFrame(
output_df[to].tolist(), index=output_df.index
)
input[[level_to, to]] = pd.DataFrame(input[to].tolist(), index=input.index)
# clean up the community map
output_df.drop(columns=[community_map_to], inplace=True)
return TableContainer(table=output_df)
input.drop(columns=[community_map_to], inplace=True)
return input
# TODO: This should support str | nx.Graph as a graphml param

View File

@ -63,8 +63,56 @@ async def embed_graph(
random_seed: 86 # Optional, The random seed to use for the embedding, default: 86
```
"""
output_df = cast(pd.DataFrame, input.get_input())
input_df = cast(pd.DataFrame, input.get_input())
output_df = await embed_graph_df(
input_df,
callbacks,
strategy,
column,
to,
**kwargs,
)
return TableContainer(table=output_df)
async def embed_graph_df(
input: pd.DataFrame,
callbacks: VerbCallbacks,
strategy: dict[str, Any],
column: str,
to: str,
**kwargs,
) -> pd.DataFrame:
"""
Embed a graph into a vector space. The graph is expected to be in graphml format. The verb outputs a new column containing a mapping between node_id and vector.
## Usage
```yaml
verb: embed_graph
args:
column: clustered_graph # The name of the column containing the graph, should be a graphml graph
to: embeddings # The name of the column to output the embeddings to
strategy: <strategy config> # See strategies section below
```
## Strategies
The embed_graph verb uses a strategy to embed the graph. The strategy is an object which defines the strategy to use. The following strategies are available:
### node2vec
This strategy uses the node2vec algorithm to embed a graph. The strategy config is as follows:
```yaml
strategy:
type: node2vec
dimensions: 1536 # Optional, The number of dimensions to use for the embedding, default: 1536
num_walks: 10 # Optional, The number of walks to use for the embedding, default: 10
walk_length: 40 # Optional, The walk length to use for the embedding, default: 40
window_size: 2 # Optional, The window size to use for the embedding, default: 2
iterations: 3 # Optional, The number of iterations to use for the embedding, default: 3
random_seed: 86 # Optional, The random seed to use for the embedding, default: 86
```
"""
strategy_type = strategy.get("type", EmbedGraphStrategyType.node2vec)
strategy_args = {**strategy}
@ -72,13 +120,13 @@ async def embed_graph(
return run_embeddings(strategy_type, cast(Any, row[column]), strategy_args)
results = await derive_from_rows(
output_df,
input,
run_strategy,
callbacks=callbacks,
num_threads=kwargs.get("num_threads", None),
)
output_df[to] = list(results)
return TableContainer(table=output_df)
input[to] = list(results)
return input
def run_embeddings(

View File

@ -5,8 +5,9 @@
import json
from dataclasses import dataclass
from typing import Any
from typing import Any, cast
import pandas as pd
from datashaper import TableContainer, VerbInput, verb
from graphrag.index.storage import PipelineStorage
@ -31,9 +32,30 @@ async def snapshot_rows(
**_kwargs: dict,
) -> TableContainer:
"""Take a by-row snapshot of the tabular data."""
data = input.get_input()
source = cast(pd.DataFrame, input.get_input())
output = await snapshot_rows_df(
source,
column,
base_name,
storage,
formats,
row_name_column,
)
return TableContainer(table=output)
# todo: once this is out of "verb land", it does not need to return the input
async def snapshot_rows_df(
input: pd.DataFrame,
column: str | None,
base_name: str,
storage: PipelineStorage,
formats: list[str | dict[str, Any]],
row_name_column: str | None = None,
) -> pd.DataFrame:
"""Take a by-row snapshot of the tabular data."""
parsed_formats = _parse_formats(formats)
num_rows = len(data)
num_rows = len(input)
def get_row_name(row: Any, row_idx: Any):
if row_name_column is None:
@ -42,7 +64,7 @@ async def snapshot_rows(
return f"{base_name}.{row_idx}"
return f"{base_name}.{row[row_name_column]}"
for row_idx, row in data.iterrows():
for row_idx, row in input.iterrows():
for fmt in parsed_formats:
row_name = get_row_name(row, row_idx)
extension = fmt.extension
@ -60,8 +82,7 @@ async def snapshot_rows(
msg = "column must be specified for text format"
raise ValueError(msg)
await storage.set(f"{row_name}.{extension}", str(row[column]))
return TableContainer(table=data)
return input
def _parse_formats(formats: list[str | dict[str, Any]]) -> list[FormatSpecifier]:

View File

@ -40,52 +40,13 @@ def build_steps(
return [
{
"verb": "cluster_graph",
"verb": "create_base_entity_graph",
"args": {
**clustering_config,
"column": "entity_graph",
"to": "clustered_graph",
"level_to": "level",
"clustering_config": clustering_config,
"graphml_snapshot_enabled": graphml_snapshot_enabled,
"embed_graph_enabled": embed_graph_enabled,
"embedding_config": embed_graph_config,
},
"input": ({"source": "workflow:create_summarized_entities"}),
},
{
"verb": "snapshot_rows",
"enabled": graphml_snapshot_enabled,
"args": {
"base_name": "clustered_graph",
"column": "clustered_graph",
"formats": [{"format": "text", "extension": "graphml"}],
},
},
{
"verb": "embed_graph",
"enabled": embed_graph_enabled,
"args": {
"column": "clustered_graph",
"to": "embeddings",
**embed_graph_config,
},
},
{
"verb": "snapshot_rows",
"enabled": graphml_snapshot_enabled,
"args": {
"base_name": "embedded_graph",
"column": "entity_graph",
"formats": [{"format": "text", "extension": "graphml"}],
},
},
{
"verb": "select",
"args": {
# only selecting for documentation sake, so we know what is contained in
# this workflow
"columns": (
["level", "clustered_graph", "embeddings"]
if embed_graph_enabled
else ["level", "clustered_graph"]
),
},
},
]

View File

@ -4,6 +4,7 @@
"""The Indexing Engine workflows -> subflows package root."""
from .create_base_documents import create_base_documents
from .create_base_entity_graph import create_base_entity_graph
from .create_base_text_units import create_base_text_units
from .create_final_communities import create_final_communities
from .create_final_community_reports import create_final_community_reports
@ -18,6 +19,7 @@ from .create_final_text_units import create_final_text_units
__all__ = [
"create_base_documents",
"create_base_entity_graph",
"create_base_text_units",
"create_final_communities",
"create_final_community_reports",

View File

@ -0,0 +1,85 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to transform final documents."""
from typing import Any, cast
import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.storage import PipelineStorage
from graphrag.index.verbs.graph.clustering.cluster_graph import cluster_graph_df
from graphrag.index.verbs.graph.embed.embed_graph import embed_graph_df
from graphrag.index.verbs.snapshot_rows import snapshot_rows_df
@verb(
name="create_base_entity_graph",
treats_input_tables_as_immutable=True,
)
async def create_base_entity_graph(
input: VerbInput,
callbacks: VerbCallbacks,
storage: PipelineStorage,
clustering_config: dict[str, Any],
embedding_config: dict[str, Any],
graphml_snapshot_enabled: bool = False,
embed_graph_enabled: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform final documents."""
source = cast(pd.DataFrame, input.get_input())
clustering_strategy = clustering_config.get("strategy", {"type": "leiden"})
clustered = cluster_graph_df(
source,
callbacks,
column="entity_graph",
strategy=clustering_strategy,
to="clustered_graph",
level_to="level",
)
if graphml_snapshot_enabled:
await snapshot_rows_df(
clustered,
column="clustered_graph",
base_name="clustered_graph",
storage=storage,
formats=[{"format": "text", "extension": "graphml"}],
)
embedding_strategy = embedding_config.get("strategy")
if embed_graph_enabled and embedding_strategy:
clustered = await embed_graph_df(
clustered,
callbacks,
column="clustered_graph",
strategy=embedding_strategy,
to="embeddings",
)
# take second snapshot after embedding
# todo: this could be skipped if embedding isn't performed, other wise it is a copy of the regular graph?
if graphml_snapshot_enabled:
await snapshot_rows_df(
clustered,
column="entity_graph",
base_name="embedded_graph",
storage=storage,
formats=[{"format": "text", "extension": "graphml"}],
)
final_columns = ["level", "clustered_graph"]
if embed_graph_enabled:
final_columns.append("embeddings")
return create_verb_result(cast(Table, clustered[final_columns]))

View File

@ -31,7 +31,7 @@
1,
2000
],
"subworkflows": 2,
"subworkflows": 1,
"max_runtime": 10
},
"create_final_entities": {

View File

@ -48,7 +48,7 @@
1,
2000
],
"subworkflows": 2,
"subworkflows": 1,
"max_runtime": 10
},
"create_final_entities": {

View File

@ -0,0 +1,114 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
import networkx as nx
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
from graphrag.index.workflows.v1.create_base_entity_graph import (
build_steps,
workflow_name,
)
from .util import (
get_config_for_workflow,
get_workflow_output,
load_expected,
load_input_tables,
remove_disabled_steps,
)
async def test_create_base_entity_graph():
input_tables = load_input_tables([
"workflow:create_summarized_entities",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
steps = remove_disabled_steps(build_steps(config))
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
# the serialization of the graph may differ so we can't assert the dataframes directly
assert actual.shape == expected.shape, "Graph dataframe shapes differ"
# let's parse a sample of the raw graphml
actual_graphml_0 = actual["clustered_graph"][:1][0]
actual_graph_0 = nx.parse_graphml(actual_graphml_0)
expected_graphml_0 = expected["clustered_graph"][:1][0]
expected_graph_0 = nx.parse_graphml(expected_graphml_0)
assert (
actual_graph_0.number_of_nodes() == expected_graph_0.number_of_nodes()
), "Graphml node count differs"
assert (
actual_graph_0.number_of_edges() == expected_graph_0.number_of_edges()
), "Graphml edge count differs"
async def test_create_base_entity_graph_with_embeddings():
input_tables = load_input_tables([
"workflow:create_summarized_entities",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
config["embed_graph_enabled"] = True
steps = remove_disabled_steps(build_steps(config))
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
assert (
len(actual.columns) == len(expected.columns) + 1
), "Graph dataframe missing embedding column"
assert "embeddings" in actual.columns, "Graph dataframe missing embedding column"
async def test_create_base_entity_graph_with_snapshots():
input_tables = load_input_tables([
"workflow:create_summarized_entities",
])
expected = load_expected(workflow_name)
storage = MemoryPipelineStorage()
config = get_config_for_workflow(workflow_name)
config["graphml_snapshot"] = True
steps = remove_disabled_steps(build_steps(config))
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
storage=storage,
)
assert actual.shape == expected.shape, "Graph dataframe shapes differ"
assert storage.keys() == [
"clustered_graph.0.graphml",
"clustered_graph.1.graphml",
"clustered_graph.2.graphml",
"clustered_graph.3.graphml",
"embedded_graph.0.graphml",
"embedded_graph.1.graphml",
"embedded_graph.2.graphml",
"embedded_graph.3.graphml",
], "Graph snapshot keys differ"

View File

@ -14,6 +14,7 @@ from graphrag.index import (
create_pipeline_config,
)
from graphrag.index.run.utils import _create_run_context
from graphrag.index.storage.typing import PipelineStorage
pd.set_option("display.max_columns", None)
@ -57,7 +58,9 @@ def get_config_for_workflow(name: str) -> PipelineWorkflowConfig:
async def get_workflow_output(
input_tables: dict[str, pd.DataFrame], schema: dict
input_tables: dict[str, pd.DataFrame],
schema: dict,
storage: PipelineStorage | None = None,
) -> pd.DataFrame:
"""Pass in the input tables, the schema, and the output name"""
@ -67,7 +70,7 @@ async def get_workflow_output(
input_tables=input_tables,
)
context = _create_run_context(None, None, None)
context = _create_run_context(storage, None, None)
await workflow.run(context=context)