Replace n2v with gee

This commit is contained in:
Nathan Evans 2025-09-08 12:13:51 -07:00
parent e1ae9bde00
commit d779346369
12 changed files with 533 additions and 34 deletions

View File

@ -103,6 +103,8 @@ isin
nocache
nbconvert
levelno
toarray
tsvd
# HTML
nbsp

View File

@ -0,0 +1,196 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "2b5f42c2",
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6b2aff68",
"metadata": {},
"outputs": [],
"source": [
"PROJECT_DIRECTORY = \"<PROJECT_DIRECTORY>\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7abffd6f",
"metadata": {},
"outputs": [],
"source": [
"entities = pd.read_parquet(f\"{PROJECT_DIRECTORY}/output/entities.parquet\")\n",
"print(len(entities))\n",
"entities.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a8d5897",
"metadata": {},
"outputs": [],
"source": [
"relationships = pd.read_parquet(f\"{PROJECT_DIRECTORY}/output/relationships.parquet\")\n",
"print(len(relationships))\n",
"relationships.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "00236b6b",
"metadata": {},
"outputs": [],
"source": [
"communities = pd.read_parquet(f\"{PROJECT_DIRECTORY}/output/communities.parquet\")\n",
"print(len(communities))\n",
"communities.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc324f92",
"metadata": {},
"outputs": [],
"source": [
"from graphrag.index.operations.create_graph import create_graph\n",
"\n",
"graph = create_graph(relationships, edge_attr=[\"weight\"])\n",
"print(graph.nodes)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8eb41087",
"metadata": {},
"outputs": [],
"source": [
"from graphrag.index.operations.embed_graph.embed_node2vec import embed_node2vec\n",
"from graphrag.index.operations.layout_graph.umap import run as run_umap\n",
"\n",
"start = time.time()\n",
"n2v = embed_node2vec(\n",
" graph,\n",
")\n",
"end = time.time()\n",
"print(\"n2v time:\", end - start)\n",
"n_embeddings = dict(zip(n2v.nodes, n2v.embeddings))\n",
"\n",
"\n",
"n_umap = run_umap(graph, n_embeddings, lambda x: x)\n",
"n_umap_list = [{\"title\": p.label, \"x_n2v\": p.x, \"y_n2v\": p.y} for p in n_umap]\n",
"\n",
"n_df = pd.DataFrame(n_umap_list)\n",
"\n",
"n_df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ae7b6da1",
"metadata": {},
"outputs": [],
"source": [
"from graphrag.config.models.embed_graph_config import EmbedGraphConfig\n",
"from graphrag.index.operations.embed_graph.embed_graph import embed_graph\n",
"\n",
"pipeline_embeddings = embed_graph(graph, entities, communities, EmbedGraphConfig())\n",
"p_umap = run_umap(graph, pipeline_embeddings, lambda x: x)\n",
"\n",
"p_umap_list = [{\"title\": p.label, \"x_gee_p\": p.x, \"y_gee_p\": p.y} for p in p_umap]\n",
"\n",
"p_df = pd.DataFrame(p_umap_list)\n",
"\n",
"p_df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba9ab829",
"metadata": {},
"outputs": [],
"source": [
"merged_entities = entities.merge(n_df, left_on=\"title\", right_on=\"title\", how=\"left\")\n",
"merged_entities = merged_entities.merge(\n",
" p_df, left_on=\"title\", right_on=\"title\", how=\"left\"\n",
")\n",
"community_labels = communities.explode(\"entity_ids\")[[\"community\", \"entity_ids\", \"level\"]]\n",
"merged_entities = merged_entities.merge(community_labels, left_on=\"id\", right_on=\"entity_ids\", how=\"left\")\n",
"merged_entities = merged_entities[merged_entities[\"level\"] == 0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fde33384",
"metadata": {},
"outputs": [],
"source": [
"merged_entities.plot(\n",
" x=\"x_n2v\",\n",
" y=\"y_n2v\",\n",
" s=5,\n",
" kind=\"scatter\",\n",
" c=\"community\",\n",
" cmap=\"tab20\",\n",
" title=\"n2v\",\n",
" figsize=(12, 10),\n",
" xticks=[],\n",
" yticks=[],\n",
" xlabel=\"\",\n",
" ylabel=\"\",\n",
")\n",
"merged_entities.plot(\n",
" x=\"x_gee_p\",\n",
" y=\"y_gee_p\",\n",
" s=5,\n",
" kind=\"scatter\",\n",
" c=\"community\",\n",
" cmap=\"tab20\",\n",
" title=\"workflow\",\n",
" figsize=(12, 10),\n",
" xticks=[],\n",
" yticks=[],\n",
" xlabel=\"\",\n",
" ylabel=\"\",\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "graphrag",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -44,7 +44,7 @@ class WorkflowCallbacksManager(WorkflowCallbacks):
for callback in self._callbacks:
if hasattr(callback, "workflow_end"):
callback.workflow_end(name, instance)
def workflow_error(self, name: str) -> None:
"""Execute this callback when a workflow has an error."""
for callback in self._callbacks:

View File

@ -0,0 +1,279 @@
# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
# Licensed under the MIT license. See LICENSE file in the project.
#
import logging
from dataclasses import dataclass
import networkx as nx
import numpy as np
from scipy import sparse
from scipy.sparse import linalg
from sklearn.decomposition import TruncatedSVD
from sklearn.preprocessing import normalize
logger = logging.getLogger(__name__)
# invalid divide results will be nan
np.seterr(divide="ignore", invalid="ignore")
@dataclass
class NodeEmbeddings:
"""Node embeddings class definition."""
nodes: list[str]
embeddings: np.ndarray
def _basic(x, y, n):
"""
Graph embedding basic function.
Input X is sparse csr matrix of adjacency matrix
-- if there is a connection between node i and node j:
---- X(i,j) = 1, no edge weight
---- X(i,j) = edge weight.
-- if there is no connection between node i and node j:
---- X(i,j) = 0,
---- note there is no storage for this in sparse matrix.
---- No storage means 0 in sparse matrix.
input Y is numpy array with size (n,1):
-- value -1 indicate no label
-- value >=0 indicate real label
input train_idx: a list of indices of input X for training set
"""
# assign k to the max along the first column
# Note for python, label y starts from 0. Python index starts from 0. thus size k should be max + 1
k = y[:, 0].max() + 1
# nk: 1*n array, contains the number of observations in each class
nk = np.zeros((1, k))
for i in range(k):
nk[0, i] = np.count_nonzero(y[:, 0] == i)
# W: sparse matrix for encoder matrix. w[i,k] = {1/nk if yi==k, otherwise 0}
w = sparse.dok_matrix((n, k), dtype=np.float32)
for i in range(n):
k_i = y[i, 0]
if k_i >= 0:
w[i, k_i] = 1 / nk[0, k_i]
w = sparse.csr_matrix(w)
return x.dot(w)
def _diagonal(x, n):
"""
Graph embedding diagonal function.
Input X is sparse csr matrix of adjacency matrix
return a sparse csr matrix of X matrix with 1s on the diagonal
"""
i = sparse.identity(n)
return x + i
def _laplacian(x):
"""
Graph embedding Laplacian function.
Input X is sparse csr matrix of adjacency matrix
return a sparse csr matrix of Laplacian normalization of X matrix
"""
x_sparse = sparse.csr_matrix(x)
# get an array of degrees
dig = x_sparse.sum(axis=0).A1
# diagonal sparse matrix of D
d = sparse.diags(dig, 0)
d_pow = d.power(-0.5)
return d_pow.dot(x_sparse.dot(d_pow))
def _correlation(z):
"""
Graph embedding correlation function.
Input Z is sparse csr matrix of embedding matrix from the basic function
return normalized Z sparse matrix
Calculation:
Calculate each row's 2-norm (Euclidean distance).
e.g.row_x: [ele_i,ele_j,ele_k]. norm2 = sqr(sum(ele_i^2+ele_i^2+ele_i^2))
then divide each element by their row norm
e.g. [ele_i/norm2,ele_j/norm2,ele_k/norm2]
"""
# 2-norm
row_norm = linalg.norm(z, axis=1)
# row division to get the normalized Z
diag = np.nan_to_num(1 / row_norm)
n = sparse.diags(diag, 0)
return n.dot(z)
def _edge_list_size(x):
"""
Get the edge list size flag.
set default edge list size as S3.
If find x only has 2 columns,
return a flag "S2" indicating this is S2 edge list
"""
# check the first entry to see if it is 2 or 3
if len(x[0]) == 2:
return "S2"
return "S3"
def _edge_to_sparse(x, n, size_flag):
"""
Convert edge list to sparse matrix.
input X is an edge list.
For S2 edge list (e.g. node_i, node_j per row), add one to all connections
return a sparse csr matrix of S3 edge list
"""
# Build an empty sparse matrix.
x_new = sparse.dok_matrix((n, n), dtype=np.float32)
for row in x:
if size_flag == "S2":
[node_i, node_j] = [int(row[0]), int(row[1])]
x_new[node_i, node_j] = 1
else:
[node_i, node_j, weight] = [int(row[0]), int(row[1]), float(row[2])]
x_new[node_i, node_j] = weight
return sparse.csr_matrix(x_new)
def _get_edge_list(graph, node_list):
"""Generate a list of edges with weights for existing nodes."""
node_to_ix = {node: i for i, node in enumerate(node_list)}
return [
[node_to_ix[s], node_to_ix[t], w]
for s, t, w in graph.edges(data="weight")
if s in node_list and t in node_list
]
def _run_gee(
x,
y,
n,
check_edge_list=False,
diag_a=True,
laplacian=False,
correlation=True,
):
if check_edge_list:
size_flag = _edge_list_size(x)
x = _edge_to_sparse(x, n, size_flag)
if diag_a:
x = _diagonal(x, n)
if laplacian:
x = _laplacian(x)
z = _basic(x, y, n)
if correlation:
z = _correlation(z)
return z
def embed_gee(
graph: nx.Graph, node_to_label, correlation, diag_a, laplacian, max_level
) -> NodeEmbeddings:
"""Generate embeddings using graph encoder embedder."""
node_list = sorted(node_to_label.keys())
edge_list = _get_edge_list(graph, node_list)
num_nodes = len(node_list)
node_to_ix = {node: i for i, node in enumerate(node_list)}
# Note that this function relies upon the incoming node_to_label dictionary to be FULL and complete. Every node MUST be defined for EVERY level in the hierarchy.
# When a node doesn't technically exist, make sure that it is populated with the leaf level parent for the logic below to work.
level_embeddings = {}
# For each level
for level in range(max_level + 1):
labels = np.array([
node_to_label[node][level] if node in node_to_label else -1
for node in node_list
]).reshape((
num_nodes,
1,
))
vectors = _run_gee(
edge_list,
labels,
num_nodes,
check_edge_list=True,
laplacian=laplacian,
diag_a=diag_a,
correlation=correlation,
)
level_embeddings[level] = vectors
# Now create a joint embedding across all levels
# get the length of any vector at the root level - this is the minimal number of dimensions to PCA to
# and resize all vectors to be of this same length.
embedding_length = level_embeddings[0].shape[1]
normalized_vectors = {}
for level in range(max_level + 1):
# First check to see if we should PCA the whole thing first to a standard dimensionality
# Obviously for root level 0 - nothing needs to be done
if level_embeddings[level].shape[1] == embedding_length:
# at the root level, just copy the vectors over
normalized_vectors[level] = normalize(level_embeddings[level].toarray())
else:
# ideally we actually run a PCA, but that doesn't scale - so we instead use a TSVD
tsvd = TruncatedSVD(n_components=embedding_length)
tsvd.fit(level_embeddings[level].toarray())
normalized_vectors[level] = normalize(
tsvd.transform(level_embeddings[level].toarray())
)
concat_vectors = {}
# Next, ONLY copy over the nodes that actually exist at a given level
for node in node_list:
for level in range(max_level + 1):
if level not in concat_vectors:
concat_vectors[level] = {}
# Check to see if the node actually existed natively at this level of the hierarchy - otherwise we can take alternative logic - like zeroing out this part of the vector.
if level == 0:
# First, all nodes exist at 0
concat_vectors[level][node] = normalized_vectors[level][
node_to_ix[node]
]
else:
# Deeper the level 0, we have to check if the node actually exists at this level
if node_to_label[node][level - 1] != node_to_label[node][level]:
# the node existed at this depth of the hierarchy
concat_vectors[level][node] = normalized_vectors[level][
node_to_ix[node]
]
else:
# if the node has the SAME cluster ID as its parent, then we know it doesn't actually at this level in the hierarchy
# So this zeros out the vector if it didn't exist at level of the hierarchy.... we can zero it out OR we can use the cluster membership from the tiers above and then keep the embedding
concat_vectors[level][node] = [0] * embedding_length
node_vectors = []
# next - concat all the vectors together for all layers of the hierarchy
for node in node_list:
node_vector = []
for level in range(max_level + 1):
node_vector = np.append(node_vector, concat_vectors[level][node])
node_vectors.append(np.array(node_vector))
node_array = np.vstack(node_vectors)
return normalize(node_array)

View File

@ -4,9 +4,10 @@
"""A module containing embed_graph and run_embeddings methods definition."""
import networkx as nx
import pandas as pd
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
from graphrag.index.operations.embed_graph.embed_node2vec import embed_node2vec
from graphrag.index.operations.embed_graph.embed_gee import embed_gee
from graphrag.index.operations.embed_graph.typing import (
NodeEmbeddings,
)
@ -15,6 +16,8 @@ from graphrag.index.utils.stable_lcc import stable_largest_connected_component
def embed_graph(
graph: nx.Graph,
entities: pd.DataFrame,
communities: pd.DataFrame,
config: EmbedGraphConfig,
) -> NodeEmbeddings:
"""
@ -33,18 +36,36 @@ def embed_graph(
if config.use_lcc:
graph = stable_largest_connected_component(graph)
# create graph embedding using node2vec
embeddings = embed_node2vec(
# gee requires a cluster label for each entity
clusters = communities.explode("entity_ids")
labeled = entities.merge(
clusters[["entity_ids", "community", "level"]],
left_on="id",
right_on="entity_ids",
how="left",
)
labeled = labeled[labeled["community"].notna()]
labeled["community"] = labeled["community"].astype(int)
labeled["level"] = labeled["level"].astype(int)
# gee needs a complete hierarchy for the clusters - we'll "fill down" using parent if a node is missing at lower levels
max_level = labeled["level"].max()
node_to_label = {}
for node in labeled.itertuples():
for level in range(node.level, max_level + 1):
node_labels = node_to_label.get(node.title, {})
node_labels[level] = node.community
node_to_label[node.title] = node_labels
vectors = embed_gee(
graph=graph,
dimensions=config.dimensions,
num_walks=config.num_walks,
walk_length=config.walk_length,
window_size=config.window_size,
iterations=config.iterations,
random_seed=config.random_seed,
node_to_label=node_to_label,
correlation=True,
diag_a=True,
laplacian=True,
max_level=max_level,
)
pairs = zip(embeddings.nodes, embeddings.embeddings.tolist(), strict=True)
sorted_pairs = sorted(pairs, key=lambda x: x[0])
return dict(sorted_pairs)
node_list = sorted(node_to_label.keys())
return dict(zip(node_list, vectors, strict=True))

View File

@ -5,6 +5,7 @@
import logging
from typing import Any
from uuid import uuid4
import pandas as pd
@ -77,7 +78,11 @@ async def extract_graph(
relationship_dfs.append(pd.DataFrame(result[1]))
entities = _merge_entities(entity_dfs)
entities = entities.loc[entities["title"].notna()].reset_index()
entities["id"] = entities.apply(lambda _x: str(uuid4()), axis=1)
relationships = _merge_relationships(relationship_dfs)
relationships["id"] = relationships.apply(lambda _x: str(uuid4()), axis=1)
return (entities, relationships)

View File

@ -3,8 +3,6 @@
"""All the steps to transform final entities."""
from uuid import uuid4
import pandas as pd
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
@ -18,6 +16,7 @@ from graphrag.index.operations.layout_graph.layout_graph import layout_graph
def finalize_entities(
entities: pd.DataFrame,
relationships: pd.DataFrame,
communities: pd.DataFrame,
embed_config: EmbedGraphConfig | None = None,
layout_enabled: bool = False,
) -> pd.DataFrame:
@ -27,8 +26,11 @@ def finalize_entities(
if embed_config is not None and embed_config.enabled:
graph_embeddings = embed_graph(
graph,
entities,
communities,
embed_config,
)
layout = layout_graph(
graph,
layout_enabled,
@ -40,14 +42,10 @@ def finalize_entities(
.merge(degrees, on="title", how="left")
.drop_duplicates(subset="title")
)
final_entities = final_entities.loc[entities["title"].notna()].reset_index()
# disconnected nodes and those with no community even at level 0 can be missing degree
final_entities["degree"] = final_entities["degree"].fillna(0).astype(int)
final_entities.reset_index(inplace=True)
final_entities["human_readable_id"] = final_entities.index
final_entities["id"] = final_entities["human_readable_id"].apply(
lambda _x: str(uuid4())
)
return final_entities.loc[
:,
ENTITIES_FINAL_COLUMNS,

View File

@ -3,8 +3,6 @@
"""All the steps to transform final relationships."""
from uuid import uuid4
import pandas as pd
from graphrag.data_model.schemas import RELATIONSHIPS_FINAL_COLUMNS
@ -34,9 +32,6 @@ def finalize_relationships(
final_relationships.reset_index(inplace=True)
final_relationships["human_readable_id"] = final_relationships.index
final_relationships["id"] = final_relationships["human_readable_id"].apply(
lambda _x: str(uuid4())
)
return final_relationships.loc[
:,

View File

@ -16,10 +16,6 @@ from graphrag.index.operations.layout_graph.typing import (
)
from graphrag.index.typing.error_handler import ErrorHandlerFn
# TODO: This could be handled more elegantly, like what columns to use
# for "size" or "cluster"
# We could also have a boolean to indicate to use node sizes or clusters
logger = logging.getLogger(__name__)

View File

@ -4,6 +4,7 @@
"""A module containing run_workflow method definition."""
import logging
from uuid import uuid4
import pandas as pd
@ -67,6 +68,9 @@ async def extract_graph_nlp(
# add in any other columns required by downstream workflows
extracted_nodes["type"] = "NOUN PHRASE"
extracted_nodes["description"] = ""
extracted_nodes["id"] = extracted_nodes.apply(lambda _x: str(uuid4()), axis=1)
extracted_edges["description"] = ""
extracted_edges["id"] = extracted_edges.apply(lambda _x: str(uuid4()), axis=1)
return (extracted_nodes, extracted_edges)

View File

@ -53,9 +53,10 @@ _standard_workflows = [
"create_base_text_units",
"create_final_documents",
"extract_graph",
# communities need to exist before finalizing the graph for labeled embeddings
"create_communities",
"finalize_graph",
"extract_covariates",
"create_communities",
"create_final_text_units",
"create_community_reports",
"generate_text_embeddings",
@ -65,8 +66,8 @@ _fast_workflows = [
"create_final_documents",
"extract_graph_nlp",
"prune_graph",
"finalize_graph",
"create_communities",
"finalize_graph",
"create_final_text_units",
"create_community_reports_text",
"generate_text_embeddings",

View File

@ -30,10 +30,12 @@ async def run_workflow(
relationships = await load_table_from_storage(
"relationships", context.output_storage
)
communities = await load_table_from_storage("communities", context.output_storage)
final_entities, final_relationships = finalize_graph(
entities,
relationships,
communities,
embed_config=config.embed_graph,
layout_enabled=config.umap.enabled,
)
@ -44,7 +46,6 @@ async def run_workflow(
)
if config.snapshots.graphml:
# todo: extract graphs at each level, and add in meta like descriptions
graph = create_graph(final_relationships, edge_attr=["weight"])
await snapshot_graphml(
@ -65,12 +66,13 @@ async def run_workflow(
def finalize_graph(
entities: pd.DataFrame,
relationships: pd.DataFrame,
communities: pd.DataFrame,
embed_config: EmbedGraphConfig | None = None,
layout_enabled: bool = False,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""All the steps to finalize the entity and relationship formats."""
final_entities = finalize_entities(
entities, relationships, embed_config, layout_enabled
entities, relationships, communities, embed_config, layout_enabled
)
final_relationships = finalize_relationships(relationships)
return (final_entities, final_relationships)