mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 09:07:20 +08:00
Collapse create final entities (#1220)
* Collapse create_final_entities * Update smoke tests * Semver * Remove prints * Update embedding assertions
This commit is contained in:
parent
3217013019
commit
ce71bcf7fb
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Collapse create-final-entities."
|
||||
}
|
||||
@ -24,110 +24,16 @@ def build_steps(
|
||||
)
|
||||
skip_name_embedding = config.get("skip_name_embedding", False)
|
||||
skip_description_embedding = config.get("skip_description_embedding", False)
|
||||
is_using_vector_store = (
|
||||
entity_name_embed_config.get("strategy", {}).get("vector_store", None)
|
||||
is not None
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"verb": "unpack_graph",
|
||||
"verb": "create_final_entities",
|
||||
"args": {
|
||||
"column": "clustered_graph",
|
||||
"type": "nodes",
|
||||
"skip_name_embedding": skip_name_embedding,
|
||||
"skip_description_embedding": skip_description_embedding,
|
||||
"name_text_embed": entity_name_embed_config,
|
||||
"description_text_embed": entity_name_description_embed_config,
|
||||
},
|
||||
"input": {"source": "workflow:create_base_entity_graph"},
|
||||
},
|
||||
{"verb": "rename", "args": {"columns": {"label": "title"}}},
|
||||
{
|
||||
"verb": "select",
|
||||
"args": {
|
||||
"columns": [
|
||||
"id",
|
||||
"title",
|
||||
"type",
|
||||
"description",
|
||||
"human_readable_id",
|
||||
"graph_embedding",
|
||||
"source_id",
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
# create_base_entity_graph has multiple levels of clustering, which means there are multiple graphs with the same entities
|
||||
# this dedupes the entities so that there is only one of each entity
|
||||
"verb": "dedupe",
|
||||
"args": {"columns": ["id"]},
|
||||
},
|
||||
{"verb": "rename", "args": {"columns": {"title": "name"}}},
|
||||
{
|
||||
# ELIMINATE EMPTY NAMES
|
||||
"verb": "filter",
|
||||
"args": {
|
||||
"column": "name",
|
||||
"criteria": [
|
||||
{
|
||||
"type": "value",
|
||||
"operator": "is not empty",
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
"verb": "text_split",
|
||||
"args": {"separator": ",", "column": "source_id", "to": "text_unit_ids"},
|
||||
},
|
||||
{"verb": "drop", "args": {"columns": ["source_id"]}},
|
||||
{
|
||||
"verb": "text_embed",
|
||||
"enabled": not skip_name_embedding,
|
||||
"args": {
|
||||
"embedding_name": "entity_name",
|
||||
"column": "name",
|
||||
"to": "name_embedding",
|
||||
**entity_name_embed_config,
|
||||
},
|
||||
},
|
||||
{
|
||||
"verb": "merge",
|
||||
"enabled": not skip_description_embedding,
|
||||
"args": {
|
||||
"strategy": "concat",
|
||||
"columns": ["name", "description"],
|
||||
"to": "name_description",
|
||||
"delimiter": ":",
|
||||
"preserveSource": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
"verb": "text_embed",
|
||||
"enabled": not skip_description_embedding,
|
||||
"args": {
|
||||
"embedding_name": "entity_name_description",
|
||||
"column": "name_description",
|
||||
"to": "description_embedding",
|
||||
**entity_name_description_embed_config,
|
||||
},
|
||||
},
|
||||
{
|
||||
"verb": "drop",
|
||||
"enabled": not skip_description_embedding,
|
||||
"args": {
|
||||
"columns": ["name_description"],
|
||||
},
|
||||
},
|
||||
{
|
||||
# ELIMINATE EMPTY DESCRIPTION EMBEDDINGS
|
||||
"verb": "filter",
|
||||
"enabled": not skip_description_embedding and not is_using_vector_store,
|
||||
"args": {
|
||||
"column": "description_embedding",
|
||||
"criteria": [
|
||||
{
|
||||
"type": "value",
|
||||
"operator": "is not empty",
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
@ -8,6 +8,7 @@ from .create_base_text_units import create_base_text_units
|
||||
from .create_final_communities import create_final_communities
|
||||
from .create_final_covariates import create_final_covariates
|
||||
from .create_final_documents import create_final_documents
|
||||
from .create_final_entities import create_final_entities
|
||||
from .create_final_nodes import create_final_nodes
|
||||
from .create_final_relationships import (
|
||||
create_final_relationships,
|
||||
@ -20,6 +21,7 @@ __all__ = [
|
||||
"create_final_communities",
|
||||
"create_final_covariates",
|
||||
"create_final_documents",
|
||||
"create_final_entities",
|
||||
"create_final_nodes",
|
||||
"create_final_relationships",
|
||||
"create_final_text_units",
|
||||
|
||||
105
graphrag/index/workflows/v1/subflows/create_final_entities.py
Normal file
105
graphrag/index/workflows/v1/subflows/create_final_entities.py
Normal file
@ -0,0 +1,105 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All the steps to transform final entities."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
Table,
|
||||
VerbCallbacks,
|
||||
VerbInput,
|
||||
verb,
|
||||
)
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.verbs.graph.unpack import unpack_graph_df
|
||||
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
|
||||
from graphrag.index.verbs.text.split import text_split_df
|
||||
|
||||
|
||||
@verb(
|
||||
name="create_final_entities",
|
||||
treats_input_tables_as_immutable=True,
|
||||
)
|
||||
async def create_final_entities(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
name_text_embed: dict,
|
||||
description_text_embed: dict,
|
||||
skip_name_embedding: bool = False,
|
||||
skip_description_embedding: bool = False,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform final entities."""
|
||||
table = cast(pd.DataFrame, input.get_input())
|
||||
|
||||
nodes = unpack_graph_df(table, callbacks, "clustered_graph", "nodes")
|
||||
nodes.rename(columns={"label": "name"}, inplace=True)
|
||||
|
||||
nodes = cast(
|
||||
pd.DataFrame,
|
||||
nodes[
|
||||
[
|
||||
"id",
|
||||
"name",
|
||||
"type",
|
||||
"description",
|
||||
"human_readable_id",
|
||||
"graph_embedding",
|
||||
"source_id",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
# create_base_entity_graph has multiple levels of clustering, which means there are multiple graphs with the same entities
|
||||
# this dedupes the entities so that there is only one of each entity
|
||||
nodes.drop_duplicates(subset="id", inplace=True)
|
||||
|
||||
# eliminate empty names
|
||||
filtered = cast(pd.DataFrame, nodes[nodes["name"].notna()].reset_index(drop=True))
|
||||
|
||||
with_ids = text_split_df(
|
||||
filtered, column="source_id", separator=",", to="text_unit_ids"
|
||||
)
|
||||
with_ids.drop(columns=["source_id"], inplace=True)
|
||||
|
||||
embedded = with_ids
|
||||
|
||||
if not skip_name_embedding:
|
||||
embedded = await text_embed_df(
|
||||
embedded,
|
||||
callbacks,
|
||||
cache,
|
||||
column="name",
|
||||
strategy=name_text_embed["strategy"],
|
||||
to="name_embedding",
|
||||
embedding_name="entity_name",
|
||||
)
|
||||
|
||||
if not skip_description_embedding:
|
||||
# description embedding is a concat of the name + description, so we'll create a temporary column
|
||||
embedded["name_description"] = embedded["name"] + ":" + embedded["description"]
|
||||
embedded = await text_embed_df(
|
||||
embedded,
|
||||
callbacks,
|
||||
cache,
|
||||
column="name_description",
|
||||
strategy=description_text_embed["strategy"],
|
||||
to="description_embedding",
|
||||
embedding_name="entity_name_description",
|
||||
)
|
||||
embedded.drop(columns=["name_description"], inplace=True)
|
||||
is_using_vector_store = (
|
||||
description_text_embed.get("strategy", {}).get("vector_store", None)
|
||||
is not None
|
||||
)
|
||||
if not is_using_vector_store:
|
||||
embedded = embedded[embedded["description_embedding"].notna()].reset_index(
|
||||
drop=True
|
||||
)
|
||||
|
||||
return create_verb_result(cast(Table, embedded))
|
||||
2
tests/fixtures/min-csv/config.json
vendored
2
tests/fixtures/min-csv/config.json
vendored
@ -44,7 +44,7 @@
|
||||
"description",
|
||||
"graph_embedding"
|
||||
],
|
||||
"subworkflows": 11,
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300
|
||||
},
|
||||
"create_final_relationships": {
|
||||
|
||||
2
tests/fixtures/text/config.json
vendored
2
tests/fixtures/text/config.json
vendored
@ -61,7 +61,7 @@
|
||||
"description",
|
||||
"graph_embedding"
|
||||
],
|
||||
"subworkflows": 11,
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 300
|
||||
},
|
||||
"create_final_relationships": {
|
||||
|
||||
@ -63,4 +63,4 @@ async def test_create_final_documents_with_embeddings():
|
||||
assert "raw_content_embedding" in actual.columns
|
||||
assert len(actual.columns) == len(expected.columns) + 1
|
||||
# the mock impl returns an array of 3 floats for each embedding
|
||||
assert len(actual["raw_content_embedding"][0]) == 3
|
||||
assert len(actual["raw_content_embedding"][:1][0]) == 3
|
||||
|
||||
133
tests/verbs/test_create_final_entities.py
Normal file
133
tests/verbs/test_create_final_entities.py
Normal file
@ -0,0 +1,133 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.index.workflows.v1.create_final_entities import (
|
||||
build_steps,
|
||||
workflow_name,
|
||||
)
|
||||
|
||||
from .util import (
|
||||
compare_outputs,
|
||||
get_config_for_workflow,
|
||||
get_workflow_output,
|
||||
load_expected,
|
||||
load_input_tables,
|
||||
remove_disabled_steps,
|
||||
)
|
||||
|
||||
|
||||
async def test_create_final_entities():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_entity_graph",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["skip_name_embedding"] = True
|
||||
config["skip_description_embedding"] = True
|
||||
|
||||
steps = remove_disabled_steps(build_steps(config))
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
# ignore the description_embedding column, which is included in the expected output due to default config
|
||||
compare_outputs(
|
||||
actual,
|
||||
expected,
|
||||
columns=[
|
||||
"id",
|
||||
"name",
|
||||
"type",
|
||||
"description",
|
||||
"human_readable_id",
|
||||
"graph_embedding",
|
||||
"text_unit_ids",
|
||||
],
|
||||
)
|
||||
assert len(actual.columns) == len(expected.columns) - 1
|
||||
|
||||
|
||||
async def test_create_final_entities_with_name_embeddings():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_entity_graph",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["skip_name_embedding"] = False
|
||||
config["skip_description_embedding"] = True
|
||||
config["entity_name_embed"]["strategy"]["type"] = "mock"
|
||||
|
||||
steps = remove_disabled_steps(build_steps(config))
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
assert "name_embedding" in actual.columns
|
||||
assert len(actual.columns) == len(expected.columns)
|
||||
# the mock impl returns an array of 3 floats for each embedding
|
||||
assert len(actual["name_embedding"][:1][0]) == 3
|
||||
|
||||
|
||||
async def test_create_final_entities_with_description_embeddings():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_entity_graph",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["skip_name_embedding"] = True
|
||||
config["skip_description_embedding"] = False
|
||||
config["entity_name_description_embed"]["strategy"]["type"] = "mock"
|
||||
|
||||
steps = remove_disabled_steps(build_steps(config))
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
assert "description_embedding" in actual.columns
|
||||
assert len(actual.columns) == len(expected.columns)
|
||||
assert len(actual["description_embedding"][:1][0]) == 3
|
||||
|
||||
|
||||
async def test_create_final_entities_with_name_and_description_embeddings():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_entity_graph",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["skip_name_embedding"] = False
|
||||
config["skip_description_embedding"] = False
|
||||
config["entity_name_description_embed"]["strategy"]["type"] = "mock"
|
||||
config["entity_name_embed"]["strategy"]["type"] = "mock"
|
||||
|
||||
steps = remove_disabled_steps(build_steps(config))
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
assert "description_embedding" in actual.columns
|
||||
assert len(actual.columns) == len(expected.columns) + 1
|
||||
assert len(actual["description_embedding"][:1][0]) == 3
|
||||
@ -65,4 +65,4 @@ async def test_create_final_relationships_with_embeddings():
|
||||
assert "description_embedding" in actual.columns
|
||||
assert len(actual.columns) == len(expected.columns) + 1
|
||||
# the mock impl returns an array of 3 floats for each embedding
|
||||
assert len(actual["description_embedding"][0]) == 3
|
||||
assert len(actual["description_embedding"][:1][0]) == 3
|
||||
|
||||
@ -102,4 +102,4 @@ async def test_create_final_text_units_with_embeddings():
|
||||
assert "text_embedding" in actual.columns
|
||||
assert len(actual.columns) == len(expected.columns) + 1
|
||||
# the mock impl returns an array of 3 floats for each embedding
|
||||
assert len(actual["text_embedding"][0]) == 3
|
||||
assert len(actual["text_embedding"][:1][0]) == 3
|
||||
|
||||
@ -88,12 +88,14 @@ def compare_outputs(
|
||||
assert column in actual.columns
|
||||
try:
|
||||
# dtypes can differ since the test data is read from parquet and our workflow runs in memory
|
||||
assert_series_equal(actual[column], expected[column], check_dtype=False)
|
||||
assert_series_equal(
|
||||
actual[column], expected[column], check_dtype=False, check_index=False
|
||||
)
|
||||
except AssertionError:
|
||||
print("Expected:")
|
||||
print(expected[column])
|
||||
print("Actual:")
|
||||
print(actual[columns])
|
||||
print(actual[column])
|
||||
raise
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user