mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 09:07:20 +08:00
Revisit create final text units (#1216)
* Add embeddings to collapsed subflow * Semver * Fix smoke tests
This commit is contained in:
parent
73e709b686
commit
3217013019
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Add embeddings to subflow."
|
||||
}
|
||||
@ -21,12 +21,8 @@ def build_steps(
|
||||
"""
|
||||
base_text_embed = config.get("text_embed", {})
|
||||
text_unit_text_embed_config = config.get("text_unit_text_embed", base_text_embed)
|
||||
covariates_enabled = config.get("covariates_enabled", False)
|
||||
skip_text_unit_embedding = config.get("skip_text_unit_embedding", False)
|
||||
is_using_vector_store = (
|
||||
text_unit_text_embed_config.get("strategy", {}).get("vector_store", None)
|
||||
is not None
|
||||
)
|
||||
covariates_enabled = config.get("covariates_enabled", False)
|
||||
|
||||
others = [
|
||||
"workflow:create_final_entities",
|
||||
@ -37,8 +33,10 @@ def build_steps(
|
||||
|
||||
return [
|
||||
{
|
||||
"verb": "create_final_text_units_pre_embedding",
|
||||
"verb": "create_final_text_units",
|
||||
"args": {
|
||||
"skip_embedding": skip_text_unit_embedding,
|
||||
"text_embed": text_unit_text_embed_config,
|
||||
"covariates_enabled": covariates_enabled,
|
||||
},
|
||||
"input": {
|
||||
@ -46,35 +44,4 @@ def build_steps(
|
||||
"others": others,
|
||||
},
|
||||
},
|
||||
# Text-Embed after final aggregations
|
||||
{
|
||||
"id": "embedded_text_units",
|
||||
"verb": "text_embed",
|
||||
"enabled": not skip_text_unit_embedding,
|
||||
"args": {
|
||||
"column": config.get("column", "text"),
|
||||
"to": config.get("to", "text_embedding"),
|
||||
**text_unit_text_embed_config,
|
||||
},
|
||||
},
|
||||
{
|
||||
"verb": "select",
|
||||
"args": {
|
||||
# Final select to get output in the correct shape
|
||||
"columns": [
|
||||
"id",
|
||||
"text",
|
||||
*(
|
||||
[]
|
||||
if (skip_text_unit_embedding or is_using_vector_store)
|
||||
else ["text_embedding"]
|
||||
),
|
||||
"n_tokens",
|
||||
"document_ids",
|
||||
"entity_ids",
|
||||
"relationship_ids",
|
||||
*([] if not covariates_enabled else ["covariate_ids"]),
|
||||
],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
@ -12,7 +12,7 @@ from .create_final_nodes import create_final_nodes
|
||||
from .create_final_relationships import (
|
||||
create_final_relationships,
|
||||
)
|
||||
from .create_final_text_units_pre_embedding import create_final_text_units_pre_embedding
|
||||
from .create_final_text_units import create_final_text_units
|
||||
|
||||
__all__ = [
|
||||
"create_base_documents",
|
||||
@ -22,5 +22,5 @@ __all__ = [
|
||||
"create_final_documents",
|
||||
"create_final_nodes",
|
||||
"create_final_relationships",
|
||||
"create_final_text_units_pre_embedding",
|
||||
"create_final_text_units",
|
||||
]
|
||||
|
||||
@ -1,25 +1,35 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All the steps to transform before we embed the text units."""
|
||||
"""All the steps to transform the text units."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper.engine.verbs.verb_input import VerbInput
|
||||
from datashaper.engine.verbs.verbs_mapping import verb
|
||||
from datashaper.table_store.types import Table, VerbResult, create_verb_result
|
||||
|
||||
|
||||
@verb(
|
||||
name="create_final_text_units_pre_embedding", treats_input_tables_as_immutable=True
|
||||
from datashaper import (
|
||||
Table,
|
||||
VerbCallbacks,
|
||||
VerbInput,
|
||||
VerbResult,
|
||||
create_verb_result,
|
||||
verb,
|
||||
)
|
||||
def create_final_text_units_pre_embedding(
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
|
||||
|
||||
|
||||
@verb(name="create_final_text_units", treats_input_tables_as_immutable=True)
|
||||
async def create_final_text_units(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
text_embed: dict,
|
||||
skip_embedding: bool = False,
|
||||
covariates_enabled: bool = False,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform before we embed the text units."""
|
||||
"""All the steps to transform the text units."""
|
||||
table = cast(pd.DataFrame, input.get_input())
|
||||
others = input.get_others()
|
||||
|
||||
@ -43,7 +53,33 @@ def create_final_text_units_pre_embedding(
|
||||
|
||||
aggregated = final_joined.groupby("id", sort=False).agg("first").reset_index()
|
||||
|
||||
return create_verb_result(cast(Table, aggregated))
|
||||
if not skip_embedding:
|
||||
aggregated = await text_embed_df(
|
||||
aggregated,
|
||||
callbacks,
|
||||
cache,
|
||||
column="text",
|
||||
strategy=text_embed["strategy"],
|
||||
to="text_embedding",
|
||||
)
|
||||
|
||||
is_using_vector_store = (
|
||||
text_embed.get("strategy", {}).get("vector_store", None) is not None
|
||||
)
|
||||
|
||||
final = aggregated[
|
||||
[
|
||||
"id",
|
||||
"text",
|
||||
*([] if (skip_embedding or is_using_vector_store) else ["text_embedding"]),
|
||||
"n_tokens",
|
||||
"document_ids",
|
||||
"entity_ids",
|
||||
"relationship_ids",
|
||||
*([] if not covariates_enabled else ["covariate_ids"]),
|
||||
]
|
||||
]
|
||||
return create_verb_result(cast(Table, final))
|
||||
|
||||
|
||||
def _entities(df: pd.DataFrame) -> pd.DataFrame:
|
||||
2
tests/fixtures/min-csv/config.json
vendored
2
tests/fixtures/min-csv/config.json
vendored
@ -105,7 +105,7 @@
|
||||
"relationship_ids",
|
||||
"entity_ids"
|
||||
],
|
||||
"subworkflows": 2,
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 100
|
||||
},
|
||||
"create_base_documents": {
|
||||
|
||||
2
tests/fixtures/text/config.json
vendored
2
tests/fixtures/text/config.json
vendored
@ -122,7 +122,7 @@
|
||||
"relationship_ids",
|
||||
"entity_ids"
|
||||
],
|
||||
"subworkflows": 2,
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 100
|
||||
},
|
||||
"create_base_documents": {
|
||||
|
||||
@ -71,3 +71,35 @@ async def test_create_final_text_units_no_covariates():
|
||||
expected,
|
||||
["id", "text", "n_tokens", "document_ids", "entity_ids", "relationship_ids"],
|
||||
)
|
||||
|
||||
|
||||
async def test_create_final_text_units_with_embeddings():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_text_units",
|
||||
"workflow:create_final_entities",
|
||||
"workflow:create_final_relationships",
|
||||
"workflow:create_final_covariates",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["covariates_enabled"] = True
|
||||
config["skip_text_unit_embedding"] = False
|
||||
# default config has a detailed standard embed config
|
||||
# just override the strategy to mock so the rest of the required parameters are in place
|
||||
config["text_unit_text_embed"]["strategy"]["type"] = "mock"
|
||||
|
||||
steps = remove_disabled_steps(build_steps(config))
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
assert "text_embedding" in actual.columns
|
||||
assert len(actual.columns) == len(expected.columns) + 1
|
||||
# the mock impl returns an array of 3 floats for each embedding
|
||||
assert len(actual["text_embedding"][0]) == 3
|
||||
|
||||
Loading…
Reference in New Issue
Block a user