Collapse relationship embeddings (#1199)

* Merge text_embed into a single relationships subflow

* Update smoke tests

* Semver

* Spelling
This commit is contained in:
Nathan Evans 2024-09-24 15:03:26 -07:00 committed by GitHub
parent 1755afbdec
commit f518c8b80b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 104 additions and 83 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Merge text_embed into create-final-relationships subflow."
}

View File

@ -79,6 +79,23 @@ async def text_embed(
<...>
```
"""
input_df = cast(pd.DataFrame, input.get_input())
result_df = await text_embed_df(
input_df, callbacks, cache, column, strategy, **kwargs
)
return TableContainer(table=result_df)
# TODO: this ultimately just creates a new column, so our embed function could just generate a series instead of updating the dataframe
async def text_embed_df(
input: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
column: str,
strategy: dict,
**kwargs,
):
"""Embed a piece of text into a vector space."""
vector_store_config = strategy.get("vector_store")
if vector_store_config:
@ -113,28 +130,28 @@ async def text_embed(
async def _text_embed_in_memory(
input: VerbInput,
input: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
column: str,
strategy: dict,
to: str,
):
output_df = cast(pd.DataFrame, input.get_input())
output_df = input
strategy_type = strategy["type"]
strategy_exec = load_strategy(strategy_type)
strategy_args = {**strategy}
input_table = input.get_input()
input_table = input
texts: list[str] = input_table[column].to_numpy().tolist()
result = await strategy_exec(texts, callbacks, cache, strategy_args)
output_df[to] = result.embeddings
return TableContainer(table=output_df)
return output_df
async def _text_embed_with_vector_store(
input: VerbInput,
input: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
column: str,
@ -144,7 +161,7 @@ async def _text_embed_with_vector_store(
store_in_table: bool = False,
to: str = "",
):
output_df = cast(pd.DataFrame, input.get_input())
output_df = input
strategy_type = strategy["type"]
strategy_exec = load_strategy(strategy_type)
strategy_args = {**strategy}
@ -179,10 +196,8 @@ async def _text_embed_with_vector_store(
all_results = []
while insert_batch_size * i < input.get_input().shape[0]:
batch = input.get_input().iloc[
insert_batch_size * i : insert_batch_size * (i + 1)
]
while insert_batch_size * i < input.shape[0]:
batch = input.iloc[insert_batch_size * i : insert_batch_size * (i + 1)]
texts: list[str] = batch[column].to_numpy().tolist()
titles: list[str] = batch[title_column].to_numpy().tolist()
ids: list[str] = batch[id_column].to_numpy().tolist()
@ -218,7 +233,7 @@ async def _text_embed_with_vector_store(
if store_in_table:
output_df[to] = all_results
return TableContainer(table=output_df)
return output_df
def _create_vector_store(

View File

@ -23,30 +23,15 @@ def build_steps(
"relationship_description_embed", base_text_embed
)
skip_description_embedding = config.get("skip_description_embedding", False)
return [
{
"id": "pre_embedding",
"verb": "create_final_relationships_pre_embedding",
"input": {"source": "workflow:create_base_entity_graph"},
},
{
"id": "description_embedding",
"verb": "text_embed",
"enabled": not skip_description_embedding,
"verb": "create_final_relationships",
"args": {
"embedding_name": "relationship_description",
"column": "description",
"to": "description_embedding",
**relationship_description_embed_config,
"skip_embedding": skip_description_embedding,
"text_embed": relationship_description_embed_config,
},
},
{
"verb": "create_final_relationships_post_embedding",
"input": {
"source": "pre_embedding"
if skip_description_embedding
else "description_embedding",
"source": "workflow:create_base_entity_graph",
"nodes": "workflow:create_final_nodes",
},
},

View File

@ -7,11 +7,8 @@ from .create_base_documents import create_base_documents
from .create_base_text_units import create_base_text_units
from .create_final_communities import create_final_communities
from .create_final_nodes import create_final_nodes
from .create_final_relationships_post_embedding import (
create_final_relationships_post_embedding,
)
from .create_final_relationships_pre_embedding import (
create_final_relationships_pre_embedding,
from .create_final_relationships import (
create_final_relationships,
)
from .create_final_text_units_pre_embedding import create_final_text_units_pre_embedding
@ -20,7 +17,6 @@ __all__ = [
"create_base_text_units",
"create_final_communities",
"create_final_nodes",
"create_final_relationships_post_embedding",
"create_final_relationships_pre_embedding",
"create_final_relationships",
"create_final_text_units_pre_embedding",
]

View File

@ -1,37 +1,64 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to transform final relationships after they are embedded."""
"""All the steps to transform final relationships before they are embedded."""
from typing import Any, cast
import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.cache import PipelineCache
from graphrag.index.utils.ds_util import get_required_input_table
from graphrag.index.verbs.graph.compute_edge_combined_degree import (
compute_edge_combined_degree_df,
)
from graphrag.index.verbs.graph.unpack import unpack_graph_df
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
@verb(
name="create_final_relationships_post_embedding",
name="create_final_relationships",
treats_input_tables_as_immutable=True,
)
def create_final_relationships_post_embedding(
async def create_final_relationships(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
text_embed: dict,
skip_embedding: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform final relationships after they are embedded."""
"""All the steps to transform final relationships before they are embedded."""
table = cast(pd.DataFrame, input.get_input())
nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table)
pruned_edges = table.drop(columns=["level"])
graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges")
graph_edges.rename(columns={"source_id": "text_unit_ids"}, inplace=True)
filtered = cast(
pd.DataFrame, graph_edges[graph_edges["level"] == 0].reset_index(drop=True)
)
if not skip_embedding:
filtered = await text_embed_df(
filtered,
callbacks,
cache,
column="description",
strategy=text_embed["strategy"],
to="description_embedding",
embedding_name="relationship_description",
)
pruned_edges = filtered.drop(columns=["level"])
filtered_nodes = cast(
pd.DataFrame,

View File

@ -1,38 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to transform final relationships before they are embedded."""
from typing import cast
import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.verbs.graph.unpack import unpack_graph_df
@verb(
name="create_final_relationships_pre_embedding",
treats_input_tables_as_immutable=True,
)
def create_final_relationships_pre_embedding(
input: VerbInput,
callbacks: VerbCallbacks,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform final relationships before they are embedded."""
table = cast(pd.DataFrame, input.get_input())
graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges")
graph_edges.rename(columns={"source_id": "text_unit_ids"}, inplace=True)
filtered = graph_edges[graph_edges["level"] == 0].reset_index(drop=True)
return create_verb_result(cast(Table, filtered))

View File

@ -52,7 +52,7 @@
1,
2000
],
"subworkflows": 2,
"subworkflows": 1,
"max_runtime": 100
},
"create_final_nodes": {

View File

@ -71,7 +71,7 @@
1,
2000
],
"subworkflows": 2,
"subworkflows": 1,
"max_runtime": 100
},
"create_final_nodes": {

View File

@ -37,3 +37,32 @@ async def test_create_final_relationships():
)
compare_outputs(actual, expected)
async def test_create_final_relationships_with_embeddings():
input_tables = load_input_tables([
"workflow:create_base_entity_graph",
"workflow:create_final_nodes",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
config["skip_description_embedding"] = False
# default config has a detailed standard embed config
# just override the strategy to mock so the rest of the required parameters are in place
config["relationship_description_embed"]["strategy"]["type"] = "mock"
steps = remove_disabled_steps(build_steps(config))
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
assert "description_embedding" in actual.columns
assert len(actual.columns) == len(expected.columns) + 1
# the mock impl returns an array of 3 floats for each embedding
assert len(actual["description_embedding"][0]) == 3

View File

@ -13,6 +13,7 @@ from graphrag.index import (
PipelineWorkflowStep,
create_pipeline_config,
)
from graphrag.index.run.utils import _create_run_context
def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]:
@ -61,7 +62,9 @@ async def get_workflow_output(
input_tables=input_tables,
)
await workflow.run()
context = _create_run_context(None, None, None)
await workflow.run(context=context)
# if there's only one output, it is the default here, no name required
return cast(pd.DataFrame, workflow.output())