mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 09:07:20 +08:00
Remove aggregate_df from final coomunities and final text units (#1179)
* Remove aggregate_df from final coomunities and final text units * Semver * Ruff and format * Format * Format * Fix tests, ruff and checks * Remove some leftover prints * Removed _final_join method
This commit is contained in:
parent
fbc483e4e5
commit
be7d3eb189
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Remove aggregate_df from final coomunities and final text units"
|
||||
}
|
||||
@ -15,7 +15,6 @@ from datashaper import (
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.verbs.graph.unpack import unpack_graph_df
|
||||
from graphrag.index.verbs.overrides.aggregate import aggregate_df
|
||||
|
||||
|
||||
@verb(name="create_final_communities", treats_input_tables_as_immutable=True)
|
||||
@ -30,54 +29,35 @@ def create_final_communities(
|
||||
graph_nodes = unpack_graph_df(table, callbacks, "clustered_graph", "nodes")
|
||||
graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges")
|
||||
|
||||
# Merge graph_nodes with graph_edges for both source and target matches
|
||||
source_clusters = graph_nodes.merge(
|
||||
graph_edges,
|
||||
left_on="label",
|
||||
right_on="source",
|
||||
how="inner",
|
||||
graph_edges, left_on="label", right_on="source", how="inner"
|
||||
)
|
||||
|
||||
target_clusters = graph_nodes.merge(
|
||||
graph_edges,
|
||||
left_on="label",
|
||||
right_on="target",
|
||||
how="inner",
|
||||
graph_edges, left_on="label", right_on="target", how="inner"
|
||||
)
|
||||
|
||||
concatenated_clusters = pd.concat(
|
||||
[source_clusters, target_clusters], ignore_index=True
|
||||
)
|
||||
# Concatenate the source and target clusters
|
||||
clusters = pd.concat([source_clusters, target_clusters], ignore_index=True)
|
||||
|
||||
# level_x is the left side of the join
|
||||
# level_y is the right side of the join
|
||||
# we only want to keep the clusters that are the same on both sides
|
||||
combined_clusters = concatenated_clusters[
|
||||
concatenated_clusters["level_x"] == concatenated_clusters["level_y"]
|
||||
# Keep only rows where level_x == level_y
|
||||
combined_clusters = clusters[
|
||||
clusters["level_x"] == clusters["level_y"]
|
||||
].reset_index(drop=True)
|
||||
|
||||
cluster_relationships = aggregate_df(
|
||||
cast(Table, combined_clusters),
|
||||
aggregations=[
|
||||
{
|
||||
"column": "id_y", # this is the id of the edge from the join steps above
|
||||
"to": "relationship_ids",
|
||||
"operation": "array_agg_distinct",
|
||||
},
|
||||
{
|
||||
"column": "source_id_x",
|
||||
"to": "text_unit_ids",
|
||||
"operation": "array_agg_distinct",
|
||||
},
|
||||
],
|
||||
groupby=[
|
||||
"cluster",
|
||||
"level_x", # level_x is the left side of the join
|
||||
],
|
||||
cluster_relationships = (
|
||||
combined_clusters.groupby(["cluster", "level_x"], sort=False)
|
||||
.agg(
|
||||
relationship_ids=("id_y", "unique"), text_unit_ids=("source_id_x", "unique")
|
||||
)
|
||||
.reset_index()
|
||||
)
|
||||
|
||||
all_clusters = aggregate_df(
|
||||
graph_nodes,
|
||||
aggregations=[{"column": "cluster", "to": "id", "operation": "any"}],
|
||||
groupby=["cluster", "level"],
|
||||
all_clusters = (
|
||||
graph_nodes.groupby(["cluster", "level"], sort=False)
|
||||
.agg(id=("cluster", "first"))
|
||||
.reset_index()
|
||||
)
|
||||
|
||||
joined = all_clusters.merge(
|
||||
@ -94,14 +74,15 @@ def create_final_communities(
|
||||
return create_verb_result(
|
||||
cast(
|
||||
Table,
|
||||
filtered[
|
||||
filtered.loc[
|
||||
:,
|
||||
[
|
||||
"id",
|
||||
"title",
|
||||
"level",
|
||||
"relationship_ids",
|
||||
"text_unit_ids",
|
||||
]
|
||||
],
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
@ -5,12 +5,11 @@
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper.engine.verbs.verb_input import VerbInput
|
||||
from datashaper.engine.verbs.verbs_mapping import verb
|
||||
from datashaper.table_store.types import Table, VerbResult, create_verb_result
|
||||
|
||||
from graphrag.index.verbs.overrides.aggregate import aggregate_df
|
||||
|
||||
|
||||
@verb(
|
||||
name="create_final_text_units_pre_embedding", treats_input_tables_as_immutable=True
|
||||
@ -21,15 +20,15 @@ def create_final_text_units_pre_embedding(
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform before we embed the text units."""
|
||||
table = input.get_input()
|
||||
table = cast(pd.DataFrame, input.get_input())
|
||||
others = input.get_others()
|
||||
|
||||
selected = cast(Table, table[["id", "chunk", "document_ids", "n_tokens"]]).rename(
|
||||
selected = table.loc[:, ["id", "chunk", "document_ids", "n_tokens"]].rename(
|
||||
columns={"chunk": "text"}
|
||||
)
|
||||
|
||||
final_entities = others[0]
|
||||
final_relationships = others[1]
|
||||
final_entities = cast(pd.DataFrame, others[0])
|
||||
final_relationships = cast(pd.DataFrame, others[1])
|
||||
entity_join = _entities(final_entities)
|
||||
relationship_join = _relationships(final_relationships)
|
||||
|
||||
@ -38,116 +37,47 @@ def create_final_text_units_pre_embedding(
|
||||
final_joined = relationship_joined
|
||||
|
||||
if covariates_enabled:
|
||||
final_covariates = others[2]
|
||||
final_covariates = cast(pd.DataFrame, others[2])
|
||||
covariate_join = _covariates(final_covariates)
|
||||
final_joined = _join(relationship_joined, covariate_join)
|
||||
|
||||
aggregated = _final_aggregation(final_joined, covariates_enabled)
|
||||
aggregated = final_joined.groupby("id", sort=False).agg("first").reset_index()
|
||||
|
||||
return create_verb_result(aggregated)
|
||||
return create_verb_result(cast(Table, aggregated))
|
||||
|
||||
|
||||
def _final_aggregation(table, covariates_enabled):
|
||||
aggregations = [
|
||||
{
|
||||
"column": "text",
|
||||
"operation": "any",
|
||||
"to": "text",
|
||||
},
|
||||
{
|
||||
"column": "n_tokens",
|
||||
"operation": "any",
|
||||
"to": "n_tokens",
|
||||
},
|
||||
{
|
||||
"column": "document_ids",
|
||||
"operation": "any",
|
||||
"to": "document_ids",
|
||||
},
|
||||
{
|
||||
"column": "entity_ids",
|
||||
"operation": "any",
|
||||
"to": "entity_ids",
|
||||
},
|
||||
{
|
||||
"column": "relationship_ids",
|
||||
"operation": "any",
|
||||
"to": "relationship_ids",
|
||||
},
|
||||
]
|
||||
if covariates_enabled:
|
||||
aggregations.append({
|
||||
"column": "covariate_ids",
|
||||
"operation": "any",
|
||||
"to": "covariate_ids",
|
||||
})
|
||||
return aggregate_df(
|
||||
table,
|
||||
aggregations,
|
||||
["id"],
|
||||
def _entities(df: pd.DataFrame) -> pd.DataFrame:
|
||||
selected = df.loc[:, ["id", "text_unit_ids"]]
|
||||
unrolled = selected.explode(["text_unit_ids"]).reset_index(drop=True)
|
||||
|
||||
return (
|
||||
unrolled.groupby("text_unit_ids", sort=False)
|
||||
.agg(entity_ids=("id", "unique"))
|
||||
.reset_index()
|
||||
.rename(columns={"text_unit_ids": "id"})
|
||||
)
|
||||
|
||||
|
||||
def _entities(table):
|
||||
selected = cast(Table, table[["id", "text_unit_ids"]])
|
||||
unrolled = selected.explode("text_unit_ids").reset_index(drop=True)
|
||||
return aggregate_df(
|
||||
unrolled,
|
||||
[
|
||||
{
|
||||
"column": "id",
|
||||
"operation": "array_agg_distinct",
|
||||
"to": "entity_ids",
|
||||
},
|
||||
{
|
||||
"column": "text_unit_ids",
|
||||
"operation": "any",
|
||||
"to": "id",
|
||||
},
|
||||
],
|
||||
["text_unit_ids"],
|
||||
def _relationships(df: pd.DataFrame) -> pd.DataFrame:
|
||||
selected = df.loc[:, ["id", "text_unit_ids"]]
|
||||
unrolled = selected.explode(["text_unit_ids"]).reset_index(drop=True)
|
||||
|
||||
return (
|
||||
unrolled.groupby("text_unit_ids", sort=False)
|
||||
.agg(relationship_ids=("id", "unique"))
|
||||
.reset_index()
|
||||
.rename(columns={"text_unit_ids": "id"})
|
||||
)
|
||||
|
||||
|
||||
def _relationships(table):
|
||||
selected = cast(Table, table[["id", "text_unit_ids"]])
|
||||
unrolled = selected.explode("text_unit_ids").reset_index(drop=True)
|
||||
aggregated = aggregate_df(
|
||||
unrolled,
|
||||
[
|
||||
{
|
||||
"column": "id",
|
||||
"operation": "array_agg_distinct",
|
||||
"to": "relationship_ids",
|
||||
},
|
||||
{
|
||||
"column": "text_unit_ids",
|
||||
"operation": "any",
|
||||
"to": "id",
|
||||
},
|
||||
],
|
||||
["text_unit_ids"],
|
||||
)
|
||||
return aggregated[["id", "relationship_ids"]]
|
||||
def _covariates(df: pd.DataFrame) -> pd.DataFrame:
|
||||
selected = df.loc[:, ["id", "text_unit_id"]]
|
||||
|
||||
|
||||
def _covariates(table):
|
||||
selected = cast(Table, table[["id", "text_unit_id"]])
|
||||
return aggregate_df(
|
||||
selected,
|
||||
[
|
||||
{
|
||||
"column": "id",
|
||||
"operation": "array_agg_distinct",
|
||||
"to": "covariate_ids",
|
||||
},
|
||||
{
|
||||
"column": "text_unit_id",
|
||||
"operation": "any",
|
||||
"to": "id",
|
||||
},
|
||||
],
|
||||
["text_unit_id"],
|
||||
return (
|
||||
selected.groupby("text_unit_id", sort=False)
|
||||
.agg(covariate_ids=("id", "unique"))
|
||||
.reset_index()
|
||||
.rename(columns={"text_unit_id": "id"})
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user