Collapse create final entities (#1220)

* Collapse create_final_entities

* Update smoke tests

* Semver

* Remove prints

* Update embedding assertions
This commit is contained in:
Nathan Evans 2024-09-25 17:35:44 -07:00 committed by GitHub
parent 3217013019
commit ce71bcf7fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 258 additions and 106 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Collapse create-final-entities."
}

View File

@ -24,110 +24,16 @@ def build_steps(
)
skip_name_embedding = config.get("skip_name_embedding", False)
skip_description_embedding = config.get("skip_description_embedding", False)
is_using_vector_store = (
entity_name_embed_config.get("strategy", {}).get("vector_store", None)
is not None
)
return [
{
"verb": "unpack_graph",
"verb": "create_final_entities",
"args": {
"column": "clustered_graph",
"type": "nodes",
"skip_name_embedding": skip_name_embedding,
"skip_description_embedding": skip_description_embedding,
"name_text_embed": entity_name_embed_config,
"description_text_embed": entity_name_description_embed_config,
},
"input": {"source": "workflow:create_base_entity_graph"},
},
{"verb": "rename", "args": {"columns": {"label": "title"}}},
{
"verb": "select",
"args": {
"columns": [
"id",
"title",
"type",
"description",
"human_readable_id",
"graph_embedding",
"source_id",
],
},
},
{
# create_base_entity_graph has multiple levels of clustering, which means there are multiple graphs with the same entities
# this dedupes the entities so that there is only one of each entity
"verb": "dedupe",
"args": {"columns": ["id"]},
},
{"verb": "rename", "args": {"columns": {"title": "name"}}},
{
# ELIMINATE EMPTY NAMES
"verb": "filter",
"args": {
"column": "name",
"criteria": [
{
"type": "value",
"operator": "is not empty",
}
],
},
},
{
"verb": "text_split",
"args": {"separator": ",", "column": "source_id", "to": "text_unit_ids"},
},
{"verb": "drop", "args": {"columns": ["source_id"]}},
{
"verb": "text_embed",
"enabled": not skip_name_embedding,
"args": {
"embedding_name": "entity_name",
"column": "name",
"to": "name_embedding",
**entity_name_embed_config,
},
},
{
"verb": "merge",
"enabled": not skip_description_embedding,
"args": {
"strategy": "concat",
"columns": ["name", "description"],
"to": "name_description",
"delimiter": ":",
"preserveSource": True,
},
},
{
"verb": "text_embed",
"enabled": not skip_description_embedding,
"args": {
"embedding_name": "entity_name_description",
"column": "name_description",
"to": "description_embedding",
**entity_name_description_embed_config,
},
},
{
"verb": "drop",
"enabled": not skip_description_embedding,
"args": {
"columns": ["name_description"],
},
},
{
# ELIMINATE EMPTY DESCRIPTION EMBEDDINGS
"verb": "filter",
"enabled": not skip_description_embedding and not is_using_vector_store,
"args": {
"column": "description_embedding",
"criteria": [
{
"type": "value",
"operator": "is not empty",
}
],
},
},
]

View File

@ -8,6 +8,7 @@ from .create_base_text_units import create_base_text_units
from .create_final_communities import create_final_communities
from .create_final_covariates import create_final_covariates
from .create_final_documents import create_final_documents
from .create_final_entities import create_final_entities
from .create_final_nodes import create_final_nodes
from .create_final_relationships import (
create_final_relationships,
@ -20,6 +21,7 @@ __all__ = [
"create_final_communities",
"create_final_covariates",
"create_final_documents",
"create_final_entities",
"create_final_nodes",
"create_final_relationships",
"create_final_text_units",

View File

@ -0,0 +1,105 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to transform final entities."""
from typing import cast
import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.cache import PipelineCache
from graphrag.index.verbs.graph.unpack import unpack_graph_df
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
from graphrag.index.verbs.text.split import text_split_df
@verb(
name="create_final_entities",
treats_input_tables_as_immutable=True,
)
async def create_final_entities(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
name_text_embed: dict,
description_text_embed: dict,
skip_name_embedding: bool = False,
skip_description_embedding: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform final entities."""
table = cast(pd.DataFrame, input.get_input())
nodes = unpack_graph_df(table, callbacks, "clustered_graph", "nodes")
nodes.rename(columns={"label": "name"}, inplace=True)
nodes = cast(
pd.DataFrame,
nodes[
[
"id",
"name",
"type",
"description",
"human_readable_id",
"graph_embedding",
"source_id",
]
],
)
# create_base_entity_graph has multiple levels of clustering, which means there are multiple graphs with the same entities
# this dedupes the entities so that there is only one of each entity
nodes.drop_duplicates(subset="id", inplace=True)
# eliminate empty names
filtered = cast(pd.DataFrame, nodes[nodes["name"].notna()].reset_index(drop=True))
with_ids = text_split_df(
filtered, column="source_id", separator=",", to="text_unit_ids"
)
with_ids.drop(columns=["source_id"], inplace=True)
embedded = with_ids
if not skip_name_embedding:
embedded = await text_embed_df(
embedded,
callbacks,
cache,
column="name",
strategy=name_text_embed["strategy"],
to="name_embedding",
embedding_name="entity_name",
)
if not skip_description_embedding:
# description embedding is a concat of the name + description, so we'll create a temporary column
embedded["name_description"] = embedded["name"] + ":" + embedded["description"]
embedded = await text_embed_df(
embedded,
callbacks,
cache,
column="name_description",
strategy=description_text_embed["strategy"],
to="description_embedding",
embedding_name="entity_name_description",
)
embedded.drop(columns=["name_description"], inplace=True)
is_using_vector_store = (
description_text_embed.get("strategy", {}).get("vector_store", None)
is not None
)
if not is_using_vector_store:
embedded = embedded[embedded["description_embedding"].notna()].reset_index(
drop=True
)
return create_verb_result(cast(Table, embedded))

View File

@ -44,7 +44,7 @@
"description",
"graph_embedding"
],
"subworkflows": 11,
"subworkflows": 1,
"max_runtime": 300
},
"create_final_relationships": {

View File

@ -61,7 +61,7 @@
"description",
"graph_embedding"
],
"subworkflows": 11,
"subworkflows": 1,
"max_runtime": 300
},
"create_final_relationships": {

View File

@ -63,4 +63,4 @@ async def test_create_final_documents_with_embeddings():
assert "raw_content_embedding" in actual.columns
assert len(actual.columns) == len(expected.columns) + 1
# the mock impl returns an array of 3 floats for each embedding
assert len(actual["raw_content_embedding"][0]) == 3
assert len(actual["raw_content_embedding"][:1][0]) == 3

View File

@ -0,0 +1,133 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.index.workflows.v1.create_final_entities import (
build_steps,
workflow_name,
)
from .util import (
compare_outputs,
get_config_for_workflow,
get_workflow_output,
load_expected,
load_input_tables,
remove_disabled_steps,
)
async def test_create_final_entities():
input_tables = load_input_tables([
"workflow:create_base_entity_graph",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
config["skip_name_embedding"] = True
config["skip_description_embedding"] = True
steps = remove_disabled_steps(build_steps(config))
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
# ignore the description_embedding column, which is included in the expected output due to default config
compare_outputs(
actual,
expected,
columns=[
"id",
"name",
"type",
"description",
"human_readable_id",
"graph_embedding",
"text_unit_ids",
],
)
assert len(actual.columns) == len(expected.columns) - 1
async def test_create_final_entities_with_name_embeddings():
input_tables = load_input_tables([
"workflow:create_base_entity_graph",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
config["skip_name_embedding"] = False
config["skip_description_embedding"] = True
config["entity_name_embed"]["strategy"]["type"] = "mock"
steps = remove_disabled_steps(build_steps(config))
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
assert "name_embedding" in actual.columns
assert len(actual.columns) == len(expected.columns)
# the mock impl returns an array of 3 floats for each embedding
assert len(actual["name_embedding"][:1][0]) == 3
async def test_create_final_entities_with_description_embeddings():
input_tables = load_input_tables([
"workflow:create_base_entity_graph",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
config["skip_name_embedding"] = True
config["skip_description_embedding"] = False
config["entity_name_description_embed"]["strategy"]["type"] = "mock"
steps = remove_disabled_steps(build_steps(config))
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
assert "description_embedding" in actual.columns
assert len(actual.columns) == len(expected.columns)
assert len(actual["description_embedding"][:1][0]) == 3
async def test_create_final_entities_with_name_and_description_embeddings():
input_tables = load_input_tables([
"workflow:create_base_entity_graph",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
config["skip_name_embedding"] = False
config["skip_description_embedding"] = False
config["entity_name_description_embed"]["strategy"]["type"] = "mock"
config["entity_name_embed"]["strategy"]["type"] = "mock"
steps = remove_disabled_steps(build_steps(config))
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
assert "description_embedding" in actual.columns
assert len(actual.columns) == len(expected.columns) + 1
assert len(actual["description_embedding"][:1][0]) == 3

View File

@ -65,4 +65,4 @@ async def test_create_final_relationships_with_embeddings():
assert "description_embedding" in actual.columns
assert len(actual.columns) == len(expected.columns) + 1
# the mock impl returns an array of 3 floats for each embedding
assert len(actual["description_embedding"][0]) == 3
assert len(actual["description_embedding"][:1][0]) == 3

View File

@ -102,4 +102,4 @@ async def test_create_final_text_units_with_embeddings():
assert "text_embedding" in actual.columns
assert len(actual.columns) == len(expected.columns) + 1
# the mock impl returns an array of 3 floats for each embedding
assert len(actual["text_embedding"][0]) == 3
assert len(actual["text_embedding"][:1][0]) == 3

View File

@ -88,12 +88,14 @@ def compare_outputs(
assert column in actual.columns
try:
# dtypes can differ since the test data is read from parquet and our workflow runs in memory
assert_series_equal(actual[column], expected[column], check_dtype=False)
assert_series_equal(
actual[column], expected[column], check_dtype=False, check_index=False
)
except AssertionError:
print("Expected:")
print(expected[column])
print("Actual:")
print(actual[columns])
print(actual[column])
raise