mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 09:07:20 +08:00
Collapse create base entity graph (#1233)
* Collapse create_base_entity_graph * Format/typing * Semver * Fix smoke tests * Simplify assignment
This commit is contained in:
parent
00d5e77568
commit
5220bb7ecc
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Collapse create-base-entity-graph."
|
||||
}
|
||||
@ -73,6 +73,10 @@ class MemoryPipelineStorage(FilePipelineStorage):
|
||||
"""Create a child storage instance."""
|
||||
return self
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
"""Return the keys in the storage."""
|
||||
return list(self._storage.keys())
|
||||
|
||||
|
||||
def create_memory_storage() -> PipelineStorage:
|
||||
"""Create memory storage."""
|
||||
|
||||
@ -57,28 +57,45 @@ def cluster_graph(
|
||||
|
||||
```
|
||||
"""
|
||||
output_df = cast(pd.DataFrame, input.get_input())
|
||||
results = output_df[column].apply(lambda graph: run_layout(strategy, graph))
|
||||
output_df = cluster_graph_df(
|
||||
cast(pd.DataFrame, input.get_input()),
|
||||
callbacks,
|
||||
strategy,
|
||||
column,
|
||||
to,
|
||||
level_to=level_to,
|
||||
)
|
||||
return TableContainer(table=output_df)
|
||||
|
||||
|
||||
def cluster_graph_df(
|
||||
input: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
strategy: dict[str, Any],
|
||||
column: str,
|
||||
to: str,
|
||||
level_to: str | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Apply a hierarchical clustering algorithm to a graph."""
|
||||
results = input[column].apply(lambda graph: run_layout(strategy, graph))
|
||||
|
||||
community_map_to = "communities"
|
||||
output_df[community_map_to] = results
|
||||
input[community_map_to] = results
|
||||
|
||||
level_to = level_to or f"{to}_level"
|
||||
output_df[level_to] = output_df.apply(
|
||||
input[level_to] = input.apply(
|
||||
lambda x: list({level for level, _, _ in x[community_map_to]}), axis=1
|
||||
)
|
||||
output_df[to] = [None] * len(output_df)
|
||||
input[to] = None
|
||||
|
||||
num_total = len(output_df)
|
||||
num_total = len(input)
|
||||
|
||||
# Create a seed for this run (if not provided)
|
||||
seed = strategy.get("seed", Random().randint(0, 0xFFFFFFFF)) # noqa S311
|
||||
|
||||
# Go through each of the rows
|
||||
graph_level_pairs_column: list[list[tuple[int, str]]] = []
|
||||
for _, row in progress_iterable(
|
||||
output_df.iterrows(), callbacks.progress, num_total
|
||||
):
|
||||
for _, row in progress_iterable(input.iterrows(), callbacks.progress, num_total):
|
||||
levels = row[level_to]
|
||||
graph_level_pairs: list[tuple[int, str]] = []
|
||||
|
||||
@ -96,21 +113,18 @@ def cluster_graph(
|
||||
)
|
||||
graph_level_pairs.append((level, graph))
|
||||
graph_level_pairs_column.append(graph_level_pairs)
|
||||
output_df[to] = graph_level_pairs_column
|
||||
input[to] = graph_level_pairs_column
|
||||
|
||||
# explode the list of (level, graph) pairs into separate rows
|
||||
output_df = output_df.explode(to, ignore_index=True)
|
||||
input = input.explode(to, ignore_index=True)
|
||||
|
||||
# split the (level, graph) pairs into separate columns
|
||||
# TODO: There is probably a better way to do this
|
||||
output_df[[level_to, to]] = pd.DataFrame(
|
||||
output_df[to].tolist(), index=output_df.index
|
||||
)
|
||||
input[[level_to, to]] = pd.DataFrame(input[to].tolist(), index=input.index)
|
||||
|
||||
# clean up the community map
|
||||
output_df.drop(columns=[community_map_to], inplace=True)
|
||||
|
||||
return TableContainer(table=output_df)
|
||||
input.drop(columns=[community_map_to], inplace=True)
|
||||
return input
|
||||
|
||||
|
||||
# TODO: This should support str | nx.Graph as a graphml param
|
||||
|
||||
@ -63,8 +63,56 @@ async def embed_graph(
|
||||
random_seed: 86 # Optional, The random seed to use for the embedding, default: 86
|
||||
```
|
||||
"""
|
||||
output_df = cast(pd.DataFrame, input.get_input())
|
||||
input_df = cast(pd.DataFrame, input.get_input())
|
||||
|
||||
output_df = await embed_graph_df(
|
||||
input_df,
|
||||
callbacks,
|
||||
strategy,
|
||||
column,
|
||||
to,
|
||||
**kwargs,
|
||||
)
|
||||
return TableContainer(table=output_df)
|
||||
|
||||
|
||||
async def embed_graph_df(
|
||||
input: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
strategy: dict[str, Any],
|
||||
column: str,
|
||||
to: str,
|
||||
**kwargs,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Embed a graph into a vector space. The graph is expected to be in graphml format. The verb outputs a new column containing a mapping between node_id and vector.
|
||||
|
||||
## Usage
|
||||
```yaml
|
||||
verb: embed_graph
|
||||
args:
|
||||
column: clustered_graph # The name of the column containing the graph, should be a graphml graph
|
||||
to: embeddings # The name of the column to output the embeddings to
|
||||
strategy: <strategy config> # See strategies section below
|
||||
```
|
||||
|
||||
## Strategies
|
||||
The embed_graph verb uses a strategy to embed the graph. The strategy is an object which defines the strategy to use. The following strategies are available:
|
||||
|
||||
### node2vec
|
||||
This strategy uses the node2vec algorithm to embed a graph. The strategy config is as follows:
|
||||
|
||||
```yaml
|
||||
strategy:
|
||||
type: node2vec
|
||||
dimensions: 1536 # Optional, The number of dimensions to use for the embedding, default: 1536
|
||||
num_walks: 10 # Optional, The number of walks to use for the embedding, default: 10
|
||||
walk_length: 40 # Optional, The walk length to use for the embedding, default: 40
|
||||
window_size: 2 # Optional, The window size to use for the embedding, default: 2
|
||||
iterations: 3 # Optional, The number of iterations to use for the embedding, default: 3
|
||||
random_seed: 86 # Optional, The random seed to use for the embedding, default: 86
|
||||
```
|
||||
"""
|
||||
strategy_type = strategy.get("type", EmbedGraphStrategyType.node2vec)
|
||||
strategy_args = {**strategy}
|
||||
|
||||
@ -72,13 +120,13 @@ async def embed_graph(
|
||||
return run_embeddings(strategy_type, cast(Any, row[column]), strategy_args)
|
||||
|
||||
results = await derive_from_rows(
|
||||
output_df,
|
||||
input,
|
||||
run_strategy,
|
||||
callbacks=callbacks,
|
||||
num_threads=kwargs.get("num_threads", None),
|
||||
)
|
||||
output_df[to] = list(results)
|
||||
return TableContainer(table=output_df)
|
||||
input[to] = list(results)
|
||||
return input
|
||||
|
||||
|
||||
def run_embeddings(
|
||||
|
||||
@ -5,8 +5,9 @@
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import TableContainer, VerbInput, verb
|
||||
|
||||
from graphrag.index.storage import PipelineStorage
|
||||
@ -31,9 +32,30 @@ async def snapshot_rows(
|
||||
**_kwargs: dict,
|
||||
) -> TableContainer:
|
||||
"""Take a by-row snapshot of the tabular data."""
|
||||
data = input.get_input()
|
||||
source = cast(pd.DataFrame, input.get_input())
|
||||
output = await snapshot_rows_df(
|
||||
source,
|
||||
column,
|
||||
base_name,
|
||||
storage,
|
||||
formats,
|
||||
row_name_column,
|
||||
)
|
||||
return TableContainer(table=output)
|
||||
|
||||
|
||||
# todo: once this is out of "verb land", it does not need to return the input
|
||||
async def snapshot_rows_df(
|
||||
input: pd.DataFrame,
|
||||
column: str | None,
|
||||
base_name: str,
|
||||
storage: PipelineStorage,
|
||||
formats: list[str | dict[str, Any]],
|
||||
row_name_column: str | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Take a by-row snapshot of the tabular data."""
|
||||
parsed_formats = _parse_formats(formats)
|
||||
num_rows = len(data)
|
||||
num_rows = len(input)
|
||||
|
||||
def get_row_name(row: Any, row_idx: Any):
|
||||
if row_name_column is None:
|
||||
@ -42,7 +64,7 @@ async def snapshot_rows(
|
||||
return f"{base_name}.{row_idx}"
|
||||
return f"{base_name}.{row[row_name_column]}"
|
||||
|
||||
for row_idx, row in data.iterrows():
|
||||
for row_idx, row in input.iterrows():
|
||||
for fmt in parsed_formats:
|
||||
row_name = get_row_name(row, row_idx)
|
||||
extension = fmt.extension
|
||||
@ -60,8 +82,7 @@ async def snapshot_rows(
|
||||
msg = "column must be specified for text format"
|
||||
raise ValueError(msg)
|
||||
await storage.set(f"{row_name}.{extension}", str(row[column]))
|
||||
|
||||
return TableContainer(table=data)
|
||||
return input
|
||||
|
||||
|
||||
def _parse_formats(formats: list[str | dict[str, Any]]) -> list[FormatSpecifier]:
|
||||
|
||||
@ -40,52 +40,13 @@ def build_steps(
|
||||
|
||||
return [
|
||||
{
|
||||
"verb": "cluster_graph",
|
||||
"verb": "create_base_entity_graph",
|
||||
"args": {
|
||||
**clustering_config,
|
||||
"column": "entity_graph",
|
||||
"to": "clustered_graph",
|
||||
"level_to": "level",
|
||||
"clustering_config": clustering_config,
|
||||
"graphml_snapshot_enabled": graphml_snapshot_enabled,
|
||||
"embed_graph_enabled": embed_graph_enabled,
|
||||
"embedding_config": embed_graph_config,
|
||||
},
|
||||
"input": ({"source": "workflow:create_summarized_entities"}),
|
||||
},
|
||||
{
|
||||
"verb": "snapshot_rows",
|
||||
"enabled": graphml_snapshot_enabled,
|
||||
"args": {
|
||||
"base_name": "clustered_graph",
|
||||
"column": "clustered_graph",
|
||||
"formats": [{"format": "text", "extension": "graphml"}],
|
||||
},
|
||||
},
|
||||
{
|
||||
"verb": "embed_graph",
|
||||
"enabled": embed_graph_enabled,
|
||||
"args": {
|
||||
"column": "clustered_graph",
|
||||
"to": "embeddings",
|
||||
**embed_graph_config,
|
||||
},
|
||||
},
|
||||
{
|
||||
"verb": "snapshot_rows",
|
||||
"enabled": graphml_snapshot_enabled,
|
||||
"args": {
|
||||
"base_name": "embedded_graph",
|
||||
"column": "entity_graph",
|
||||
"formats": [{"format": "text", "extension": "graphml"}],
|
||||
},
|
||||
},
|
||||
{
|
||||
"verb": "select",
|
||||
"args": {
|
||||
# only selecting for documentation sake, so we know what is contained in
|
||||
# this workflow
|
||||
"columns": (
|
||||
["level", "clustered_graph", "embeddings"]
|
||||
if embed_graph_enabled
|
||||
else ["level", "clustered_graph"]
|
||||
),
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
"""The Indexing Engine workflows -> subflows package root."""
|
||||
|
||||
from .create_base_documents import create_base_documents
|
||||
from .create_base_entity_graph import create_base_entity_graph
|
||||
from .create_base_text_units import create_base_text_units
|
||||
from .create_final_communities import create_final_communities
|
||||
from .create_final_community_reports import create_final_community_reports
|
||||
@ -18,6 +19,7 @@ from .create_final_text_units import create_final_text_units
|
||||
|
||||
__all__ = [
|
||||
"create_base_documents",
|
||||
"create_base_entity_graph",
|
||||
"create_base_text_units",
|
||||
"create_final_communities",
|
||||
"create_final_community_reports",
|
||||
|
||||
@ -0,0 +1,85 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All the steps to transform final documents."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
Table,
|
||||
VerbCallbacks,
|
||||
VerbInput,
|
||||
verb,
|
||||
)
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.storage import PipelineStorage
|
||||
from graphrag.index.verbs.graph.clustering.cluster_graph import cluster_graph_df
|
||||
from graphrag.index.verbs.graph.embed.embed_graph import embed_graph_df
|
||||
from graphrag.index.verbs.snapshot_rows import snapshot_rows_df
|
||||
|
||||
|
||||
@verb(
|
||||
name="create_base_entity_graph",
|
||||
treats_input_tables_as_immutable=True,
|
||||
)
|
||||
async def create_base_entity_graph(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
storage: PipelineStorage,
|
||||
clustering_config: dict[str, Any],
|
||||
embedding_config: dict[str, Any],
|
||||
graphml_snapshot_enabled: bool = False,
|
||||
embed_graph_enabled: bool = False,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform final documents."""
|
||||
source = cast(pd.DataFrame, input.get_input())
|
||||
|
||||
clustering_strategy = clustering_config.get("strategy", {"type": "leiden"})
|
||||
|
||||
clustered = cluster_graph_df(
|
||||
source,
|
||||
callbacks,
|
||||
column="entity_graph",
|
||||
strategy=clustering_strategy,
|
||||
to="clustered_graph",
|
||||
level_to="level",
|
||||
)
|
||||
|
||||
if graphml_snapshot_enabled:
|
||||
await snapshot_rows_df(
|
||||
clustered,
|
||||
column="clustered_graph",
|
||||
base_name="clustered_graph",
|
||||
storage=storage,
|
||||
formats=[{"format": "text", "extension": "graphml"}],
|
||||
)
|
||||
|
||||
embedding_strategy = embedding_config.get("strategy")
|
||||
if embed_graph_enabled and embedding_strategy:
|
||||
clustered = await embed_graph_df(
|
||||
clustered,
|
||||
callbacks,
|
||||
column="clustered_graph",
|
||||
strategy=embedding_strategy,
|
||||
to="embeddings",
|
||||
)
|
||||
|
||||
# take second snapshot after embedding
|
||||
# todo: this could be skipped if embedding isn't performed, other wise it is a copy of the regular graph?
|
||||
if graphml_snapshot_enabled:
|
||||
await snapshot_rows_df(
|
||||
clustered,
|
||||
column="entity_graph",
|
||||
base_name="embedded_graph",
|
||||
storage=storage,
|
||||
formats=[{"format": "text", "extension": "graphml"}],
|
||||
)
|
||||
|
||||
final_columns = ["level", "clustered_graph"]
|
||||
if embed_graph_enabled:
|
||||
final_columns.append("embeddings")
|
||||
|
||||
return create_verb_result(cast(Table, clustered[final_columns]))
|
||||
2
tests/fixtures/min-csv/config.json
vendored
2
tests/fixtures/min-csv/config.json
vendored
@ -31,7 +31,7 @@
|
||||
1,
|
||||
2000
|
||||
],
|
||||
"subworkflows": 2,
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 10
|
||||
},
|
||||
"create_final_entities": {
|
||||
|
||||
2
tests/fixtures/text/config.json
vendored
2
tests/fixtures/text/config.json
vendored
@ -48,7 +48,7 @@
|
||||
1,
|
||||
2000
|
||||
],
|
||||
"subworkflows": 2,
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 10
|
||||
},
|
||||
"create_final_entities": {
|
||||
|
||||
114
tests/verbs/test_create_base_entity_graph.py
Normal file
114
tests/verbs/test_create_base_entity_graph.py
Normal file
@ -0,0 +1,114 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
|
||||
from graphrag.index.workflows.v1.create_base_entity_graph import (
|
||||
build_steps,
|
||||
workflow_name,
|
||||
)
|
||||
|
||||
from .util import (
|
||||
get_config_for_workflow,
|
||||
get_workflow_output,
|
||||
load_expected,
|
||||
load_input_tables,
|
||||
remove_disabled_steps,
|
||||
)
|
||||
|
||||
|
||||
async def test_create_base_entity_graph():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_summarized_entities",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
steps = remove_disabled_steps(build_steps(config))
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
# the serialization of the graph may differ so we can't assert the dataframes directly
|
||||
assert actual.shape == expected.shape, "Graph dataframe shapes differ"
|
||||
|
||||
# let's parse a sample of the raw graphml
|
||||
actual_graphml_0 = actual["clustered_graph"][:1][0]
|
||||
actual_graph_0 = nx.parse_graphml(actual_graphml_0)
|
||||
|
||||
expected_graphml_0 = expected["clustered_graph"][:1][0]
|
||||
expected_graph_0 = nx.parse_graphml(expected_graphml_0)
|
||||
|
||||
assert (
|
||||
actual_graph_0.number_of_nodes() == expected_graph_0.number_of_nodes()
|
||||
), "Graphml node count differs"
|
||||
assert (
|
||||
actual_graph_0.number_of_edges() == expected_graph_0.number_of_edges()
|
||||
), "Graphml edge count differs"
|
||||
|
||||
|
||||
async def test_create_base_entity_graph_with_embeddings():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_summarized_entities",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["embed_graph_enabled"] = True
|
||||
|
||||
steps = remove_disabled_steps(build_steps(config))
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
assert (
|
||||
len(actual.columns) == len(expected.columns) + 1
|
||||
), "Graph dataframe missing embedding column"
|
||||
assert "embeddings" in actual.columns, "Graph dataframe missing embedding column"
|
||||
|
||||
|
||||
async def test_create_base_entity_graph_with_snapshots():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_summarized_entities",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
storage = MemoryPipelineStorage()
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["graphml_snapshot"] = True
|
||||
|
||||
steps = remove_disabled_steps(build_steps(config))
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
storage=storage,
|
||||
)
|
||||
|
||||
assert actual.shape == expected.shape, "Graph dataframe shapes differ"
|
||||
|
||||
assert storage.keys() == [
|
||||
"clustered_graph.0.graphml",
|
||||
"clustered_graph.1.graphml",
|
||||
"clustered_graph.2.graphml",
|
||||
"clustered_graph.3.graphml",
|
||||
"embedded_graph.0.graphml",
|
||||
"embedded_graph.1.graphml",
|
||||
"embedded_graph.2.graphml",
|
||||
"embedded_graph.3.graphml",
|
||||
], "Graph snapshot keys differ"
|
||||
@ -14,6 +14,7 @@ from graphrag.index import (
|
||||
create_pipeline_config,
|
||||
)
|
||||
from graphrag.index.run.utils import _create_run_context
|
||||
from graphrag.index.storage.typing import PipelineStorage
|
||||
|
||||
pd.set_option("display.max_columns", None)
|
||||
|
||||
@ -57,7 +58,9 @@ def get_config_for_workflow(name: str) -> PipelineWorkflowConfig:
|
||||
|
||||
|
||||
async def get_workflow_output(
|
||||
input_tables: dict[str, pd.DataFrame], schema: dict
|
||||
input_tables: dict[str, pd.DataFrame],
|
||||
schema: dict,
|
||||
storage: PipelineStorage | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Pass in the input tables, the schema, and the output name"""
|
||||
|
||||
@ -67,7 +70,7 @@ async def get_workflow_output(
|
||||
input_tables=input_tables,
|
||||
)
|
||||
|
||||
context = _create_run_context(None, None, None)
|
||||
context = _create_run_context(storage, None, None)
|
||||
|
||||
await workflow.run(context=context)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user