Covariate collapse (#1142)

* Setup basic verb test runner

* Replace join_text_units_to_entity_ids with subflow

* Update comments

* Replace join_text_units_to_relationship_ids subflow

* Roll in final select

* Reuse assertion util

* Small fix + format

* Format/typing

* Semver

* Format/typing

* Semver

* Revert format changes

* Fix smoke test subworkflow count

* Edit subworkflows for another smoke test

* Update test parquets for covariates

* Collapse covariate join

* Rework subtasks for per-flow customization

* Format

* Semver

* Fix smoke test
This commit is contained in:
Nathan Evans 2024-09-16 12:35:45 -07:00 committed by GitHub
parent 2de302ff0d
commit d22c0e7836
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 126 additions and 13 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Covariate verb collapse."
}

View File

@ -19,15 +19,11 @@ def build_steps(
"""
return [
{
"verb": "select",
"args": {"columns": ["id", "text_unit_id"]},
"input": {"source": "workflow:create_final_covariates"},
},
{
"verb": "aggregate_override",
"verb": "join_text_units_to_covariate_ids",
"args": {
"groupby": ["text_unit_id"],
"aggregations": [
"select_columns": ["id", "text_unit_id"],
"aggregate_groupby": ["text_unit_id"],
"aggregate_aggregations": [
{
"column": "id",
"operation": "array_agg_distinct",
@ -40,5 +36,6 @@ def build_steps(
},
],
},
"input": {"source": "workflow:create_final_covariates"},
},
]

View File

@ -19,7 +19,7 @@ def build_steps(
"""
return [
{
"verb": "join_text_units",
"verb": "join_text_units_to_entity_ids",
"args": {
"select_columns": ["id", "text_unit_ids"],
"unroll_column": "text_unit_ids",

View File

@ -19,7 +19,7 @@ def build_steps(
"""
return [
{
"verb": "join_text_units",
"verb": "join_text_units_to_relationship_ids",
"args": {
"select_columns": ["id", "text_unit_ids"],
"unroll_column": "text_unit_ids",

View File

@ -3,8 +3,12 @@
"""The Indexing Engine workflows -> subflows package root."""
from .join_text_units import join_text_units
from .join_text_units_to_covariate_ids import join_text_units_to_covariate_ids
from .join_text_units_to_entity_ids import join_text_units_to_entity_ids
from .join_text_units_to_relationship_ids import join_text_units_to_relationship_ids
__all__ = [
"join_text_units",
"join_text_units_to_covariate_ids",
"join_text_units_to_entity_ids",
"join_text_units_to_relationship_ids",
]

View File

@ -0,0 +1,27 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""join_text_units_to_covariate_ids verb (subtask)."""
from typing import Any, cast
from datashaper.engine.verbs.verb_input import VerbInput
from datashaper.engine.verbs.verbs_mapping import verb
from datashaper.table_store.types import Table, VerbResult, create_verb_result
from graphrag.index.verbs.overrides.aggregate import aggregate_df
@verb(name="join_text_units_to_covariate_ids", treats_input_tables_as_immutable=True)
def join_text_units_to_covariate_ids(
input: VerbInput,
select_columns: list[str],
aggregate_aggregations: list[dict[str, Any]],
aggregate_groupby: list[str] | None = None,
**_kwargs: dict,
) -> VerbResult:
"""Subtask to select and unroll items using an id."""
table = input.get_input()
selected = cast(Table, table[select_columns])
aggregated = aggregate_df(selected, aggregate_aggregations, aggregate_groupby)
return create_verb_result(aggregated)

View File

@ -0,0 +1,29 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""join_text_units_to_entity_ids verb (subtask)."""
from typing import Any, cast
from datashaper.engine.verbs.verb_input import VerbInput
from datashaper.engine.verbs.verbs_mapping import verb
from datashaper.table_store.types import Table, VerbResult, create_verb_result
from graphrag.index.verbs.overrides.aggregate import aggregate_df
@verb(name="join_text_units_to_entity_ids", treats_input_tables_as_immutable=True)
def join_text_units_to_entity_ids(
input: VerbInput,
select_columns: list[str],
unroll_column: str,
aggregate_aggregations: list[dict[str, Any]],
aggregate_groupby: list[str] | None = None,
**_kwargs: dict,
) -> VerbResult:
"""Subtask to select and unroll items using an id."""
table = input.get_input()
selected = cast(Table, table[select_columns])
unrolled = selected.explode(unroll_column).reset_index(drop=True)
aggregated = aggregate_df(unrolled, aggregate_aggregations, aggregate_groupby)
return create_verb_result(aggregated)

View File

@ -0,0 +1,30 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""join_text_units_to_relationship_ids verb (subtask)."""
from typing import Any, cast
from datashaper.engine.verbs.verb_input import VerbInput
from datashaper.engine.verbs.verbs_mapping import verb
from datashaper.table_store.types import Table, VerbResult, create_verb_result
from graphrag.index.verbs.overrides.aggregate import aggregate_df
@verb(name="join_text_units_to_relationship_ids", treats_input_tables_as_immutable=True)
def join_text_units_to_relationship_ids(
input: VerbInput,
select_columns: list[str],
unroll_column: str,
aggregate_aggregations: list[dict[str, Any]],
aggregate_groupby: list[str] | None = None,
final_select_columns: list[str] | None = None,
**_kwargs: dict,
) -> VerbResult:
"""Subtask to select and unroll items using an id."""
table = input.get_input()
selected = cast(Table, table[select_columns])
unrolled = selected.explode(unroll_column).reset_index(drop=True)
aggregated = aggregate_df(unrolled, aggregate_aggregations, aggregate_groupby)
return create_verb_result(cast(Table, aggregated[final_select_columns]))

View File

@ -50,7 +50,7 @@
1,
2000
],
"subworkflows": 2,
"subworkflows": 1,
"max_runtime": 10
},
"create_base_entity_graph": {

Binary file not shown.

View File

@ -0,0 +1,22 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.index.workflows.v1.join_text_units_to_covariate_ids import build_steps
from .util import compare_outputs, get_workflow_output, load_expected, load_input_tables
async def test_join_text_units_to_covariate_ids():
input_tables = load_input_tables([
"workflow:create_final_covariates",
])
expected = load_expected("join_text_units_to_covariate_ids")
actual = await get_workflow_output(
input_tables,
{
"steps": build_steps({}),
},
)
compare_outputs(actual, expected)