mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 09:07:20 +08:00
Refactor Create Final Community reports to simplify code (#1456)
* Optimize prep claims * Optimize community hierarchy restore * Partial optimization of prepare_community_reports * More optimization code * Fix context string generation * Filter community -1 * Fix cache, add more optimization fixes * Fix local search community ids * Cleanup * Format * Semver * Remove perf counter * Unused import * Format * Fix edge addition to reports * Add edge by edge context creation * Re-org of the optimization code * Format * Ruff * Some Ruff fixes * More pyright * More pyright * Pyright * Pyright * Update tests
This commit is contained in:
parent
b00142260d
commit
d43124e576
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Optimize Final Community Reports calculation and stabilize cache"
|
||||
}
|
||||
@ -19,6 +19,7 @@ from graphrag.index.graph.extractors.community_reports.schemas import (
|
||||
CLAIM_STATUS,
|
||||
CLAIM_SUBJECT,
|
||||
CLAIM_TYPE,
|
||||
COMMUNITY_ID,
|
||||
EDGE_DEGREE,
|
||||
EDGE_DESCRIPTION,
|
||||
EDGE_DETAILS,
|
||||
@ -83,9 +84,7 @@ async def create_final_community_reports(
|
||||
|
||||
community_reports["community"] = community_reports["community"].astype(int)
|
||||
community_reports["human_readable_id"] = community_reports["community"]
|
||||
community_reports["id"] = community_reports["community"].apply(
|
||||
lambda _x: str(uuid4())
|
||||
)
|
||||
community_reports["id"] = [uuid4().hex for _ in range(len(community_reports))]
|
||||
|
||||
# Merge with communities to add size and period
|
||||
merged = community_reports.merge(
|
||||
@ -115,45 +114,42 @@ async def create_final_community_reports(
|
||||
|
||||
|
||||
def _prep_nodes(input: pd.DataFrame) -> pd.DataFrame:
|
||||
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
|
||||
# merge values of four columns into a map column
|
||||
input[NODE_DETAILS] = input.apply(
|
||||
lambda x: {
|
||||
NODE_ID: x[NODE_ID],
|
||||
NODE_NAME: x[NODE_NAME],
|
||||
NODE_DESCRIPTION: x[NODE_DESCRIPTION],
|
||||
NODE_DEGREE: x[NODE_DEGREE],
|
||||
},
|
||||
axis=1,
|
||||
"""Prepare nodes by filtering, filling missing descriptions, and creating NODE_DETAILS."""
|
||||
# Filter rows where community is not -1
|
||||
input = input.loc[input.loc[:, COMMUNITY_ID] != -1]
|
||||
|
||||
# Fill missing values in NODE_DESCRIPTION
|
||||
input.loc[:, NODE_DESCRIPTION] = input.loc[:, NODE_DESCRIPTION].fillna(
|
||||
"No Description"
|
||||
)
|
||||
|
||||
# Create NODE_DETAILS column
|
||||
input[NODE_DETAILS] = input.loc[
|
||||
:, [NODE_ID, NODE_NAME, NODE_DESCRIPTION, NODE_DEGREE]
|
||||
].to_dict(orient="records")
|
||||
|
||||
return input
|
||||
|
||||
|
||||
def _prep_edges(input: pd.DataFrame) -> pd.DataFrame:
|
||||
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
|
||||
input[EDGE_DETAILS] = input.apply(
|
||||
lambda x: {
|
||||
EDGE_ID: x[EDGE_ID],
|
||||
EDGE_SOURCE: x[EDGE_SOURCE],
|
||||
EDGE_TARGET: x[EDGE_TARGET],
|
||||
EDGE_DESCRIPTION: x[EDGE_DESCRIPTION],
|
||||
EDGE_DEGREE: x[EDGE_DEGREE],
|
||||
},
|
||||
axis=1,
|
||||
)
|
||||
# Fill missing NODE_DESCRIPTION
|
||||
input.fillna(value={NODE_DESCRIPTION: "No Description"}, inplace=True)
|
||||
|
||||
# Create EDGE_DETAILS column
|
||||
input[EDGE_DETAILS] = input.loc[
|
||||
:, [EDGE_ID, EDGE_SOURCE, EDGE_TARGET, EDGE_DESCRIPTION, EDGE_DEGREE]
|
||||
].to_dict(orient="records")
|
||||
|
||||
return input
|
||||
|
||||
|
||||
def _prep_claims(input: pd.DataFrame) -> pd.DataFrame:
|
||||
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
|
||||
input[CLAIM_DETAILS] = input.apply(
|
||||
lambda x: {
|
||||
CLAIM_ID: x[CLAIM_ID],
|
||||
CLAIM_SUBJECT: x[CLAIM_SUBJECT],
|
||||
CLAIM_TYPE: x[CLAIM_TYPE],
|
||||
CLAIM_STATUS: x[CLAIM_STATUS],
|
||||
CLAIM_DESCRIPTION: x[CLAIM_DESCRIPTION],
|
||||
},
|
||||
axis=1,
|
||||
)
|
||||
# Fill missing NODE_DESCRIPTION
|
||||
input.fillna(value={NODE_DESCRIPTION: "No Description"}, inplace=True)
|
||||
|
||||
# Create CLAIM_DETAILS column
|
||||
input[CLAIM_DETAILS] = input.loc[
|
||||
:, [CLAIM_ID, CLAIM_SUBJECT, CLAIM_TYPE, CLAIM_STATUS, CLAIM_DESCRIPTION]
|
||||
].to_dict(orient="records")
|
||||
|
||||
return input
|
||||
|
||||
@ -14,25 +14,11 @@ from graphrag.index.graph.extractors.community_reports.prep_community_report_con
|
||||
prep_community_report_context,
|
||||
)
|
||||
from graphrag.index.graph.extractors.community_reports.sort_context import sort_context
|
||||
from graphrag.index.graph.extractors.community_reports.utils import (
|
||||
filter_claims_to_nodes,
|
||||
filter_edges_to_nodes,
|
||||
filter_nodes_to_level,
|
||||
get_levels,
|
||||
set_context_exceeds_flag,
|
||||
set_context_size,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CommunityReportsExtractor",
|
||||
"build_mixed_context",
|
||||
"filter_claims_to_nodes",
|
||||
"filter_edges_to_nodes",
|
||||
"filter_nodes_to_level",
|
||||
"get_levels",
|
||||
"prep_community_report_context",
|
||||
"schemas",
|
||||
"set_context_exceeds_flag",
|
||||
"set_context_size",
|
||||
"sort_context",
|
||||
]
|
||||
|
||||
@ -13,7 +13,6 @@ from graphrag.index.graph.extractors.community_reports.build_mixed_context impor
|
||||
build_mixed_context,
|
||||
)
|
||||
from graphrag.index.graph.extractors.community_reports.sort_context import sort_context
|
||||
from graphrag.index.graph.extractors.community_reports.utils import set_context_size
|
||||
from graphrag.index.utils.dataframes import (
|
||||
antijoin,
|
||||
drop_columns,
|
||||
@ -23,6 +22,7 @@ from graphrag.index.utils.dataframes import (
|
||||
union,
|
||||
where_column_equals,
|
||||
)
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -31,7 +31,7 @@ def prep_community_report_context(
|
||||
report_df: pd.DataFrame | None,
|
||||
community_hierarchy_df: pd.DataFrame,
|
||||
local_context_df: pd.DataFrame,
|
||||
level: int | str,
|
||||
level: int,
|
||||
max_tokens: int,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
@ -44,10 +44,18 @@ def prep_community_report_context(
|
||||
if report_df is None:
|
||||
report_df = pd.DataFrame()
|
||||
|
||||
level = int(level)
|
||||
level_context_df = _at_level(level, local_context_df)
|
||||
valid_context_df = _within_context(level_context_df)
|
||||
invalid_context_df = _exceeding_context(level_context_df)
|
||||
# Filter by community level
|
||||
level_context_df = local_context_df.loc[
|
||||
local_context_df.loc[:, schemas.COMMUNITY_LEVEL] == level
|
||||
]
|
||||
|
||||
# Filter valid and invalid contexts using boolean logic
|
||||
valid_context_df = level_context_df.loc[
|
||||
~level_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG]
|
||||
]
|
||||
invalid_context_df = level_context_df.loc[
|
||||
level_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG]
|
||||
]
|
||||
|
||||
# there is no report to substitute with, so we just trim the local context of the invalid context records
|
||||
# this case should only happen at the bottom level of the community hierarchy where there are no sub-communities
|
||||
@ -55,11 +63,13 @@ def prep_community_report_context(
|
||||
return valid_context_df
|
||||
|
||||
if report_df.empty:
|
||||
invalid_context_df[schemas.CONTEXT_STRING] = _sort_and_trim_context(
|
||||
invalid_context_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
|
||||
invalid_context_df, max_tokens
|
||||
)
|
||||
set_context_size(invalid_context_df)
|
||||
invalid_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG] = 0
|
||||
invalid_context_df[schemas.CONTEXT_SIZE] = invalid_context_df.loc[
|
||||
:, schemas.CONTEXT_STRING
|
||||
].map(num_tokens)
|
||||
invalid_context_df[schemas.CONTEXT_EXCEED_FLAG] = 0
|
||||
return union(valid_context_df, invalid_context_df)
|
||||
|
||||
level_context_df = _antijoin_reports(level_context_df, report_df)
|
||||
@ -74,12 +84,13 @@ def prep_community_report_context(
|
||||
# handle any remaining invalid records that can't be subsituted with sub-community reports
|
||||
# this should be rare, but if it happens, we will just trim the local context to fit the limit
|
||||
remaining_df = _antijoin_reports(invalid_context_df, community_df)
|
||||
remaining_df[schemas.CONTEXT_STRING] = _sort_and_trim_context(
|
||||
remaining_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
|
||||
remaining_df, max_tokens
|
||||
)
|
||||
|
||||
result = union(valid_context_df, community_df, remaining_df)
|
||||
set_context_size(result)
|
||||
result[schemas.CONTEXT_SIZE] = result.loc[:, schemas.CONTEXT_STRING].map(num_tokens)
|
||||
|
||||
result[schemas.CONTEXT_EXCEED_FLAG] = 0
|
||||
return result
|
||||
|
||||
@ -94,16 +105,6 @@ def _at_level(level: int, df: pd.DataFrame) -> pd.DataFrame:
|
||||
return where_column_equals(df, schemas.COMMUNITY_LEVEL, level)
|
||||
|
||||
|
||||
def _exceeding_context(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Return records where the context exceeds the limit."""
|
||||
return where_column_equals(df, schemas.CONTEXT_EXCEED_FLAG, 1)
|
||||
|
||||
|
||||
def _within_context(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Return records where the context is within the limit."""
|
||||
return where_column_equals(df, schemas.CONTEXT_EXCEED_FLAG, 0)
|
||||
|
||||
|
||||
def _antijoin_reports(df: pd.DataFrame, reports: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Return records in df that are not in reports."""
|
||||
return antijoin(df, reports, schemas.NODE_COMMUNITY)
|
||||
|
||||
@ -12,7 +12,6 @@ def sort_context(
|
||||
local_context: list[dict],
|
||||
sub_community_reports: list[dict] | None = None,
|
||||
max_tokens: int | None = None,
|
||||
node_id_column: str = schemas.NODE_ID,
|
||||
node_name_column: str = schemas.NODE_NAME,
|
||||
node_details_column: str = schemas.NODE_DETAILS,
|
||||
edge_id_column: str = schemas.EDGE_ID,
|
||||
@ -20,14 +19,9 @@ def sort_context(
|
||||
edge_degree_column: str = schemas.EDGE_DEGREE,
|
||||
edge_source_column: str = schemas.EDGE_SOURCE,
|
||||
edge_target_column: str = schemas.EDGE_TARGET,
|
||||
claim_id_column: str = schemas.CLAIM_ID,
|
||||
claim_details_column: str = schemas.CLAIM_DETAILS,
|
||||
community_id_column: str = schemas.COMMUNITY_ID,
|
||||
) -> str:
|
||||
"""Sort context by degree in descending order.
|
||||
|
||||
If max tokens is provided, we will return the context string that fits within the token limit.
|
||||
"""
|
||||
"""Sort context by degree in descending order, optimizing for performance."""
|
||||
|
||||
def _get_context_string(
|
||||
entities: list[dict],
|
||||
@ -38,119 +32,123 @@ def sort_context(
|
||||
"""Concatenate structured data into a context string."""
|
||||
contexts = []
|
||||
if sub_community_reports:
|
||||
sub_community_reports = [
|
||||
report
|
||||
for report in sub_community_reports
|
||||
if community_id_column in report
|
||||
and report[community_id_column]
|
||||
and str(report[community_id_column]).strip() != ""
|
||||
]
|
||||
report_df = pd.DataFrame(sub_community_reports).drop_duplicates()
|
||||
report_df = pd.DataFrame(sub_community_reports)
|
||||
if not report_df.empty:
|
||||
if report_df[community_id_column].dtype == float:
|
||||
report_df[community_id_column] = report_df[
|
||||
community_id_column
|
||||
].astype(int)
|
||||
report_string = (
|
||||
contexts.append(
|
||||
f"----Reports-----\n{report_df.to_csv(index=False, sep=',')}"
|
||||
)
|
||||
contexts.append(report_string)
|
||||
|
||||
entities = [
|
||||
entity
|
||||
for entity in entities
|
||||
if node_id_column in entity
|
||||
and entity[node_id_column]
|
||||
and str(entity[node_id_column]).strip() != ""
|
||||
]
|
||||
entity_df = pd.DataFrame(entities).drop_duplicates()
|
||||
if not entity_df.empty:
|
||||
if entity_df[node_id_column].dtype == float:
|
||||
entity_df[node_id_column] = entity_df[node_id_column].astype(int)
|
||||
entity_string = (
|
||||
f"-----Entities-----\n{entity_df.to_csv(index=False, sep=',')}"
|
||||
)
|
||||
contexts.append(entity_string)
|
||||
|
||||
if claims and len(claims) > 0:
|
||||
claims = [
|
||||
claim
|
||||
for claim in claims
|
||||
if claim_id_column in claim
|
||||
and claim[claim_id_column]
|
||||
and str(claim[claim_id_column]).strip() != ""
|
||||
]
|
||||
claim_df = pd.DataFrame(claims).drop_duplicates()
|
||||
if not claim_df.empty:
|
||||
if claim_df[claim_id_column].dtype == float:
|
||||
claim_df[claim_id_column] = claim_df[claim_id_column].astype(int)
|
||||
claim_string = (
|
||||
f"-----Claims-----\n{claim_df.to_csv(index=False, sep=',')}"
|
||||
)
|
||||
contexts.append(claim_string)
|
||||
|
||||
edges = [
|
||||
edge
|
||||
for edge in edges
|
||||
if edge_id_column in edge
|
||||
and edge[edge_id_column]
|
||||
and str(edge[edge_id_column]).strip() != ""
|
||||
]
|
||||
edge_df = pd.DataFrame(edges).drop_duplicates()
|
||||
if not edge_df.empty:
|
||||
if edge_df[edge_id_column].dtype == float:
|
||||
edge_df[edge_id_column] = edge_df[edge_id_column].astype(int)
|
||||
edge_string = (
|
||||
f"-----Relationships-----\n{edge_df.to_csv(index=False, sep=',')}"
|
||||
)
|
||||
contexts.append(edge_string)
|
||||
for label, data in [
|
||||
("Entities", entities),
|
||||
("Claims", claims),
|
||||
("Relationships", edges),
|
||||
]:
|
||||
if data:
|
||||
data_df = pd.DataFrame(data)
|
||||
if not data_df.empty:
|
||||
contexts.append(
|
||||
f"-----{label}-----\n{data_df.to_csv(index=False, sep=',')}"
|
||||
)
|
||||
|
||||
return "\n\n".join(contexts)
|
||||
|
||||
# sort node details by degree in descending order
|
||||
edges = []
|
||||
node_details = {}
|
||||
claim_details = {}
|
||||
# Preprocess local context
|
||||
edges = [
|
||||
{**e, schemas.EDGE_ID: int(e[schemas.EDGE_ID])}
|
||||
for record in local_context
|
||||
for e in record.get(edge_details_column, [])
|
||||
if isinstance(e, dict)
|
||||
]
|
||||
|
||||
for record in local_context:
|
||||
node_name = record[node_name_column]
|
||||
record_edges = record.get(edge_details_column, [])
|
||||
record_edges = [e for e in record_edges if not pd.isna(e)]
|
||||
record_node_details = record[node_details_column]
|
||||
record_claims = record.get(claim_details_column, [])
|
||||
record_claims = [c for c in record_claims if not pd.isna(c)]
|
||||
node_details = {
|
||||
record[node_name_column]: {
|
||||
**record[node_details_column],
|
||||
schemas.NODE_ID: int(record[node_details_column][schemas.NODE_ID]),
|
||||
}
|
||||
for record in local_context
|
||||
}
|
||||
|
||||
edges.extend(record_edges)
|
||||
node_details[node_name] = record_node_details
|
||||
claim_details[node_name] = record_claims
|
||||
claim_details = {
|
||||
record[node_name_column]: [
|
||||
{**c, schemas.CLAIM_ID: int(c[schemas.CLAIM_ID])}
|
||||
for c in record.get(claim_details_column, [])
|
||||
if isinstance(c, dict) and c.get(schemas.CLAIM_ID) is not None
|
||||
]
|
||||
for record in local_context
|
||||
if isinstance(record.get(claim_details_column), list)
|
||||
}
|
||||
|
||||
edges = [edge for edge in edges if isinstance(edge, dict)]
|
||||
edges = sorted(edges, key=lambda x: x[edge_degree_column], reverse=True)
|
||||
# Sort edges by degree (desc) and ID (asc)
|
||||
edges.sort(key=lambda x: (-x.get(edge_degree_column, 0), x.get(edge_id_column, "")))
|
||||
|
||||
sorted_edges = []
|
||||
sorted_nodes = []
|
||||
sorted_claims = []
|
||||
# Deduplicate and build context incrementally
|
||||
edge_ids, nodes_ids, claims_ids = set(), set(), set()
|
||||
sorted_edges, sorted_nodes, sorted_claims = [], [], []
|
||||
context_string = ""
|
||||
for edge in edges:
|
||||
source_details = node_details.get(edge[edge_source_column], {})
|
||||
target_details = node_details.get(edge[edge_target_column], {})
|
||||
sorted_nodes.extend([source_details, target_details])
|
||||
sorted_edges.append(edge)
|
||||
source_claims = claim_details.get(edge[edge_source_column], [])
|
||||
target_claims = claim_details.get(edge[edge_target_column], [])
|
||||
sorted_claims.extend(source_claims if source_claims else [])
|
||||
sorted_claims.extend(target_claims if source_claims else [])
|
||||
if max_tokens:
|
||||
new_context_string = _get_context_string(
|
||||
sorted_nodes, sorted_edges, sorted_claims, sub_community_reports
|
||||
)
|
||||
if num_tokens(new_context_string) > max_tokens:
|
||||
break
|
||||
context_string = new_context_string
|
||||
|
||||
if context_string == "":
|
||||
return _get_context_string(
|
||||
for edge in edges:
|
||||
source, target = edge[edge_source_column], edge[edge_target_column]
|
||||
|
||||
# Add source and target node details
|
||||
for node in [node_details.get(source), node_details.get(target)]:
|
||||
if node and node[schemas.NODE_ID] not in nodes_ids:
|
||||
nodes_ids.add(node[schemas.NODE_ID])
|
||||
sorted_nodes.append(node)
|
||||
|
||||
# Add claims related to source and target
|
||||
for claims in [claim_details.get(source), claim_details.get(target)]:
|
||||
if claims:
|
||||
for claim in claims:
|
||||
if claim[schemas.CLAIM_ID] not in claims_ids:
|
||||
claims_ids.add(claim[schemas.CLAIM_ID])
|
||||
sorted_claims.append(claim)
|
||||
|
||||
# Add the edge
|
||||
if edge[schemas.EDGE_ID] not in edge_ids:
|
||||
edge_ids.add(edge[schemas.EDGE_ID])
|
||||
sorted_edges.append(edge)
|
||||
|
||||
# Generate new context string
|
||||
new_context_string = _get_context_string(
|
||||
sorted_nodes, sorted_edges, sorted_claims, sub_community_reports
|
||||
)
|
||||
if max_tokens and num_tokens(new_context_string) > max_tokens:
|
||||
break
|
||||
context_string = new_context_string
|
||||
|
||||
return context_string
|
||||
# Return the final context string
|
||||
return context_string or _get_context_string(
|
||||
sorted_nodes, sorted_edges, sorted_claims, sub_community_reports
|
||||
)
|
||||
|
||||
|
||||
def parallel_sort_context_batch(community_df, max_tokens, parallel=False):
|
||||
"""Calculate context using parallelization if enabled."""
|
||||
if parallel:
|
||||
# Use ThreadPoolExecutor for parallel execution
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
with ThreadPoolExecutor(max_workers=None) as executor:
|
||||
context_strings = list(
|
||||
executor.map(
|
||||
lambda x: sort_context(x, max_tokens=max_tokens),
|
||||
community_df[schemas.ALL_CONTEXT],
|
||||
)
|
||||
)
|
||||
community_df[schemas.CONTEXT_STRING] = context_strings
|
||||
|
||||
else:
|
||||
# Assign context strings directly to the DataFrame
|
||||
community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply(
|
||||
lambda context_list: sort_context(context_list, max_tokens=max_tokens)
|
||||
)
|
||||
|
||||
# Calculate other columns
|
||||
community_df[schemas.CONTEXT_SIZE] = community_df[schemas.CONTEXT_STRING].apply(
|
||||
num_tokens
|
||||
)
|
||||
community_df[schemas.CONTEXT_EXCEED_FLAG] = (
|
||||
community_df[schemas.CONTEXT_SIZE] > max_tokens
|
||||
)
|
||||
|
||||
return community_df
|
||||
|
||||
@ -3,53 +3,13 @@
|
||||
|
||||
"""A module containing community report generation utilities."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import graphrag.index.graph.extractors.community_reports.schemas as schemas
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
|
||||
|
||||
def set_context_size(df: pd.DataFrame) -> None:
|
||||
"""Measure the number of tokens in the context."""
|
||||
df.loc[:, schemas.CONTEXT_SIZE] = df.loc[:, schemas.CONTEXT_STRING].apply(
|
||||
lambda x: num_tokens(x)
|
||||
)
|
||||
|
||||
|
||||
def set_context_exceeds_flag(df: pd.DataFrame, max_tokens: int) -> None:
|
||||
"""Set a flag to indicate if the context exceeds the limit."""
|
||||
df.loc[:, schemas.CONTEXT_EXCEED_FLAG] = df.loc[:, schemas.CONTEXT_SIZE].apply(
|
||||
lambda x: x > max_tokens
|
||||
)
|
||||
|
||||
|
||||
def get_levels(df: pd.DataFrame, level_column: str = schemas.NODE_LEVEL) -> list[int]:
|
||||
"""Get the levels of the communities."""
|
||||
result = sorted(df[level_column].fillna(-1).unique().tolist(), reverse=True)
|
||||
return [r for r in result if r != -1]
|
||||
|
||||
|
||||
def filter_nodes_to_level(node_df: pd.DataFrame, level: int) -> pd.DataFrame:
|
||||
"""Filter nodes to level."""
|
||||
return cast(pd.DataFrame, node_df[node_df[schemas.NODE_LEVEL] == level])
|
||||
|
||||
|
||||
def filter_edges_to_nodes(edge_df: pd.DataFrame, nodes: list[str]) -> pd.DataFrame:
|
||||
"""Filter edges to nodes."""
|
||||
return cast(
|
||||
pd.DataFrame,
|
||||
edge_df[
|
||||
edge_df[schemas.EDGE_SOURCE].isin(nodes)
|
||||
& edge_df[schemas.EDGE_TARGET].isin(nodes)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def filter_claims_to_nodes(claims_df: pd.DataFrame, nodes: list[str]) -> pd.DataFrame:
|
||||
"""Filter edges to nodes."""
|
||||
return cast(
|
||||
pd.DataFrame,
|
||||
claims_df[claims_df[schemas.CLAIM_SUBJECT].isin(nodes)],
|
||||
)
|
||||
levels = df[level_column].dropna().unique()
|
||||
levels = [int(lvl) for lvl in levels if lvl != -1]
|
||||
return sorted(levels, reverse=True)
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
"""A module containing create_community_reports and load_strategy methods definition."""
|
||||
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
@ -13,15 +12,10 @@ from datashaper import (
|
||||
)
|
||||
|
||||
import graphrag.index.graph.extractors.community_reports.schemas as schemas
|
||||
from graphrag.index.graph.extractors.community_reports import (
|
||||
filter_claims_to_nodes,
|
||||
filter_edges_to_nodes,
|
||||
filter_nodes_to_level,
|
||||
get_levels,
|
||||
set_context_exceeds_flag,
|
||||
set_context_size,
|
||||
sort_context,
|
||||
from graphrag.index.graph.extractors.community_reports.sort_context import (
|
||||
parallel_sort_context_batch,
|
||||
)
|
||||
from graphrag.index.graph.extractors.community_reports.utils import get_levels
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -35,12 +29,15 @@ def prepare_community_reports(
|
||||
):
|
||||
"""Prep communities for report generation."""
|
||||
levels = get_levels(nodes, schemas.NODE_LEVEL)
|
||||
|
||||
dfs = []
|
||||
|
||||
for level in progress_iterable(levels, callbacks.progress, len(levels)):
|
||||
communities_at_level_df = _prepare_reports_at_level(
|
||||
nodes, edges, claims, level, max_tokens
|
||||
)
|
||||
|
||||
communities_at_level_df.loc[:, schemas.COMMUNITY_LEVEL] = level
|
||||
dfs.append(communities_at_level_df)
|
||||
|
||||
# build initial local context for all communities
|
||||
@ -53,127 +50,121 @@ def _prepare_reports_at_level(
|
||||
claim_df: pd.DataFrame | None,
|
||||
level: int,
|
||||
max_tokens: int = 16_000,
|
||||
community_id_column: str = schemas.COMMUNITY_ID,
|
||||
node_id_column: str = schemas.NODE_ID,
|
||||
node_name_column: str = schemas.NODE_NAME,
|
||||
node_details_column: str = schemas.NODE_DETAILS,
|
||||
node_level_column: str = schemas.NODE_LEVEL,
|
||||
node_degree_column: str = schemas.NODE_DEGREE,
|
||||
node_community_column: str = schemas.NODE_COMMUNITY,
|
||||
edge_id_column: str = schemas.EDGE_ID,
|
||||
edge_source_column: str = schemas.EDGE_SOURCE,
|
||||
edge_target_column: str = schemas.EDGE_TARGET,
|
||||
edge_degree_column: str = schemas.EDGE_DEGREE,
|
||||
edge_details_column: str = schemas.EDGE_DETAILS,
|
||||
claim_id_column: str = schemas.CLAIM_ID,
|
||||
claim_subject_column: str = schemas.CLAIM_SUBJECT,
|
||||
claim_details_column: str = schemas.CLAIM_DETAILS,
|
||||
):
|
||||
def get_edge_details(node_df: pd.DataFrame, edge_df: pd.DataFrame, name_col: str):
|
||||
return node_df.merge(
|
||||
cast(
|
||||
pd.DataFrame,
|
||||
edge_df[[name_col, schemas.EDGE_DETAILS]],
|
||||
).rename(columns={name_col: schemas.NODE_NAME}),
|
||||
on=schemas.NODE_NAME,
|
||||
how="left",
|
||||
)
|
||||
|
||||
level_node_df = filter_nodes_to_level(node_df, level)
|
||||
) -> pd.DataFrame:
|
||||
"""Prepare reports at a given level."""
|
||||
# Filter and prepare node details
|
||||
level_node_df = node_df[node_df[schemas.NODE_LEVEL] == level]
|
||||
log.info("Number of nodes at level=%s => %s", level, len(level_node_df))
|
||||
nodes = level_node_df[node_name_column].tolist()
|
||||
nodes_set = set(level_node_df[schemas.NODE_NAME])
|
||||
|
||||
# Filter edges & claims to those containing the target nodes
|
||||
level_edge_df = filter_edges_to_nodes(edge_df, nodes)
|
||||
level_claim_df = (
|
||||
filter_claims_to_nodes(claim_df, nodes) if claim_df is not None else None
|
||||
)
|
||||
|
||||
# concat all edge details per node
|
||||
merged_node_df = pd.concat(
|
||||
# Filter and prepare edge details
|
||||
level_edge_df = edge_df[
|
||||
edge_df.loc[:, schemas.EDGE_SOURCE].isin(nodes_set)
|
||||
& edge_df.loc[:, schemas.EDGE_TARGET].isin(nodes_set)
|
||||
]
|
||||
level_edge_df.loc[:, schemas.EDGE_DETAILS] = level_edge_df.loc[
|
||||
:,
|
||||
[
|
||||
get_edge_details(level_node_df, level_edge_df, edge_source_column),
|
||||
get_edge_details(level_node_df, level_edge_df, edge_target_column),
|
||||
schemas.EDGE_ID,
|
||||
schemas.EDGE_SOURCE,
|
||||
schemas.EDGE_TARGET,
|
||||
schemas.EDGE_DESCRIPTION,
|
||||
schemas.EDGE_DEGREE,
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
merged_node_df = (
|
||||
merged_node_df.groupby([
|
||||
node_name_column,
|
||||
node_community_column,
|
||||
node_degree_column,
|
||||
node_level_column,
|
||||
])
|
||||
.agg({node_details_column: "first", edge_details_column: list})
|
||||
].to_dict(orient="records")
|
||||
|
||||
level_claim_df = pd.DataFrame()
|
||||
if claim_df is not None:
|
||||
level_claim_df = claim_df[
|
||||
claim_df.loc[:, schemas.CLAIM_SUBJECT].isin(nodes_set)
|
||||
]
|
||||
|
||||
# Merge node and edge details
|
||||
# Group edge details by node and aggregate into lists
|
||||
source_edges = (
|
||||
level_edge_df.groupby(schemas.EDGE_SOURCE)
|
||||
.agg({schemas.EDGE_DETAILS: "first"})
|
||||
.reset_index()
|
||||
.rename(columns={schemas.EDGE_SOURCE: schemas.NODE_NAME})
|
||||
)
|
||||
|
||||
# concat claim details per node
|
||||
if level_claim_df is not None:
|
||||
merged_node_df = merged_node_df.merge(
|
||||
cast(
|
||||
pd.DataFrame,
|
||||
level_claim_df[[claim_subject_column, claim_details_column]],
|
||||
).rename(columns={claim_subject_column: node_name_column}),
|
||||
on=node_name_column,
|
||||
how="left",
|
||||
)
|
||||
target_edges = (
|
||||
level_edge_df.groupby(schemas.EDGE_TARGET)
|
||||
.agg({schemas.EDGE_DETAILS: "first"})
|
||||
.reset_index()
|
||||
.rename(columns={schemas.EDGE_TARGET: schemas.NODE_NAME})
|
||||
)
|
||||
|
||||
# Merge aggregated edges into the node DataFrame
|
||||
merged_node_df = level_node_df.merge(
|
||||
source_edges, on=schemas.NODE_NAME, how="left"
|
||||
).merge(target_edges, on=schemas.NODE_NAME, how="left")
|
||||
|
||||
# Combine source and target edge details into a single column
|
||||
merged_node_df.loc[:, schemas.EDGE_DETAILS] = merged_node_df.loc[
|
||||
:, f"{schemas.EDGE_DETAILS}_x"
|
||||
].combine_first(merged_node_df.loc[:, f"{schemas.EDGE_DETAILS}_y"])
|
||||
|
||||
# Drop intermediate columns
|
||||
merged_node_df.drop(
|
||||
columns=[f"{schemas.EDGE_DETAILS}_x", f"{schemas.EDGE_DETAILS}_y"], inplace=True
|
||||
)
|
||||
|
||||
# Aggregate node and edge details
|
||||
merged_node_df = (
|
||||
merged_node_df.groupby([
|
||||
node_name_column,
|
||||
node_community_column,
|
||||
node_level_column,
|
||||
node_degree_column,
|
||||
schemas.NODE_NAME,
|
||||
schemas.NODE_COMMUNITY,
|
||||
schemas.NODE_LEVEL,
|
||||
schemas.NODE_DEGREE,
|
||||
])
|
||||
.agg({
|
||||
node_details_column: "first",
|
||||
edge_details_column: "first",
|
||||
**({claim_details_column: list} if level_claim_df is not None else {}),
|
||||
schemas.NODE_DETAILS: "first",
|
||||
schemas.EDGE_DETAILS: lambda x: list(x.dropna()),
|
||||
})
|
||||
.reset_index()
|
||||
)
|
||||
|
||||
# concat all node details, including name, degree, node_details, edge_details, and claim_details
|
||||
merged_node_df[schemas.ALL_CONTEXT] = merged_node_df.apply(
|
||||
lambda x: {
|
||||
node_name_column: x[node_name_column],
|
||||
node_degree_column: x[node_degree_column],
|
||||
node_details_column: x[node_details_column],
|
||||
edge_details_column: x[edge_details_column],
|
||||
claim_details_column: x[claim_details_column]
|
||||
if level_claim_df is not None
|
||||
else [],
|
||||
},
|
||||
axis=1,
|
||||
# Add ALL_CONTEXT column
|
||||
# Ensure schemas.CLAIM_DETAILS exists with the correct length
|
||||
# Merge claim details if available
|
||||
if claim_df is not None:
|
||||
merged_node_df = merged_node_df.merge(
|
||||
level_claim_df.loc[
|
||||
:, [schemas.CLAIM_SUBJECT, schemas.CLAIM_DETAILS]
|
||||
].rename(columns={schemas.CLAIM_SUBJECT: schemas.NODE_NAME}),
|
||||
on=schemas.NODE_NAME,
|
||||
how="left",
|
||||
)
|
||||
|
||||
# Create the ALL_CONTEXT column
|
||||
merged_node_df[schemas.ALL_CONTEXT] = (
|
||||
merged_node_df.loc[
|
||||
:,
|
||||
[
|
||||
schemas.NODE_NAME,
|
||||
schemas.NODE_DEGREE,
|
||||
schemas.NODE_DETAILS,
|
||||
schemas.EDGE_DETAILS,
|
||||
],
|
||||
]
|
||||
.assign(
|
||||
**{schemas.CLAIM_DETAILS: merged_node_df[schemas.CLAIM_DETAILS]}
|
||||
if claim_df is not None
|
||||
else {}
|
||||
)
|
||||
.to_dict(orient="records")
|
||||
)
|
||||
|
||||
# group all node details by community
|
||||
community_df = (
|
||||
merged_node_df.groupby(node_community_column)
|
||||
merged_node_df.groupby(schemas.NODE_COMMUNITY)
|
||||
.agg({schemas.ALL_CONTEXT: list})
|
||||
.reset_index()
|
||||
)
|
||||
community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply(
|
||||
lambda x: sort_context(
|
||||
x,
|
||||
node_id_column=node_id_column,
|
||||
node_name_column=node_name_column,
|
||||
node_details_column=node_details_column,
|
||||
edge_id_column=edge_id_column,
|
||||
edge_details_column=edge_details_column,
|
||||
edge_degree_column=edge_degree_column,
|
||||
edge_source_column=edge_source_column,
|
||||
edge_target_column=edge_target_column,
|
||||
claim_id_column=claim_id_column,
|
||||
claim_details_column=claim_details_column,
|
||||
community_id_column=community_id_column,
|
||||
)
|
||||
)
|
||||
set_context_size(community_df)
|
||||
set_context_exceeds_flag(community_df, max_tokens)
|
||||
|
||||
community_df[schemas.COMMUNITY_LEVEL] = level
|
||||
community_df[node_community_column] = community_df[node_community_column].astype(
|
||||
int
|
||||
# Generate community-level context strings using vectorized batch processing
|
||||
return parallel_sort_context_batch(
|
||||
community_df,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return community_df
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
"""A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition."""
|
||||
|
||||
import logging
|
||||
from itertools import pairwise
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@ -19,50 +20,39 @@ def restore_community_hierarchy(
|
||||
level_column: str = schemas.NODE_LEVEL,
|
||||
) -> pd.DataFrame:
|
||||
"""Restore the community hierarchy from the node data."""
|
||||
# Group by community and level, aggregate names as lists
|
||||
community_df = (
|
||||
input.groupby([community_column, level_column])
|
||||
.agg({name_column: list})
|
||||
input.groupby([community_column, level_column])[name_column]
|
||||
.apply(set)
|
||||
.reset_index()
|
||||
)
|
||||
community_levels = {}
|
||||
for _, row in community_df.iterrows():
|
||||
level = row[level_column]
|
||||
name = row[name_column]
|
||||
community = row[community_column]
|
||||
|
||||
if community_levels.get(level) is None:
|
||||
community_levels[level] = {}
|
||||
community_levels[level][community] = name
|
||||
# Build dictionary with levels as integers
|
||||
community_levels = {
|
||||
level: group.set_index(community_column)[name_column].to_dict()
|
||||
for level, group in community_df.groupby(level_column)
|
||||
}
|
||||
|
||||
# get unique levels, sorted in ascending order
|
||||
levels = sorted(community_levels.keys())
|
||||
levels = sorted(community_levels.keys()) # type: ignore
|
||||
community_hierarchy = []
|
||||
|
||||
for idx in range(len(levels) - 1):
|
||||
level = levels[idx]
|
||||
next_level = levels[idx + 1]
|
||||
current_level_communities = community_levels[level]
|
||||
next_level_communities = community_levels[next_level]
|
||||
# Iterate through adjacent levels
|
||||
for current_level, next_level in pairwise(levels):
|
||||
current_communities = community_levels[current_level]
|
||||
next_communities = community_levels[next_level]
|
||||
|
||||
for current_community in current_level_communities:
|
||||
current_entities = current_level_communities[current_community]
|
||||
|
||||
# loop through next level's communities to find all the subcommunities
|
||||
entities_found = 0
|
||||
for next_level_community in next_level_communities:
|
||||
next_entities = next_level_communities[next_level_community]
|
||||
if set(next_entities).issubset(set(current_entities)):
|
||||
# Find sub-communities
|
||||
for curr_comm, curr_entities in current_communities.items():
|
||||
for next_comm, next_entities in next_communities.items():
|
||||
if next_entities.issubset(curr_entities):
|
||||
community_hierarchy.append({
|
||||
community_column: current_community,
|
||||
schemas.COMMUNITY_LEVEL: level,
|
||||
schemas.SUB_COMMUNITY: next_level_community,
|
||||
community_column: curr_comm,
|
||||
schemas.COMMUNITY_LEVEL: current_level,
|
||||
schemas.SUB_COMMUNITY: next_comm,
|
||||
schemas.SUB_COMMUNITY_SIZE: len(next_entities),
|
||||
})
|
||||
|
||||
entities_found += len(next_entities)
|
||||
if entities_found == len(current_entities):
|
||||
break
|
||||
|
||||
return pd.DataFrame(
|
||||
community_hierarchy,
|
||||
)
|
||||
|
||||
@ -18,9 +18,9 @@ import graphrag.config.defaults as defaults
|
||||
import graphrag.index.graph.extractors.community_reports.schemas as schemas
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.index.graph.extractors.community_reports import (
|
||||
get_levels,
|
||||
prep_community_report_context,
|
||||
)
|
||||
from graphrag.index.graph.extractors.community_reports.utils import get_levels
|
||||
from graphrag.index.operations.summarize_communities.typing import (
|
||||
CommunityReport,
|
||||
CommunityReportsStrategy,
|
||||
|
||||
@ -28,15 +28,7 @@ def antijoin(df: pd.DataFrame, exclude: pd.DataFrame, column: str) -> pd.DataFra
|
||||
* exclude: The DataFrame containing rows to remove.
|
||||
* column: The join-on column.
|
||||
"""
|
||||
result = df.merge(
|
||||
exclude[[column]],
|
||||
on=column,
|
||||
how="outer",
|
||||
indicator=True,
|
||||
)
|
||||
if "_merge" in result.columns:
|
||||
result = result[result["_merge"] == "left_only"].drop("_merge", axis=1)
|
||||
return cast(pd.DataFrame, result)
|
||||
return df.loc[~df.loc[:, column].isin(exclude.loc[:, column])]
|
||||
|
||||
|
||||
def transform_series(series: pd.Series, fn: Callable[[Any], Any]) -> pd.Series:
|
||||
|
||||
@ -72,7 +72,7 @@ class LocalSearchMixedContext(LocalContextBuilder):
|
||||
text_units = []
|
||||
self.entities = {entity.id: entity for entity in entities}
|
||||
self.community_reports = {
|
||||
community.id: community for community in community_reports
|
||||
community.community_id: community for community in community_reports
|
||||
}
|
||||
self.text_units = {unit.id: unit for unit in text_units}
|
||||
self.relationships = {
|
||||
@ -254,7 +254,7 @@ class LocalSearchMixedContext(LocalContextBuilder):
|
||||
for community in selected_communities:
|
||||
if community.attributes is None:
|
||||
community.attributes = {}
|
||||
community.attributes["matches"] = community_matches[community.id]
|
||||
community.attributes["matches"] = community_matches[community.community_id]
|
||||
selected_communities.sort(
|
||||
key=lambda x: (x.attributes["matches"], x.rank), # type: ignore
|
||||
reverse=True, # type: ignore
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue
Block a user