mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 09:07:20 +08:00
Covariate collapse (#1142)
* Setup basic verb test runner * Replace join_text_units_to_entity_ids with subflow * Update comments * Replace join_text_units_to_relationship_ids subflow * Roll in final select * Reuse assertion util * Small fix + format * Format/typing * Semver * Format/typing * Semver * Revert format changes * Fix smoke test subworkflow count * Edit subworkflows for another smoke test * Update test parquets for covariates * Collapse covariate join * Rework subtasks for per-flow customization * Format * Semver * Fix smoke test
This commit is contained in:
parent
2de302ff0d
commit
d22c0e7836
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Covariate verb collapse."
|
||||
}
|
||||
@ -19,15 +19,11 @@ def build_steps(
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"verb": "select",
|
||||
"args": {"columns": ["id", "text_unit_id"]},
|
||||
"input": {"source": "workflow:create_final_covariates"},
|
||||
},
|
||||
{
|
||||
"verb": "aggregate_override",
|
||||
"verb": "join_text_units_to_covariate_ids",
|
||||
"args": {
|
||||
"groupby": ["text_unit_id"],
|
||||
"aggregations": [
|
||||
"select_columns": ["id", "text_unit_id"],
|
||||
"aggregate_groupby": ["text_unit_id"],
|
||||
"aggregate_aggregations": [
|
||||
{
|
||||
"column": "id",
|
||||
"operation": "array_agg_distinct",
|
||||
@ -40,5 +36,6 @@ def build_steps(
|
||||
},
|
||||
],
|
||||
},
|
||||
"input": {"source": "workflow:create_final_covariates"},
|
||||
},
|
||||
]
|
||||
|
||||
@ -19,7 +19,7 @@ def build_steps(
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"verb": "join_text_units",
|
||||
"verb": "join_text_units_to_entity_ids",
|
||||
"args": {
|
||||
"select_columns": ["id", "text_unit_ids"],
|
||||
"unroll_column": "text_unit_ids",
|
||||
|
||||
@ -19,7 +19,7 @@ def build_steps(
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"verb": "join_text_units",
|
||||
"verb": "join_text_units_to_relationship_ids",
|
||||
"args": {
|
||||
"select_columns": ["id", "text_unit_ids"],
|
||||
"unroll_column": "text_unit_ids",
|
||||
|
||||
@ -3,8 +3,12 @@
|
||||
|
||||
"""The Indexing Engine workflows -> subflows package root."""
|
||||
|
||||
from .join_text_units import join_text_units
|
||||
from .join_text_units_to_covariate_ids import join_text_units_to_covariate_ids
|
||||
from .join_text_units_to_entity_ids import join_text_units_to_entity_ids
|
||||
from .join_text_units_to_relationship_ids import join_text_units_to_relationship_ids
|
||||
|
||||
__all__ = [
|
||||
"join_text_units",
|
||||
"join_text_units_to_covariate_ids",
|
||||
"join_text_units_to_entity_ids",
|
||||
"join_text_units_to_relationship_ids",
|
||||
]
|
||||
|
||||
@ -0,0 +1,27 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""join_text_units_to_covariate_ids verb (subtask)."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from datashaper.engine.verbs.verb_input import VerbInput
|
||||
from datashaper.engine.verbs.verbs_mapping import verb
|
||||
from datashaper.table_store.types import Table, VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.verbs.overrides.aggregate import aggregate_df
|
||||
|
||||
|
||||
@verb(name="join_text_units_to_covariate_ids", treats_input_tables_as_immutable=True)
|
||||
def join_text_units_to_covariate_ids(
|
||||
input: VerbInput,
|
||||
select_columns: list[str],
|
||||
aggregate_aggregations: list[dict[str, Any]],
|
||||
aggregate_groupby: list[str] | None = None,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""Subtask to select and unroll items using an id."""
|
||||
table = input.get_input()
|
||||
selected = cast(Table, table[select_columns])
|
||||
aggregated = aggregate_df(selected, aggregate_aggregations, aggregate_groupby)
|
||||
return create_verb_result(aggregated)
|
||||
@ -0,0 +1,29 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""join_text_units_to_entity_ids verb (subtask)."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from datashaper.engine.verbs.verb_input import VerbInput
|
||||
from datashaper.engine.verbs.verbs_mapping import verb
|
||||
from datashaper.table_store.types import Table, VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.verbs.overrides.aggregate import aggregate_df
|
||||
|
||||
|
||||
@verb(name="join_text_units_to_entity_ids", treats_input_tables_as_immutable=True)
|
||||
def join_text_units_to_entity_ids(
|
||||
input: VerbInput,
|
||||
select_columns: list[str],
|
||||
unroll_column: str,
|
||||
aggregate_aggregations: list[dict[str, Any]],
|
||||
aggregate_groupby: list[str] | None = None,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""Subtask to select and unroll items using an id."""
|
||||
table = input.get_input()
|
||||
selected = cast(Table, table[select_columns])
|
||||
unrolled = selected.explode(unroll_column).reset_index(drop=True)
|
||||
aggregated = aggregate_df(unrolled, aggregate_aggregations, aggregate_groupby)
|
||||
return create_verb_result(aggregated)
|
||||
@ -0,0 +1,30 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""join_text_units_to_relationship_ids verb (subtask)."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from datashaper.engine.verbs.verb_input import VerbInput
|
||||
from datashaper.engine.verbs.verbs_mapping import verb
|
||||
from datashaper.table_store.types import Table, VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.verbs.overrides.aggregate import aggregate_df
|
||||
|
||||
|
||||
@verb(name="join_text_units_to_relationship_ids", treats_input_tables_as_immutable=True)
|
||||
def join_text_units_to_relationship_ids(
|
||||
input: VerbInput,
|
||||
select_columns: list[str],
|
||||
unroll_column: str,
|
||||
aggregate_aggregations: list[dict[str, Any]],
|
||||
aggregate_groupby: list[str] | None = None,
|
||||
final_select_columns: list[str] | None = None,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""Subtask to select and unroll items using an id."""
|
||||
table = input.get_input()
|
||||
selected = cast(Table, table[select_columns])
|
||||
unrolled = selected.explode(unroll_column).reset_index(drop=True)
|
||||
aggregated = aggregate_df(unrolled, aggregate_aggregations, aggregate_groupby)
|
||||
return create_verb_result(cast(Table, aggregated[final_select_columns]))
|
||||
2
tests/fixtures/text/config.json
vendored
2
tests/fixtures/text/config.json
vendored
@ -50,7 +50,7 @@
|
||||
1,
|
||||
2000
|
||||
],
|
||||
"subworkflows": 2,
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 10
|
||||
},
|
||||
"create_base_entity_graph": {
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
tests/verbs/data/create_final_covariates.parquet
Normal file
BIN
tests/verbs/data/create_final_covariates.parquet
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
tests/verbs/data/join_text_units_to_covariate_ids.parquet
Normal file
BIN
tests/verbs/data/join_text_units_to_covariate_ids.parquet
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
22
tests/verbs/test_join_text_units_to_covariate_ids.py
Normal file
22
tests/verbs/test_join_text_units_to_covariate_ids.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.index.workflows.v1.join_text_units_to_covariate_ids import build_steps
|
||||
|
||||
from .util import compare_outputs, get_workflow_output, load_expected, load_input_tables
|
||||
|
||||
|
||||
async def test_join_text_units_to_covariate_ids():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_final_covariates",
|
||||
])
|
||||
expected = load_expected("join_text_units_to_covariate_ids")
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": build_steps({}),
|
||||
},
|
||||
)
|
||||
|
||||
compare_outputs(actual, expected)
|
||||
Loading…
Reference in New Issue
Block a user