mirror of
https://github.com/microsoft/graphrag.git
synced 2026-02-18 00:35:44 +08:00
Collapse relationship embeddings (#1199)
* Merge text_embed into a single relationships subflow * Update smoke tests * Semver * Spelling
This commit is contained in:
parent
1755afbdec
commit
f518c8b80b
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Merge text_embed into create-final-relationships subflow."
|
||||
}
|
||||
@ -79,6 +79,23 @@ async def text_embed(
|
||||
<...>
|
||||
```
|
||||
"""
|
||||
input_df = cast(pd.DataFrame, input.get_input())
|
||||
result_df = await text_embed_df(
|
||||
input_df, callbacks, cache, column, strategy, **kwargs
|
||||
)
|
||||
return TableContainer(table=result_df)
|
||||
|
||||
|
||||
# TODO: this ultimately just creates a new column, so our embed function could just generate a series instead of updating the dataframe
|
||||
async def text_embed_df(
|
||||
input: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
column: str,
|
||||
strategy: dict,
|
||||
**kwargs,
|
||||
):
|
||||
"""Embed a piece of text into a vector space."""
|
||||
vector_store_config = strategy.get("vector_store")
|
||||
|
||||
if vector_store_config:
|
||||
@ -113,28 +130,28 @@ async def text_embed(
|
||||
|
||||
|
||||
async def _text_embed_in_memory(
|
||||
input: VerbInput,
|
||||
input: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
column: str,
|
||||
strategy: dict,
|
||||
to: str,
|
||||
):
|
||||
output_df = cast(pd.DataFrame, input.get_input())
|
||||
output_df = input
|
||||
strategy_type = strategy["type"]
|
||||
strategy_exec = load_strategy(strategy_type)
|
||||
strategy_args = {**strategy}
|
||||
input_table = input.get_input()
|
||||
input_table = input
|
||||
|
||||
texts: list[str] = input_table[column].to_numpy().tolist()
|
||||
result = await strategy_exec(texts, callbacks, cache, strategy_args)
|
||||
|
||||
output_df[to] = result.embeddings
|
||||
return TableContainer(table=output_df)
|
||||
return output_df
|
||||
|
||||
|
||||
async def _text_embed_with_vector_store(
|
||||
input: VerbInput,
|
||||
input: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
column: str,
|
||||
@ -144,7 +161,7 @@ async def _text_embed_with_vector_store(
|
||||
store_in_table: bool = False,
|
||||
to: str = "",
|
||||
):
|
||||
output_df = cast(pd.DataFrame, input.get_input())
|
||||
output_df = input
|
||||
strategy_type = strategy["type"]
|
||||
strategy_exec = load_strategy(strategy_type)
|
||||
strategy_args = {**strategy}
|
||||
@ -179,10 +196,8 @@ async def _text_embed_with_vector_store(
|
||||
|
||||
all_results = []
|
||||
|
||||
while insert_batch_size * i < input.get_input().shape[0]:
|
||||
batch = input.get_input().iloc[
|
||||
insert_batch_size * i : insert_batch_size * (i + 1)
|
||||
]
|
||||
while insert_batch_size * i < input.shape[0]:
|
||||
batch = input.iloc[insert_batch_size * i : insert_batch_size * (i + 1)]
|
||||
texts: list[str] = batch[column].to_numpy().tolist()
|
||||
titles: list[str] = batch[title_column].to_numpy().tolist()
|
||||
ids: list[str] = batch[id_column].to_numpy().tolist()
|
||||
@ -218,7 +233,7 @@ async def _text_embed_with_vector_store(
|
||||
if store_in_table:
|
||||
output_df[to] = all_results
|
||||
|
||||
return TableContainer(table=output_df)
|
||||
return output_df
|
||||
|
||||
|
||||
def _create_vector_store(
|
||||
|
||||
@ -23,30 +23,15 @@ def build_steps(
|
||||
"relationship_description_embed", base_text_embed
|
||||
)
|
||||
skip_description_embedding = config.get("skip_description_embedding", False)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": "pre_embedding",
|
||||
"verb": "create_final_relationships_pre_embedding",
|
||||
"input": {"source": "workflow:create_base_entity_graph"},
|
||||
},
|
||||
{
|
||||
"id": "description_embedding",
|
||||
"verb": "text_embed",
|
||||
"enabled": not skip_description_embedding,
|
||||
"verb": "create_final_relationships",
|
||||
"args": {
|
||||
"embedding_name": "relationship_description",
|
||||
"column": "description",
|
||||
"to": "description_embedding",
|
||||
**relationship_description_embed_config,
|
||||
"skip_embedding": skip_description_embedding,
|
||||
"text_embed": relationship_description_embed_config,
|
||||
},
|
||||
},
|
||||
{
|
||||
"verb": "create_final_relationships_post_embedding",
|
||||
"input": {
|
||||
"source": "pre_embedding"
|
||||
if skip_description_embedding
|
||||
else "description_embedding",
|
||||
"source": "workflow:create_base_entity_graph",
|
||||
"nodes": "workflow:create_final_nodes",
|
||||
},
|
||||
},
|
||||
|
||||
@ -7,11 +7,8 @@ from .create_base_documents import create_base_documents
|
||||
from .create_base_text_units import create_base_text_units
|
||||
from .create_final_communities import create_final_communities
|
||||
from .create_final_nodes import create_final_nodes
|
||||
from .create_final_relationships_post_embedding import (
|
||||
create_final_relationships_post_embedding,
|
||||
)
|
||||
from .create_final_relationships_pre_embedding import (
|
||||
create_final_relationships_pre_embedding,
|
||||
from .create_final_relationships import (
|
||||
create_final_relationships,
|
||||
)
|
||||
from .create_final_text_units_pre_embedding import create_final_text_units_pre_embedding
|
||||
|
||||
@ -20,7 +17,6 @@ __all__ = [
|
||||
"create_base_text_units",
|
||||
"create_final_communities",
|
||||
"create_final_nodes",
|
||||
"create_final_relationships_post_embedding",
|
||||
"create_final_relationships_pre_embedding",
|
||||
"create_final_relationships",
|
||||
"create_final_text_units_pre_embedding",
|
||||
]
|
||||
|
||||
@ -1,37 +1,64 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All the steps to transform final relationships after they are embedded."""
|
||||
"""All the steps to transform final relationships before they are embedded."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
Table,
|
||||
VerbCallbacks,
|
||||
VerbInput,
|
||||
verb,
|
||||
)
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.utils.ds_util import get_required_input_table
|
||||
from graphrag.index.verbs.graph.compute_edge_combined_degree import (
|
||||
compute_edge_combined_degree_df,
|
||||
)
|
||||
from graphrag.index.verbs.graph.unpack import unpack_graph_df
|
||||
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
|
||||
|
||||
|
||||
@verb(
|
||||
name="create_final_relationships_post_embedding",
|
||||
name="create_final_relationships",
|
||||
treats_input_tables_as_immutable=True,
|
||||
)
|
||||
def create_final_relationships_post_embedding(
|
||||
async def create_final_relationships(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
text_embed: dict,
|
||||
skip_embedding: bool = False,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform final relationships after they are embedded."""
|
||||
"""All the steps to transform final relationships before they are embedded."""
|
||||
table = cast(pd.DataFrame, input.get_input())
|
||||
nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table)
|
||||
|
||||
pruned_edges = table.drop(columns=["level"])
|
||||
graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges")
|
||||
|
||||
graph_edges.rename(columns={"source_id": "text_unit_ids"}, inplace=True)
|
||||
|
||||
filtered = cast(
|
||||
pd.DataFrame, graph_edges[graph_edges["level"] == 0].reset_index(drop=True)
|
||||
)
|
||||
|
||||
if not skip_embedding:
|
||||
filtered = await text_embed_df(
|
||||
filtered,
|
||||
callbacks,
|
||||
cache,
|
||||
column="description",
|
||||
strategy=text_embed["strategy"],
|
||||
to="description_embedding",
|
||||
embedding_name="relationship_description",
|
||||
)
|
||||
|
||||
pruned_edges = filtered.drop(columns=["level"])
|
||||
|
||||
filtered_nodes = cast(
|
||||
pd.DataFrame,
|
||||
@ -1,38 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All the steps to transform final relationships before they are embedded."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
Table,
|
||||
VerbCallbacks,
|
||||
VerbInput,
|
||||
verb,
|
||||
)
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.verbs.graph.unpack import unpack_graph_df
|
||||
|
||||
|
||||
@verb(
|
||||
name="create_final_relationships_pre_embedding",
|
||||
treats_input_tables_as_immutable=True,
|
||||
)
|
||||
def create_final_relationships_pre_embedding(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform final relationships before they are embedded."""
|
||||
table = cast(pd.DataFrame, input.get_input())
|
||||
|
||||
graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges")
|
||||
|
||||
graph_edges.rename(columns={"source_id": "text_unit_ids"}, inplace=True)
|
||||
|
||||
filtered = graph_edges[graph_edges["level"] == 0].reset_index(drop=True)
|
||||
|
||||
return create_verb_result(cast(Table, filtered))
|
||||
2
tests/fixtures/min-csv/config.json
vendored
2
tests/fixtures/min-csv/config.json
vendored
@ -52,7 +52,7 @@
|
||||
1,
|
||||
2000
|
||||
],
|
||||
"subworkflows": 2,
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 100
|
||||
},
|
||||
"create_final_nodes": {
|
||||
|
||||
2
tests/fixtures/text/config.json
vendored
2
tests/fixtures/text/config.json
vendored
@ -71,7 +71,7 @@
|
||||
1,
|
||||
2000
|
||||
],
|
||||
"subworkflows": 2,
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 100
|
||||
},
|
||||
"create_final_nodes": {
|
||||
|
||||
@ -37,3 +37,32 @@ async def test_create_final_relationships():
|
||||
)
|
||||
|
||||
compare_outputs(actual, expected)
|
||||
|
||||
|
||||
async def test_create_final_relationships_with_embeddings():
|
||||
input_tables = load_input_tables([
|
||||
"workflow:create_base_entity_graph",
|
||||
"workflow:create_final_nodes",
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["skip_description_embedding"] = False
|
||||
# default config has a detailed standard embed config
|
||||
# just override the strategy to mock so the rest of the required parameters are in place
|
||||
config["relationship_description_embed"]["strategy"]["type"] = "mock"
|
||||
|
||||
steps = remove_disabled_steps(build_steps(config))
|
||||
|
||||
actual = await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
assert "description_embedding" in actual.columns
|
||||
assert len(actual.columns) == len(expected.columns) + 1
|
||||
# the mock impl returns an array of 3 floats for each embedding
|
||||
assert len(actual["description_embedding"][0]) == 3
|
||||
|
||||
@ -13,6 +13,7 @@ from graphrag.index import (
|
||||
PipelineWorkflowStep,
|
||||
create_pipeline_config,
|
||||
)
|
||||
from graphrag.index.run.utils import _create_run_context
|
||||
|
||||
|
||||
def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]:
|
||||
@ -61,7 +62,9 @@ async def get_workflow_output(
|
||||
input_tables=input_tables,
|
||||
)
|
||||
|
||||
await workflow.run()
|
||||
context = _create_run_context(None, None, None)
|
||||
|
||||
await workflow.run(context=context)
|
||||
|
||||
# if there's only one output, it is the default here, no name required
|
||||
return cast(pd.DataFrame, workflow.output())
|
||||
|
||||
Loading…
Reference in New Issue
Block a user