Revisit create final text units (#1216)

* Add embeddings to collapsed subflow

* Semver

* Fix smoke tests
This commit is contained in:
Nathan Evans 2024-09-25 16:55:27 -07:00 committed by GitHub
parent 73e709b686
commit 3217013019
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 91 additions and 52 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Add embeddings to subflow."
}

View File

@ -21,12 +21,8 @@ def build_steps(
"""
base_text_embed = config.get("text_embed", {})
text_unit_text_embed_config = config.get("text_unit_text_embed", base_text_embed)
covariates_enabled = config.get("covariates_enabled", False)
skip_text_unit_embedding = config.get("skip_text_unit_embedding", False)
is_using_vector_store = (
text_unit_text_embed_config.get("strategy", {}).get("vector_store", None)
is not None
)
covariates_enabled = config.get("covariates_enabled", False)
others = [
"workflow:create_final_entities",
@ -37,8 +33,10 @@ def build_steps(
return [
{
"verb": "create_final_text_units_pre_embedding",
"verb": "create_final_text_units",
"args": {
"skip_embedding": skip_text_unit_embedding,
"text_embed": text_unit_text_embed_config,
"covariates_enabled": covariates_enabled,
},
"input": {
@ -46,35 +44,4 @@ def build_steps(
"others": others,
},
},
# Text-Embed after final aggregations
{
"id": "embedded_text_units",
"verb": "text_embed",
"enabled": not skip_text_unit_embedding,
"args": {
"column": config.get("column", "text"),
"to": config.get("to", "text_embedding"),
**text_unit_text_embed_config,
},
},
{
"verb": "select",
"args": {
# Final select to get output in the correct shape
"columns": [
"id",
"text",
*(
[]
if (skip_text_unit_embedding or is_using_vector_store)
else ["text_embedding"]
),
"n_tokens",
"document_ids",
"entity_ids",
"relationship_ids",
*([] if not covariates_enabled else ["covariate_ids"]),
],
},
},
]

View File

@ -12,7 +12,7 @@ from .create_final_nodes import create_final_nodes
from .create_final_relationships import (
create_final_relationships,
)
from .create_final_text_units_pre_embedding import create_final_text_units_pre_embedding
from .create_final_text_units import create_final_text_units
__all__ = [
"create_base_documents",
@ -22,5 +22,5 @@ __all__ = [
"create_final_documents",
"create_final_nodes",
"create_final_relationships",
"create_final_text_units_pre_embedding",
"create_final_text_units",
]

View File

@ -1,25 +1,35 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All the steps to transform before we embed the text units."""
"""All the steps to transform the text units."""
from typing import cast
import pandas as pd
from datashaper.engine.verbs.verb_input import VerbInput
from datashaper.engine.verbs.verbs_mapping import verb
from datashaper.table_store.types import Table, VerbResult, create_verb_result
@verb(
name="create_final_text_units_pre_embedding", treats_input_tables_as_immutable=True
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
VerbResult,
create_verb_result,
verb,
)
def create_final_text_units_pre_embedding(
from graphrag.index.cache import PipelineCache
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
@verb(name="create_final_text_units", treats_input_tables_as_immutable=True)
async def create_final_text_units(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
text_embed: dict,
skip_embedding: bool = False,
covariates_enabled: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform before we embed the text units."""
"""All the steps to transform the text units."""
table = cast(pd.DataFrame, input.get_input())
others = input.get_others()
@ -43,7 +53,33 @@ def create_final_text_units_pre_embedding(
aggregated = final_joined.groupby("id", sort=False).agg("first").reset_index()
return create_verb_result(cast(Table, aggregated))
if not skip_embedding:
aggregated = await text_embed_df(
aggregated,
callbacks,
cache,
column="text",
strategy=text_embed["strategy"],
to="text_embedding",
)
is_using_vector_store = (
text_embed.get("strategy", {}).get("vector_store", None) is not None
)
final = aggregated[
[
"id",
"text",
*([] if (skip_embedding or is_using_vector_store) else ["text_embedding"]),
"n_tokens",
"document_ids",
"entity_ids",
"relationship_ids",
*([] if not covariates_enabled else ["covariate_ids"]),
]
]
return create_verb_result(cast(Table, final))
def _entities(df: pd.DataFrame) -> pd.DataFrame:

View File

@ -105,7 +105,7 @@
"relationship_ids",
"entity_ids"
],
"subworkflows": 2,
"subworkflows": 1,
"max_runtime": 100
},
"create_base_documents": {

View File

@ -122,7 +122,7 @@
"relationship_ids",
"entity_ids"
],
"subworkflows": 2,
"subworkflows": 1,
"max_runtime": 100
},
"create_base_documents": {

View File

@ -71,3 +71,35 @@ async def test_create_final_text_units_no_covariates():
expected,
["id", "text", "n_tokens", "document_ids", "entity_ids", "relationship_ids"],
)
async def test_create_final_text_units_with_embeddings():
input_tables = load_input_tables([
"workflow:create_base_text_units",
"workflow:create_final_entities",
"workflow:create_final_relationships",
"workflow:create_final_covariates",
])
expected = load_expected(workflow_name)
config = get_config_for_workflow(workflow_name)
config["covariates_enabled"] = True
config["skip_text_unit_embedding"] = False
# default config has a detailed standard embed config
# just override the strategy to mock so the rest of the required parameters are in place
config["text_unit_text_embed"]["strategy"]["type"] = "mock"
steps = remove_disabled_steps(build_steps(config))
actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)
assert "text_embedding" in actual.columns
assert len(actual.columns) == len(expected.columns) + 1
# the mock impl returns an array of 3 floats for each embedding
assert len(actual["text_embedding"][0]) == 3