mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Replace n2v with gee
This commit is contained in:
parent
e1ae9bde00
commit
d779346369
@ -103,6 +103,8 @@ isin
|
||||
nocache
|
||||
nbconvert
|
||||
levelno
|
||||
toarray
|
||||
tsvd
|
||||
|
||||
# HTML
|
||||
nbsp
|
||||
|
||||
196
docs/examples_notebooks/gee.ipynb
Normal file
196
docs/examples_notebooks/gee.ipynb
Normal file
@ -0,0 +1,196 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2b5f42c2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import time\n",
|
||||
"\n",
|
||||
"import pandas as pd"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6b2aff68",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"PROJECT_DIRECTORY = \"<PROJECT_DIRECTORY>\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7abffd6f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"entities = pd.read_parquet(f\"{PROJECT_DIRECTORY}/output/entities.parquet\")\n",
|
||||
"print(len(entities))\n",
|
||||
"entities.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6a8d5897",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"relationships = pd.read_parquet(f\"{PROJECT_DIRECTORY}/output/relationships.parquet\")\n",
|
||||
"print(len(relationships))\n",
|
||||
"relationships.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "00236b6b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"communities = pd.read_parquet(f\"{PROJECT_DIRECTORY}/output/communities.parquet\")\n",
|
||||
"print(len(communities))\n",
|
||||
"communities.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dc324f92",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from graphrag.index.operations.create_graph import create_graph\n",
|
||||
"\n",
|
||||
"graph = create_graph(relationships, edge_attr=[\"weight\"])\n",
|
||||
"print(graph.nodes)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8eb41087",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from graphrag.index.operations.embed_graph.embed_node2vec import embed_node2vec\n",
|
||||
"from graphrag.index.operations.layout_graph.umap import run as run_umap\n",
|
||||
"\n",
|
||||
"start = time.time()\n",
|
||||
"n2v = embed_node2vec(\n",
|
||||
" graph,\n",
|
||||
")\n",
|
||||
"end = time.time()\n",
|
||||
"print(\"n2v time:\", end - start)\n",
|
||||
"n_embeddings = dict(zip(n2v.nodes, n2v.embeddings))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"n_umap = run_umap(graph, n_embeddings, lambda x: x)\n",
|
||||
"n_umap_list = [{\"title\": p.label, \"x_n2v\": p.x, \"y_n2v\": p.y} for p in n_umap]\n",
|
||||
"\n",
|
||||
"n_df = pd.DataFrame(n_umap_list)\n",
|
||||
"\n",
|
||||
"n_df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ae7b6da1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from graphrag.config.models.embed_graph_config import EmbedGraphConfig\n",
|
||||
"from graphrag.index.operations.embed_graph.embed_graph import embed_graph\n",
|
||||
"\n",
|
||||
"pipeline_embeddings = embed_graph(graph, entities, communities, EmbedGraphConfig())\n",
|
||||
"p_umap = run_umap(graph, pipeline_embeddings, lambda x: x)\n",
|
||||
"\n",
|
||||
"p_umap_list = [{\"title\": p.label, \"x_gee_p\": p.x, \"y_gee_p\": p.y} for p in p_umap]\n",
|
||||
"\n",
|
||||
"p_df = pd.DataFrame(p_umap_list)\n",
|
||||
"\n",
|
||||
"p_df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ba9ab829",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"merged_entities = entities.merge(n_df, left_on=\"title\", right_on=\"title\", how=\"left\")\n",
|
||||
"merged_entities = merged_entities.merge(\n",
|
||||
" p_df, left_on=\"title\", right_on=\"title\", how=\"left\"\n",
|
||||
")\n",
|
||||
"community_labels = communities.explode(\"entity_ids\")[[\"community\", \"entity_ids\", \"level\"]]\n",
|
||||
"merged_entities = merged_entities.merge(community_labels, left_on=\"id\", right_on=\"entity_ids\", how=\"left\")\n",
|
||||
"merged_entities = merged_entities[merged_entities[\"level\"] == 0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fde33384",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"merged_entities.plot(\n",
|
||||
" x=\"x_n2v\",\n",
|
||||
" y=\"y_n2v\",\n",
|
||||
" s=5,\n",
|
||||
" kind=\"scatter\",\n",
|
||||
" c=\"community\",\n",
|
||||
" cmap=\"tab20\",\n",
|
||||
" title=\"n2v\",\n",
|
||||
" figsize=(12, 10),\n",
|
||||
" xticks=[],\n",
|
||||
" yticks=[],\n",
|
||||
" xlabel=\"\",\n",
|
||||
" ylabel=\"\",\n",
|
||||
")\n",
|
||||
"merged_entities.plot(\n",
|
||||
" x=\"x_gee_p\",\n",
|
||||
" y=\"y_gee_p\",\n",
|
||||
" s=5,\n",
|
||||
" kind=\"scatter\",\n",
|
||||
" c=\"community\",\n",
|
||||
" cmap=\"tab20\",\n",
|
||||
" title=\"workflow\",\n",
|
||||
" figsize=(12, 10),\n",
|
||||
" xticks=[],\n",
|
||||
" yticks=[],\n",
|
||||
" xlabel=\"\",\n",
|
||||
" ylabel=\"\",\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "graphrag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@ -44,7 +44,7 @@ class WorkflowCallbacksManager(WorkflowCallbacks):
|
||||
for callback in self._callbacks:
|
||||
if hasattr(callback, "workflow_end"):
|
||||
callback.workflow_end(name, instance)
|
||||
|
||||
|
||||
def workflow_error(self, name: str) -> None:
|
||||
"""Execute this callback when a workflow has an error."""
|
||||
for callback in self._callbacks:
|
||||
|
||||
279
graphrag/index/operations/embed_graph/embed_gee.py
Normal file
279
graphrag/index/operations/embed_graph/embed_gee.py
Normal file
@ -0,0 +1,279 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license. See LICENSE file in the project.
|
||||
#
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from scipy import sparse
|
||||
from scipy.sparse import linalg
|
||||
from sklearn.decomposition import TruncatedSVD
|
||||
from sklearn.preprocessing import normalize
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# invalid divide results will be nan
|
||||
np.seterr(divide="ignore", invalid="ignore")
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeEmbeddings:
|
||||
"""Node embeddings class definition."""
|
||||
|
||||
nodes: list[str]
|
||||
embeddings: np.ndarray
|
||||
|
||||
|
||||
def _basic(x, y, n):
|
||||
"""
|
||||
Graph embedding basic function.
|
||||
|
||||
Input X is sparse csr matrix of adjacency matrix
|
||||
-- if there is a connection between node i and node j:
|
||||
---- X(i,j) = 1, no edge weight
|
||||
---- X(i,j) = edge weight.
|
||||
-- if there is no connection between node i and node j:
|
||||
---- X(i,j) = 0,
|
||||
---- note there is no storage for this in sparse matrix.
|
||||
---- No storage means 0 in sparse matrix.
|
||||
input Y is numpy array with size (n,1):
|
||||
-- value -1 indicate no label
|
||||
-- value >=0 indicate real label
|
||||
input train_idx: a list of indices of input X for training set
|
||||
"""
|
||||
# assign k to the max along the first column
|
||||
# Note for python, label y starts from 0. Python index starts from 0. thus size k should be max + 1
|
||||
k = y[:, 0].max() + 1
|
||||
|
||||
# nk: 1*n array, contains the number of observations in each class
|
||||
nk = np.zeros((1, k))
|
||||
for i in range(k):
|
||||
nk[0, i] = np.count_nonzero(y[:, 0] == i)
|
||||
|
||||
# W: sparse matrix for encoder matrix. w[i,k] = {1/nk if yi==k, otherwise 0}
|
||||
w = sparse.dok_matrix((n, k), dtype=np.float32)
|
||||
|
||||
for i in range(n):
|
||||
k_i = y[i, 0]
|
||||
if k_i >= 0:
|
||||
w[i, k_i] = 1 / nk[0, k_i]
|
||||
|
||||
w = sparse.csr_matrix(w)
|
||||
return x.dot(w)
|
||||
|
||||
|
||||
def _diagonal(x, n):
|
||||
"""
|
||||
Graph embedding diagonal function.
|
||||
|
||||
Input X is sparse csr matrix of adjacency matrix
|
||||
return a sparse csr matrix of X matrix with 1s on the diagonal
|
||||
"""
|
||||
i = sparse.identity(n)
|
||||
return x + i
|
||||
|
||||
|
||||
def _laplacian(x):
|
||||
"""
|
||||
Graph embedding Laplacian function.
|
||||
|
||||
Input X is sparse csr matrix of adjacency matrix
|
||||
return a sparse csr matrix of Laplacian normalization of X matrix
|
||||
"""
|
||||
x_sparse = sparse.csr_matrix(x)
|
||||
# get an array of degrees
|
||||
dig = x_sparse.sum(axis=0).A1
|
||||
# diagonal sparse matrix of D
|
||||
d = sparse.diags(dig, 0)
|
||||
d_pow = d.power(-0.5)
|
||||
return d_pow.dot(x_sparse.dot(d_pow))
|
||||
|
||||
|
||||
def _correlation(z):
|
||||
"""
|
||||
Graph embedding correlation function.
|
||||
|
||||
Input Z is sparse csr matrix of embedding matrix from the basic function
|
||||
return normalized Z sparse matrix
|
||||
Calculation:
|
||||
Calculate each row's 2-norm (Euclidean distance).
|
||||
e.g.row_x: [ele_i,ele_j,ele_k]. norm2 = sqr(sum(ele_i^2+ele_i^2+ele_i^2))
|
||||
then divide each element by their row norm
|
||||
e.g. [ele_i/norm2,ele_j/norm2,ele_k/norm2]
|
||||
"""
|
||||
# 2-norm
|
||||
row_norm = linalg.norm(z, axis=1)
|
||||
|
||||
# row division to get the normalized Z
|
||||
diag = np.nan_to_num(1 / row_norm)
|
||||
n = sparse.diags(diag, 0)
|
||||
return n.dot(z)
|
||||
|
||||
|
||||
def _edge_list_size(x):
|
||||
"""
|
||||
Get the edge list size flag.
|
||||
|
||||
set default edge list size as S3.
|
||||
If find x only has 2 columns,
|
||||
return a flag "S2" indicating this is S2 edge list
|
||||
"""
|
||||
# check the first entry to see if it is 2 or 3
|
||||
if len(x[0]) == 2:
|
||||
return "S2"
|
||||
|
||||
return "S3"
|
||||
|
||||
|
||||
def _edge_to_sparse(x, n, size_flag):
|
||||
"""
|
||||
Convert edge list to sparse matrix.
|
||||
|
||||
input X is an edge list.
|
||||
For S2 edge list (e.g. node_i, node_j per row), add one to all connections
|
||||
return a sparse csr matrix of S3 edge list
|
||||
"""
|
||||
# Build an empty sparse matrix.
|
||||
x_new = sparse.dok_matrix((n, n), dtype=np.float32)
|
||||
|
||||
for row in x:
|
||||
if size_flag == "S2":
|
||||
[node_i, node_j] = [int(row[0]), int(row[1])]
|
||||
x_new[node_i, node_j] = 1
|
||||
else:
|
||||
[node_i, node_j, weight] = [int(row[0]), int(row[1]), float(row[2])]
|
||||
x_new[node_i, node_j] = weight
|
||||
|
||||
return sparse.csr_matrix(x_new)
|
||||
|
||||
|
||||
def _get_edge_list(graph, node_list):
|
||||
"""Generate a list of edges with weights for existing nodes."""
|
||||
node_to_ix = {node: i for i, node in enumerate(node_list)}
|
||||
return [
|
||||
[node_to_ix[s], node_to_ix[t], w]
|
||||
for s, t, w in graph.edges(data="weight")
|
||||
if s in node_list and t in node_list
|
||||
]
|
||||
|
||||
|
||||
def _run_gee(
|
||||
x,
|
||||
y,
|
||||
n,
|
||||
check_edge_list=False,
|
||||
diag_a=True,
|
||||
laplacian=False,
|
||||
correlation=True,
|
||||
):
|
||||
if check_edge_list:
|
||||
size_flag = _edge_list_size(x)
|
||||
x = _edge_to_sparse(x, n, size_flag)
|
||||
|
||||
if diag_a:
|
||||
x = _diagonal(x, n)
|
||||
|
||||
if laplacian:
|
||||
x = _laplacian(x)
|
||||
|
||||
z = _basic(x, y, n)
|
||||
|
||||
if correlation:
|
||||
z = _correlation(z)
|
||||
|
||||
return z
|
||||
|
||||
|
||||
def embed_gee(
|
||||
graph: nx.Graph, node_to_label, correlation, diag_a, laplacian, max_level
|
||||
) -> NodeEmbeddings:
|
||||
"""Generate embeddings using graph encoder embedder."""
|
||||
node_list = sorted(node_to_label.keys())
|
||||
edge_list = _get_edge_list(graph, node_list)
|
||||
num_nodes = len(node_list)
|
||||
|
||||
node_to_ix = {node: i for i, node in enumerate(node_list)}
|
||||
|
||||
# Note that this function relies upon the incoming node_to_label dictionary to be FULL and complete. Every node MUST be defined for EVERY level in the hierarchy.
|
||||
# When a node doesn't technically exist, make sure that it is populated with the leaf level parent for the logic below to work.
|
||||
|
||||
level_embeddings = {}
|
||||
# For each level
|
||||
for level in range(max_level + 1):
|
||||
labels = np.array([
|
||||
node_to_label[node][level] if node in node_to_label else -1
|
||||
for node in node_list
|
||||
]).reshape((
|
||||
num_nodes,
|
||||
1,
|
||||
))
|
||||
vectors = _run_gee(
|
||||
edge_list,
|
||||
labels,
|
||||
num_nodes,
|
||||
check_edge_list=True,
|
||||
laplacian=laplacian,
|
||||
diag_a=diag_a,
|
||||
correlation=correlation,
|
||||
)
|
||||
level_embeddings[level] = vectors
|
||||
|
||||
# Now create a joint embedding across all levels
|
||||
# get the length of any vector at the root level - this is the minimal number of dimensions to PCA to
|
||||
# and resize all vectors to be of this same length.
|
||||
embedding_length = level_embeddings[0].shape[1]
|
||||
|
||||
normalized_vectors = {}
|
||||
|
||||
for level in range(max_level + 1):
|
||||
# First check to see if we should PCA the whole thing first to a standard dimensionality
|
||||
# Obviously for root level 0 - nothing needs to be done
|
||||
if level_embeddings[level].shape[1] == embedding_length:
|
||||
# at the root level, just copy the vectors over
|
||||
normalized_vectors[level] = normalize(level_embeddings[level].toarray())
|
||||
|
||||
else:
|
||||
# ideally we actually run a PCA, but that doesn't scale - so we instead use a TSVD
|
||||
tsvd = TruncatedSVD(n_components=embedding_length)
|
||||
|
||||
tsvd.fit(level_embeddings[level].toarray())
|
||||
|
||||
normalized_vectors[level] = normalize(
|
||||
tsvd.transform(level_embeddings[level].toarray())
|
||||
)
|
||||
|
||||
concat_vectors = {}
|
||||
# Next, ONLY copy over the nodes that actually exist at a given level
|
||||
for node in node_list:
|
||||
for level in range(max_level + 1):
|
||||
if level not in concat_vectors:
|
||||
concat_vectors[level] = {}
|
||||
# Check to see if the node actually existed natively at this level of the hierarchy - otherwise we can take alternative logic - like zeroing out this part of the vector.
|
||||
if level == 0:
|
||||
# First, all nodes exist at 0
|
||||
concat_vectors[level][node] = normalized_vectors[level][
|
||||
node_to_ix[node]
|
||||
]
|
||||
else:
|
||||
# Deeper the level 0, we have to check if the node actually exists at this level
|
||||
if node_to_label[node][level - 1] != node_to_label[node][level]:
|
||||
# the node existed at this depth of the hierarchy
|
||||
concat_vectors[level][node] = normalized_vectors[level][
|
||||
node_to_ix[node]
|
||||
]
|
||||
else:
|
||||
# if the node has the SAME cluster ID as its parent, then we know it doesn't actually at this level in the hierarchy
|
||||
# So this zeros out the vector if it didn't exist at level of the hierarchy.... we can zero it out OR we can use the cluster membership from the tiers above and then keep the embedding
|
||||
concat_vectors[level][node] = [0] * embedding_length
|
||||
|
||||
node_vectors = []
|
||||
# next - concat all the vectors together for all layers of the hierarchy
|
||||
for node in node_list:
|
||||
node_vector = []
|
||||
for level in range(max_level + 1):
|
||||
node_vector = np.append(node_vector, concat_vectors[level][node])
|
||||
node_vectors.append(np.array(node_vector))
|
||||
node_array = np.vstack(node_vectors)
|
||||
|
||||
return normalize(node_array)
|
||||
@ -4,9 +4,10 @@
|
||||
"""A module containing embed_graph and run_embeddings methods definition."""
|
||||
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
|
||||
from graphrag.index.operations.embed_graph.embed_node2vec import embed_node2vec
|
||||
from graphrag.index.operations.embed_graph.embed_gee import embed_gee
|
||||
from graphrag.index.operations.embed_graph.typing import (
|
||||
NodeEmbeddings,
|
||||
)
|
||||
@ -15,6 +16,8 @@ from graphrag.index.utils.stable_lcc import stable_largest_connected_component
|
||||
|
||||
def embed_graph(
|
||||
graph: nx.Graph,
|
||||
entities: pd.DataFrame,
|
||||
communities: pd.DataFrame,
|
||||
config: EmbedGraphConfig,
|
||||
) -> NodeEmbeddings:
|
||||
"""
|
||||
@ -33,18 +36,36 @@ def embed_graph(
|
||||
if config.use_lcc:
|
||||
graph = stable_largest_connected_component(graph)
|
||||
|
||||
# create graph embedding using node2vec
|
||||
embeddings = embed_node2vec(
|
||||
# gee requires a cluster label for each entity
|
||||
clusters = communities.explode("entity_ids")
|
||||
labeled = entities.merge(
|
||||
clusters[["entity_ids", "community", "level"]],
|
||||
left_on="id",
|
||||
right_on="entity_ids",
|
||||
how="left",
|
||||
)
|
||||
labeled = labeled[labeled["community"].notna()]
|
||||
labeled["community"] = labeled["community"].astype(int)
|
||||
labeled["level"] = labeled["level"].astype(int)
|
||||
|
||||
# gee needs a complete hierarchy for the clusters - we'll "fill down" using parent if a node is missing at lower levels
|
||||
max_level = labeled["level"].max()
|
||||
|
||||
node_to_label = {}
|
||||
for node in labeled.itertuples():
|
||||
for level in range(node.level, max_level + 1):
|
||||
node_labels = node_to_label.get(node.title, {})
|
||||
node_labels[level] = node.community
|
||||
node_to_label[node.title] = node_labels
|
||||
|
||||
vectors = embed_gee(
|
||||
graph=graph,
|
||||
dimensions=config.dimensions,
|
||||
num_walks=config.num_walks,
|
||||
walk_length=config.walk_length,
|
||||
window_size=config.window_size,
|
||||
iterations=config.iterations,
|
||||
random_seed=config.random_seed,
|
||||
node_to_label=node_to_label,
|
||||
correlation=True,
|
||||
diag_a=True,
|
||||
laplacian=True,
|
||||
max_level=max_level,
|
||||
)
|
||||
|
||||
pairs = zip(embeddings.nodes, embeddings.embeddings.tolist(), strict=True)
|
||||
sorted_pairs = sorted(pairs, key=lambda x: x[0])
|
||||
|
||||
return dict(sorted_pairs)
|
||||
node_list = sorted(node_to_label.keys())
|
||||
return dict(zip(node_list, vectors, strict=True))
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@ -77,7 +78,11 @@ async def extract_graph(
|
||||
relationship_dfs.append(pd.DataFrame(result[1]))
|
||||
|
||||
entities = _merge_entities(entity_dfs)
|
||||
entities = entities.loc[entities["title"].notna()].reset_index()
|
||||
entities["id"] = entities.apply(lambda _x: str(uuid4()), axis=1)
|
||||
|
||||
relationships = _merge_relationships(relationship_dfs)
|
||||
relationships["id"] = relationships.apply(lambda _x: str(uuid4()), axis=1)
|
||||
|
||||
return (entities, relationships)
|
||||
|
||||
|
||||
@ -3,8 +3,6 @@
|
||||
|
||||
"""All the steps to transform final entities."""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
|
||||
@ -18,6 +16,7 @@ from graphrag.index.operations.layout_graph.layout_graph import layout_graph
|
||||
def finalize_entities(
|
||||
entities: pd.DataFrame,
|
||||
relationships: pd.DataFrame,
|
||||
communities: pd.DataFrame,
|
||||
embed_config: EmbedGraphConfig | None = None,
|
||||
layout_enabled: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
@ -27,8 +26,11 @@ def finalize_entities(
|
||||
if embed_config is not None and embed_config.enabled:
|
||||
graph_embeddings = embed_graph(
|
||||
graph,
|
||||
entities,
|
||||
communities,
|
||||
embed_config,
|
||||
)
|
||||
|
||||
layout = layout_graph(
|
||||
graph,
|
||||
layout_enabled,
|
||||
@ -40,14 +42,10 @@ def finalize_entities(
|
||||
.merge(degrees, on="title", how="left")
|
||||
.drop_duplicates(subset="title")
|
||||
)
|
||||
final_entities = final_entities.loc[entities["title"].notna()].reset_index()
|
||||
# disconnected nodes and those with no community even at level 0 can be missing degree
|
||||
final_entities["degree"] = final_entities["degree"].fillna(0).astype(int)
|
||||
final_entities.reset_index(inplace=True)
|
||||
final_entities["human_readable_id"] = final_entities.index
|
||||
final_entities["id"] = final_entities["human_readable_id"].apply(
|
||||
lambda _x: str(uuid4())
|
||||
)
|
||||
|
||||
return final_entities.loc[
|
||||
:,
|
||||
ENTITIES_FINAL_COLUMNS,
|
||||
|
||||
@ -3,8 +3,6 @@
|
||||
|
||||
"""All the steps to transform final relationships."""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.data_model.schemas import RELATIONSHIPS_FINAL_COLUMNS
|
||||
@ -34,9 +32,6 @@ def finalize_relationships(
|
||||
|
||||
final_relationships.reset_index(inplace=True)
|
||||
final_relationships["human_readable_id"] = final_relationships.index
|
||||
final_relationships["id"] = final_relationships["human_readable_id"].apply(
|
||||
lambda _x: str(uuid4())
|
||||
)
|
||||
|
||||
return final_relationships.loc[
|
||||
:,
|
||||
|
||||
@ -16,10 +16,6 @@ from graphrag.index.operations.layout_graph.typing import (
|
||||
)
|
||||
from graphrag.index.typing.error_handler import ErrorHandlerFn
|
||||
|
||||
# TODO: This could be handled more elegantly, like what columns to use
|
||||
# for "size" or "cluster"
|
||||
# We could also have a boolean to indicate to use node sizes or clusters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
"""A module containing run_workflow method definition."""
|
||||
|
||||
import logging
|
||||
from uuid import uuid4
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@ -67,6 +68,9 @@ async def extract_graph_nlp(
|
||||
# add in any other columns required by downstream workflows
|
||||
extracted_nodes["type"] = "NOUN PHRASE"
|
||||
extracted_nodes["description"] = ""
|
||||
extracted_nodes["id"] = extracted_nodes.apply(lambda _x: str(uuid4()), axis=1)
|
||||
|
||||
extracted_edges["description"] = ""
|
||||
extracted_edges["id"] = extracted_edges.apply(lambda _x: str(uuid4()), axis=1)
|
||||
|
||||
return (extracted_nodes, extracted_edges)
|
||||
|
||||
@ -53,9 +53,10 @@ _standard_workflows = [
|
||||
"create_base_text_units",
|
||||
"create_final_documents",
|
||||
"extract_graph",
|
||||
# communities need to exist before finalizing the graph for labeled embeddings
|
||||
"create_communities",
|
||||
"finalize_graph",
|
||||
"extract_covariates",
|
||||
"create_communities",
|
||||
"create_final_text_units",
|
||||
"create_community_reports",
|
||||
"generate_text_embeddings",
|
||||
@ -65,8 +66,8 @@ _fast_workflows = [
|
||||
"create_final_documents",
|
||||
"extract_graph_nlp",
|
||||
"prune_graph",
|
||||
"finalize_graph",
|
||||
"create_communities",
|
||||
"finalize_graph",
|
||||
"create_final_text_units",
|
||||
"create_community_reports_text",
|
||||
"generate_text_embeddings",
|
||||
|
||||
@ -30,10 +30,12 @@ async def run_workflow(
|
||||
relationships = await load_table_from_storage(
|
||||
"relationships", context.output_storage
|
||||
)
|
||||
communities = await load_table_from_storage("communities", context.output_storage)
|
||||
|
||||
final_entities, final_relationships = finalize_graph(
|
||||
entities,
|
||||
relationships,
|
||||
communities,
|
||||
embed_config=config.embed_graph,
|
||||
layout_enabled=config.umap.enabled,
|
||||
)
|
||||
@ -44,7 +46,6 @@ async def run_workflow(
|
||||
)
|
||||
|
||||
if config.snapshots.graphml:
|
||||
# todo: extract graphs at each level, and add in meta like descriptions
|
||||
graph = create_graph(final_relationships, edge_attr=["weight"])
|
||||
|
||||
await snapshot_graphml(
|
||||
@ -65,12 +66,13 @@ async def run_workflow(
|
||||
def finalize_graph(
|
||||
entities: pd.DataFrame,
|
||||
relationships: pd.DataFrame,
|
||||
communities: pd.DataFrame,
|
||||
embed_config: EmbedGraphConfig | None = None,
|
||||
layout_enabled: bool = False,
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
"""All the steps to finalize the entity and relationship formats."""
|
||||
final_entities = finalize_entities(
|
||||
entities, relationships, embed_config, layout_enabled
|
||||
entities, relationships, communities, embed_config, layout_enabled
|
||||
)
|
||||
final_relationships = finalize_relationships(relationships)
|
||||
return (final_entities, final_relationships)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user