NLP graph parity (#1888)

* Update stopwords config

* Minor edits

* Update PMI

* Format

* Perf improvements

* Semver

* Remove edge collection apply

* Remove source/target apply

* Add edge weight to graph snapshot

* Revert breaking optimizations

* Add perf fixes back in

* Format/types

* Update defaults

* Fix source/target ordering

* Fix test
This commit is contained in:
Nathan Evans 2025-04-25 16:09:06 -07:00 committed by GitHub
parent 25b605b6cd
commit 56e0fad218
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 46 additions and 51 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Brings parity with our latest NLP extraction approaches."
}

View File

@ -251,7 +251,7 @@ Parameters for manual graph pruning. This can be used to optimize the modularity
- max_node_freq_std **float | None** - The maximum standard deviation of node frequency to allow.
- min_node_degree **int** - The minimum node degree to allow.
- max_node_degree_std **float | None** - The maximum standard deviation of node degree to allow.
- min_edge_weight_pct **int** - The minimum edge weight percentile to allow.
- min_edge_weight_pct **float** - The minimum edge weight percentile to allow.
- remove_ego_nodes **bool** - Remove ego nodes.
- lcc_only **bool** - Only use largest connected component.

View File

@ -19,6 +19,9 @@ from graphrag.config.enums import (
ReportingType,
TextEmbeddingTarget,
)
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
EN_STOP_WORDS,
)
from graphrag.vector_stores.factory import VectorStoreType
DEFAULT_OUTPUT_BASE_DIR = "output"
@ -186,7 +189,7 @@ class TextAnalyzerDefaults:
max_word_length: int = 15
word_delimiter: str = " "
include_named_entities: bool = True
exclude_nouns: None = None
exclude_nouns: list[str] = field(default_factory=lambda: EN_STOP_WORDS)
exclude_entity_tags: list[str] = field(default_factory=lambda: ["DATE"])
exclude_pos_tags: list[str] = field(
default_factory=lambda: ["DET", "PRON", "INTJ", "X"]
@ -317,8 +320,8 @@ class PruneGraphDefaults:
max_node_freq_std: None = None
min_node_degree: int = 1
max_node_degree_std: None = None
min_edge_weight_pct: int = 40
remove_ego_nodes: bool = False
min_edge_weight_pct: float = 40.0
remove_ego_nodes: bool = True
lcc_only: bool = False

View File

@ -3,8 +3,9 @@
"""Graph extraction using NLP."""
import math
from itertools import combinations
import numpy as np
import pandas as pd
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
@ -30,7 +31,6 @@ async def build_noun_graph(
text_units, text_analyzer, num_threads=num_threads, cache=cache
)
edges_df = _extract_edges(nodes_df, normalize_edge_weights=normalize_edge_weights)
return (nodes_df, edges_df)
@ -69,7 +69,7 @@ async def _extract_nodes(
noun_node_df = text_unit_df.explode("noun_phrases")
noun_node_df = noun_node_df.rename(
columns={"noun_phrases": "title", "id": "text_unit_id"}
).drop_duplicates()
)
# group by title and count the number of text units
grouped_node_df = (
@ -94,45 +94,37 @@ def _extract_edges(
"""
text_units_df = nodes_df.explode("text_unit_ids")
text_units_df = text_units_df.rename(columns={"text_unit_ids": "text_unit_id"})
text_units_df = (
text_units_df.groupby("text_unit_id").agg({"title": list}).reset_index()
)
text_units_df["edges"] = text_units_df["title"].apply(
lambda x: _create_relationships(x)
)
edge_df = text_units_df.explode("edges").loc[:, ["edges", "text_unit_id"]]
edge_df["source"] = edge_df["edges"].apply(
lambda x: x[0] if isinstance(x, tuple) else None
text_units_df = (
text_units_df.groupby("text_unit_id")
.agg({"title": lambda x: list(x) if len(x) > 1 else np.nan})
.reset_index()
)
edge_df["target"] = edge_df["edges"].apply(
lambda x: x[1] if isinstance(x, tuple) else None
text_units_df = text_units_df.dropna()
titles = text_units_df["title"].tolist()
all_edges: list[list[tuple[str, str]]] = [list(combinations(t, 2)) for t in titles]
text_units_df = text_units_df.assign(edges=all_edges) # type: ignore
edge_df = text_units_df.explode("edges")[["edges", "text_unit_id"]]
edge_df[["source", "target"]] = edge_df.loc[:, "edges"].to_list()
edge_df["min_source"] = edge_df[["source", "target"]].min(axis=1)
edge_df["max_target"] = edge_df[["source", "target"]].max(axis=1)
edge_df = edge_df.drop(columns=["source", "target"]).rename(
columns={"min_source": "source", "max_target": "target"} # type: ignore
)
edge_df = edge_df[(edge_df.source.notna()) & (edge_df.target.notna())]
edge_df = edge_df.drop(columns=["edges"])
# make sure source is always smaller than target
edge_df["source"], edge_df["target"] = zip(
*edge_df.apply(
lambda x: (x["source"], x["target"])
if x["source"] < x["target"]
else (x["target"], x["source"]),
axis=1,
),
strict=False,
)
# group by source and target, count the number of text units and collect their ids
# group by source and target, count the number of text units
grouped_edge_df = (
edge_df.groupby(["source", "target"]).agg({"text_unit_id": list}).reset_index()
)
grouped_edge_df = grouped_edge_df.rename(columns={"text_unit_id": "text_unit_ids"})
grouped_edge_df["weight"] = grouped_edge_df["text_unit_ids"].apply(len)
grouped_edge_df = grouped_edge_df.loc[
:, ["source", "target", "weight", "text_unit_ids"]
]
if normalize_edge_weights:
# use PMI weight instead of raw weight
grouped_edge_df = _calculate_pmi_edge_weights(nodes_df, grouped_edge_df)
@ -140,18 +132,6 @@ def _extract_edges(
return grouped_edge_df
def _create_relationships(
noun_phrases: list[str],
) -> list[tuple[str, str]]:
"""Create a (source, target) tuple pairwise for all noun phrases in a list."""
relationships = []
if len(noun_phrases) >= 2:
for i in range(len(noun_phrases) - 1):
for j in range(i + 1, len(noun_phrases)):
relationships.extend([(noun_phrases[i], noun_phrases[j])])
return relationships
def _calculate_pmi_edge_weights(
nodes_df: pd.DataFrame,
edges_df: pd.DataFrame,
@ -192,8 +172,7 @@ def _calculate_pmi_edge_weights(
.drop(columns=[node_name_col])
.rename(columns={"prop_occurrence": "target_prop"})
)
edges_df[edge_weight_col] = edges_df.apply(
lambda x: math.log2(x["prop_weight"] / (x["source_prop"] * x["target_prop"])),
axis=1,
edges_df[edge_weight_col] = edges_df["prop_weight"] * np.log2(
edges_df["prop_weight"] / (edges_df["source_prop"] * edges_df["target_prop"])
)
return edges_df.drop(columns=["prop_weight", "source_prop", "target_prop"])

View File

@ -21,6 +21,14 @@ def graph_to_dataframes(
edges = nx.to_pandas_edgelist(graph)
# we don't deal in directed graphs, but we do need to ensure consistent ordering for df joins
# nx loses the initial ordering
edges["min_source"] = edges[["source", "target"]].min(axis=1)
edges["max_target"] = edges[["source", "target"]].max(axis=1)
edges = edges.drop(columns=["source", "target"]).rename(
columns={"min_source": "source", "max_target": "target"} # type: ignore
)
if node_columns:
nodes = nodes.loc[:, node_columns]

View File

@ -21,7 +21,7 @@ def prune_graph(
max_node_freq_std: float | None = None,
min_node_degree: int = 1,
max_node_degree_std: float | None = None,
min_edge_weight_pct: float = 0,
min_edge_weight_pct: float = 40,
remove_ego_nodes: bool = False,
lcc_only: bool = False,
) -> nx.Graph:

View File

@ -38,7 +38,8 @@ async def run_workflow(
if config.snapshots.graphml:
# todo: extract graphs at each level, and add in meta like descriptions
graph = create_graph(relationships)
graph = create_graph(final_relationships, edge_attr=["weight"])
await snapshot_graphml(
graph,
name="graph",

View File

@ -28,4 +28,4 @@ async def test_prune_graph():
nodes_actual = await load_table_from_storage("entities", context.storage)
assert len(nodes_actual) == 21
assert len(nodes_actual) == 20