Refactor Create Final Community reports to simplify code (#1456)

* Optimize prep claims

* Optimize community hierarchy restore

* Partial optimization of prepare_community_reports

* More optimization code

* Fix context string generation

* Filter community -1

* Fix cache, add more optimization fixes

* Fix local search community ids

* Cleanup

* Format

* Semver

* Remove perf counter

* Unused import

* Format

* Fix edge addition to reports

* Add edge by edge context creation

* Re-org of the optimization code

* Format

* Ruff

* Some Ruff fixes

* More pyright

* More pyright

* Pyright

* Pyright

* Update tests
This commit is contained in:
Alonso Guevara 2024-12-05 17:13:05 -06:00 committed by GitHub
parent b00142260d
commit d43124e576
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 290 additions and 372 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Optimize Final Community Reports calculation and stabilize cache"
}

View File

@ -19,6 +19,7 @@ from graphrag.index.graph.extractors.community_reports.schemas import (
CLAIM_STATUS,
CLAIM_SUBJECT,
CLAIM_TYPE,
COMMUNITY_ID,
EDGE_DEGREE,
EDGE_DESCRIPTION,
EDGE_DETAILS,
@ -83,9 +84,7 @@ async def create_final_community_reports(
community_reports["community"] = community_reports["community"].astype(int)
community_reports["human_readable_id"] = community_reports["community"]
community_reports["id"] = community_reports["community"].apply(
lambda _x: str(uuid4())
)
community_reports["id"] = [uuid4().hex for _ in range(len(community_reports))]
# Merge with communities to add size and period
merged = community_reports.merge(
@ -115,45 +114,42 @@ async def create_final_community_reports(
def _prep_nodes(input: pd.DataFrame) -> pd.DataFrame:
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
# merge values of four columns into a map column
input[NODE_DETAILS] = input.apply(
lambda x: {
NODE_ID: x[NODE_ID],
NODE_NAME: x[NODE_NAME],
NODE_DESCRIPTION: x[NODE_DESCRIPTION],
NODE_DEGREE: x[NODE_DEGREE],
},
axis=1,
"""Prepare nodes by filtering, filling missing descriptions, and creating NODE_DETAILS."""
# Filter rows where community is not -1
input = input.loc[input.loc[:, COMMUNITY_ID] != -1]
# Fill missing values in NODE_DESCRIPTION
input.loc[:, NODE_DESCRIPTION] = input.loc[:, NODE_DESCRIPTION].fillna(
"No Description"
)
# Create NODE_DETAILS column
input[NODE_DETAILS] = input.loc[
:, [NODE_ID, NODE_NAME, NODE_DESCRIPTION, NODE_DEGREE]
].to_dict(orient="records")
return input
def _prep_edges(input: pd.DataFrame) -> pd.DataFrame:
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
input[EDGE_DETAILS] = input.apply(
lambda x: {
EDGE_ID: x[EDGE_ID],
EDGE_SOURCE: x[EDGE_SOURCE],
EDGE_TARGET: x[EDGE_TARGET],
EDGE_DESCRIPTION: x[EDGE_DESCRIPTION],
EDGE_DEGREE: x[EDGE_DEGREE],
},
axis=1,
)
# Fill missing NODE_DESCRIPTION
input.fillna(value={NODE_DESCRIPTION: "No Description"}, inplace=True)
# Create EDGE_DETAILS column
input[EDGE_DETAILS] = input.loc[
:, [EDGE_ID, EDGE_SOURCE, EDGE_TARGET, EDGE_DESCRIPTION, EDGE_DEGREE]
].to_dict(orient="records")
return input
def _prep_claims(input: pd.DataFrame) -> pd.DataFrame:
input = input.fillna(value={NODE_DESCRIPTION: "No Description"})
input[CLAIM_DETAILS] = input.apply(
lambda x: {
CLAIM_ID: x[CLAIM_ID],
CLAIM_SUBJECT: x[CLAIM_SUBJECT],
CLAIM_TYPE: x[CLAIM_TYPE],
CLAIM_STATUS: x[CLAIM_STATUS],
CLAIM_DESCRIPTION: x[CLAIM_DESCRIPTION],
},
axis=1,
)
# Fill missing NODE_DESCRIPTION
input.fillna(value={NODE_DESCRIPTION: "No Description"}, inplace=True)
# Create CLAIM_DETAILS column
input[CLAIM_DETAILS] = input.loc[
:, [CLAIM_ID, CLAIM_SUBJECT, CLAIM_TYPE, CLAIM_STATUS, CLAIM_DESCRIPTION]
].to_dict(orient="records")
return input

View File

@ -14,25 +14,11 @@ from graphrag.index.graph.extractors.community_reports.prep_community_report_con
prep_community_report_context,
)
from graphrag.index.graph.extractors.community_reports.sort_context import sort_context
from graphrag.index.graph.extractors.community_reports.utils import (
filter_claims_to_nodes,
filter_edges_to_nodes,
filter_nodes_to_level,
get_levels,
set_context_exceeds_flag,
set_context_size,
)
__all__ = [
"CommunityReportsExtractor",
"build_mixed_context",
"filter_claims_to_nodes",
"filter_edges_to_nodes",
"filter_nodes_to_level",
"get_levels",
"prep_community_report_context",
"schemas",
"set_context_exceeds_flag",
"set_context_size",
"sort_context",
]

View File

@ -13,7 +13,6 @@ from graphrag.index.graph.extractors.community_reports.build_mixed_context impor
build_mixed_context,
)
from graphrag.index.graph.extractors.community_reports.sort_context import sort_context
from graphrag.index.graph.extractors.community_reports.utils import set_context_size
from graphrag.index.utils.dataframes import (
antijoin,
drop_columns,
@ -23,6 +22,7 @@ from graphrag.index.utils.dataframes import (
union,
where_column_equals,
)
from graphrag.query.llm.text_utils import num_tokens
log = logging.getLogger(__name__)
@ -31,7 +31,7 @@ def prep_community_report_context(
report_df: pd.DataFrame | None,
community_hierarchy_df: pd.DataFrame,
local_context_df: pd.DataFrame,
level: int | str,
level: int,
max_tokens: int,
) -> pd.DataFrame:
"""
@ -44,10 +44,18 @@ def prep_community_report_context(
if report_df is None:
report_df = pd.DataFrame()
level = int(level)
level_context_df = _at_level(level, local_context_df)
valid_context_df = _within_context(level_context_df)
invalid_context_df = _exceeding_context(level_context_df)
# Filter by community level
level_context_df = local_context_df.loc[
local_context_df.loc[:, schemas.COMMUNITY_LEVEL] == level
]
# Filter valid and invalid contexts using boolean logic
valid_context_df = level_context_df.loc[
~level_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG]
]
invalid_context_df = level_context_df.loc[
level_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG]
]
# there is no report to substitute with, so we just trim the local context of the invalid context records
# this case should only happen at the bottom level of the community hierarchy where there are no sub-communities
@ -55,11 +63,13 @@ def prep_community_report_context(
return valid_context_df
if report_df.empty:
invalid_context_df[schemas.CONTEXT_STRING] = _sort_and_trim_context(
invalid_context_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
invalid_context_df, max_tokens
)
set_context_size(invalid_context_df)
invalid_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG] = 0
invalid_context_df[schemas.CONTEXT_SIZE] = invalid_context_df.loc[
:, schemas.CONTEXT_STRING
].map(num_tokens)
invalid_context_df[schemas.CONTEXT_EXCEED_FLAG] = 0
return union(valid_context_df, invalid_context_df)
level_context_df = _antijoin_reports(level_context_df, report_df)
@ -74,12 +84,13 @@ def prep_community_report_context(
# handle any remaining invalid records that can't be subsituted with sub-community reports
# this should be rare, but if it happens, we will just trim the local context to fit the limit
remaining_df = _antijoin_reports(invalid_context_df, community_df)
remaining_df[schemas.CONTEXT_STRING] = _sort_and_trim_context(
remaining_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
remaining_df, max_tokens
)
result = union(valid_context_df, community_df, remaining_df)
set_context_size(result)
result[schemas.CONTEXT_SIZE] = result.loc[:, schemas.CONTEXT_STRING].map(num_tokens)
result[schemas.CONTEXT_EXCEED_FLAG] = 0
return result
@ -94,16 +105,6 @@ def _at_level(level: int, df: pd.DataFrame) -> pd.DataFrame:
return where_column_equals(df, schemas.COMMUNITY_LEVEL, level)
def _exceeding_context(df: pd.DataFrame) -> pd.DataFrame:
"""Return records where the context exceeds the limit."""
return where_column_equals(df, schemas.CONTEXT_EXCEED_FLAG, 1)
def _within_context(df: pd.DataFrame) -> pd.DataFrame:
"""Return records where the context is within the limit."""
return where_column_equals(df, schemas.CONTEXT_EXCEED_FLAG, 0)
def _antijoin_reports(df: pd.DataFrame, reports: pd.DataFrame) -> pd.DataFrame:
"""Return records in df that are not in reports."""
return antijoin(df, reports, schemas.NODE_COMMUNITY)

View File

@ -12,7 +12,6 @@ def sort_context(
local_context: list[dict],
sub_community_reports: list[dict] | None = None,
max_tokens: int | None = None,
node_id_column: str = schemas.NODE_ID,
node_name_column: str = schemas.NODE_NAME,
node_details_column: str = schemas.NODE_DETAILS,
edge_id_column: str = schemas.EDGE_ID,
@ -20,14 +19,9 @@ def sort_context(
edge_degree_column: str = schemas.EDGE_DEGREE,
edge_source_column: str = schemas.EDGE_SOURCE,
edge_target_column: str = schemas.EDGE_TARGET,
claim_id_column: str = schemas.CLAIM_ID,
claim_details_column: str = schemas.CLAIM_DETAILS,
community_id_column: str = schemas.COMMUNITY_ID,
) -> str:
"""Sort context by degree in descending order.
If max tokens is provided, we will return the context string that fits within the token limit.
"""
"""Sort context by degree in descending order, optimizing for performance."""
def _get_context_string(
entities: list[dict],
@ -38,119 +32,123 @@ def sort_context(
"""Concatenate structured data into a context string."""
contexts = []
if sub_community_reports:
sub_community_reports = [
report
for report in sub_community_reports
if community_id_column in report
and report[community_id_column]
and str(report[community_id_column]).strip() != ""
]
report_df = pd.DataFrame(sub_community_reports).drop_duplicates()
report_df = pd.DataFrame(sub_community_reports)
if not report_df.empty:
if report_df[community_id_column].dtype == float:
report_df[community_id_column] = report_df[
community_id_column
].astype(int)
report_string = (
contexts.append(
f"----Reports-----\n{report_df.to_csv(index=False, sep=',')}"
)
contexts.append(report_string)
entities = [
entity
for entity in entities
if node_id_column in entity
and entity[node_id_column]
and str(entity[node_id_column]).strip() != ""
]
entity_df = pd.DataFrame(entities).drop_duplicates()
if not entity_df.empty:
if entity_df[node_id_column].dtype == float:
entity_df[node_id_column] = entity_df[node_id_column].astype(int)
entity_string = (
f"-----Entities-----\n{entity_df.to_csv(index=False, sep=',')}"
)
contexts.append(entity_string)
if claims and len(claims) > 0:
claims = [
claim
for claim in claims
if claim_id_column in claim
and claim[claim_id_column]
and str(claim[claim_id_column]).strip() != ""
]
claim_df = pd.DataFrame(claims).drop_duplicates()
if not claim_df.empty:
if claim_df[claim_id_column].dtype == float:
claim_df[claim_id_column] = claim_df[claim_id_column].astype(int)
claim_string = (
f"-----Claims-----\n{claim_df.to_csv(index=False, sep=',')}"
)
contexts.append(claim_string)
edges = [
edge
for edge in edges
if edge_id_column in edge
and edge[edge_id_column]
and str(edge[edge_id_column]).strip() != ""
]
edge_df = pd.DataFrame(edges).drop_duplicates()
if not edge_df.empty:
if edge_df[edge_id_column].dtype == float:
edge_df[edge_id_column] = edge_df[edge_id_column].astype(int)
edge_string = (
f"-----Relationships-----\n{edge_df.to_csv(index=False, sep=',')}"
)
contexts.append(edge_string)
for label, data in [
("Entities", entities),
("Claims", claims),
("Relationships", edges),
]:
if data:
data_df = pd.DataFrame(data)
if not data_df.empty:
contexts.append(
f"-----{label}-----\n{data_df.to_csv(index=False, sep=',')}"
)
return "\n\n".join(contexts)
# sort node details by degree in descending order
edges = []
node_details = {}
claim_details = {}
# Preprocess local context
edges = [
{**e, schemas.EDGE_ID: int(e[schemas.EDGE_ID])}
for record in local_context
for e in record.get(edge_details_column, [])
if isinstance(e, dict)
]
for record in local_context:
node_name = record[node_name_column]
record_edges = record.get(edge_details_column, [])
record_edges = [e for e in record_edges if not pd.isna(e)]
record_node_details = record[node_details_column]
record_claims = record.get(claim_details_column, [])
record_claims = [c for c in record_claims if not pd.isna(c)]
node_details = {
record[node_name_column]: {
**record[node_details_column],
schemas.NODE_ID: int(record[node_details_column][schemas.NODE_ID]),
}
for record in local_context
}
edges.extend(record_edges)
node_details[node_name] = record_node_details
claim_details[node_name] = record_claims
claim_details = {
record[node_name_column]: [
{**c, schemas.CLAIM_ID: int(c[schemas.CLAIM_ID])}
for c in record.get(claim_details_column, [])
if isinstance(c, dict) and c.get(schemas.CLAIM_ID) is not None
]
for record in local_context
if isinstance(record.get(claim_details_column), list)
}
edges = [edge for edge in edges if isinstance(edge, dict)]
edges = sorted(edges, key=lambda x: x[edge_degree_column], reverse=True)
# Sort edges by degree (desc) and ID (asc)
edges.sort(key=lambda x: (-x.get(edge_degree_column, 0), x.get(edge_id_column, "")))
sorted_edges = []
sorted_nodes = []
sorted_claims = []
# Deduplicate and build context incrementally
edge_ids, nodes_ids, claims_ids = set(), set(), set()
sorted_edges, sorted_nodes, sorted_claims = [], [], []
context_string = ""
for edge in edges:
source_details = node_details.get(edge[edge_source_column], {})
target_details = node_details.get(edge[edge_target_column], {})
sorted_nodes.extend([source_details, target_details])
sorted_edges.append(edge)
source_claims = claim_details.get(edge[edge_source_column], [])
target_claims = claim_details.get(edge[edge_target_column], [])
sorted_claims.extend(source_claims if source_claims else [])
sorted_claims.extend(target_claims if source_claims else [])
if max_tokens:
new_context_string = _get_context_string(
sorted_nodes, sorted_edges, sorted_claims, sub_community_reports
)
if num_tokens(new_context_string) > max_tokens:
break
context_string = new_context_string
if context_string == "":
return _get_context_string(
for edge in edges:
source, target = edge[edge_source_column], edge[edge_target_column]
# Add source and target node details
for node in [node_details.get(source), node_details.get(target)]:
if node and node[schemas.NODE_ID] not in nodes_ids:
nodes_ids.add(node[schemas.NODE_ID])
sorted_nodes.append(node)
# Add claims related to source and target
for claims in [claim_details.get(source), claim_details.get(target)]:
if claims:
for claim in claims:
if claim[schemas.CLAIM_ID] not in claims_ids:
claims_ids.add(claim[schemas.CLAIM_ID])
sorted_claims.append(claim)
# Add the edge
if edge[schemas.EDGE_ID] not in edge_ids:
edge_ids.add(edge[schemas.EDGE_ID])
sorted_edges.append(edge)
# Generate new context string
new_context_string = _get_context_string(
sorted_nodes, sorted_edges, sorted_claims, sub_community_reports
)
if max_tokens and num_tokens(new_context_string) > max_tokens:
break
context_string = new_context_string
return context_string
# Return the final context string
return context_string or _get_context_string(
sorted_nodes, sorted_edges, sorted_claims, sub_community_reports
)
def parallel_sort_context_batch(community_df, max_tokens, parallel=False):
"""Calculate context using parallelization if enabled."""
if parallel:
# Use ThreadPoolExecutor for parallel execution
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=None) as executor:
context_strings = list(
executor.map(
lambda x: sort_context(x, max_tokens=max_tokens),
community_df[schemas.ALL_CONTEXT],
)
)
community_df[schemas.CONTEXT_STRING] = context_strings
else:
# Assign context strings directly to the DataFrame
community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply(
lambda context_list: sort_context(context_list, max_tokens=max_tokens)
)
# Calculate other columns
community_df[schemas.CONTEXT_SIZE] = community_df[schemas.CONTEXT_STRING].apply(
num_tokens
)
community_df[schemas.CONTEXT_EXCEED_FLAG] = (
community_df[schemas.CONTEXT_SIZE] > max_tokens
)
return community_df

View File

@ -3,53 +3,13 @@
"""A module containing community report generation utilities."""
from typing import cast
import pandas as pd
import graphrag.index.graph.extractors.community_reports.schemas as schemas
from graphrag.query.llm.text_utils import num_tokens
def set_context_size(df: pd.DataFrame) -> None:
"""Measure the number of tokens in the context."""
df.loc[:, schemas.CONTEXT_SIZE] = df.loc[:, schemas.CONTEXT_STRING].apply(
lambda x: num_tokens(x)
)
def set_context_exceeds_flag(df: pd.DataFrame, max_tokens: int) -> None:
"""Set a flag to indicate if the context exceeds the limit."""
df.loc[:, schemas.CONTEXT_EXCEED_FLAG] = df.loc[:, schemas.CONTEXT_SIZE].apply(
lambda x: x > max_tokens
)
def get_levels(df: pd.DataFrame, level_column: str = schemas.NODE_LEVEL) -> list[int]:
"""Get the levels of the communities."""
result = sorted(df[level_column].fillna(-1).unique().tolist(), reverse=True)
return [r for r in result if r != -1]
def filter_nodes_to_level(node_df: pd.DataFrame, level: int) -> pd.DataFrame:
"""Filter nodes to level."""
return cast(pd.DataFrame, node_df[node_df[schemas.NODE_LEVEL] == level])
def filter_edges_to_nodes(edge_df: pd.DataFrame, nodes: list[str]) -> pd.DataFrame:
"""Filter edges to nodes."""
return cast(
pd.DataFrame,
edge_df[
edge_df[schemas.EDGE_SOURCE].isin(nodes)
& edge_df[schemas.EDGE_TARGET].isin(nodes)
],
)
def filter_claims_to_nodes(claims_df: pd.DataFrame, nodes: list[str]) -> pd.DataFrame:
"""Filter edges to nodes."""
return cast(
pd.DataFrame,
claims_df[claims_df[schemas.CLAIM_SUBJECT].isin(nodes)],
)
levels = df[level_column].dropna().unique()
levels = [int(lvl) for lvl in levels if lvl != -1]
return sorted(levels, reverse=True)

View File

@ -4,7 +4,6 @@
"""A module containing create_community_reports and load_strategy methods definition."""
import logging
from typing import cast
import pandas as pd
from datashaper import (
@ -13,15 +12,10 @@ from datashaper import (
)
import graphrag.index.graph.extractors.community_reports.schemas as schemas
from graphrag.index.graph.extractors.community_reports import (
filter_claims_to_nodes,
filter_edges_to_nodes,
filter_nodes_to_level,
get_levels,
set_context_exceeds_flag,
set_context_size,
sort_context,
from graphrag.index.graph.extractors.community_reports.sort_context import (
parallel_sort_context_batch,
)
from graphrag.index.graph.extractors.community_reports.utils import get_levels
log = logging.getLogger(__name__)
@ -35,12 +29,15 @@ def prepare_community_reports(
):
"""Prep communities for report generation."""
levels = get_levels(nodes, schemas.NODE_LEVEL)
dfs = []
for level in progress_iterable(levels, callbacks.progress, len(levels)):
communities_at_level_df = _prepare_reports_at_level(
nodes, edges, claims, level, max_tokens
)
communities_at_level_df.loc[:, schemas.COMMUNITY_LEVEL] = level
dfs.append(communities_at_level_df)
# build initial local context for all communities
@ -53,127 +50,121 @@ def _prepare_reports_at_level(
claim_df: pd.DataFrame | None,
level: int,
max_tokens: int = 16_000,
community_id_column: str = schemas.COMMUNITY_ID,
node_id_column: str = schemas.NODE_ID,
node_name_column: str = schemas.NODE_NAME,
node_details_column: str = schemas.NODE_DETAILS,
node_level_column: str = schemas.NODE_LEVEL,
node_degree_column: str = schemas.NODE_DEGREE,
node_community_column: str = schemas.NODE_COMMUNITY,
edge_id_column: str = schemas.EDGE_ID,
edge_source_column: str = schemas.EDGE_SOURCE,
edge_target_column: str = schemas.EDGE_TARGET,
edge_degree_column: str = schemas.EDGE_DEGREE,
edge_details_column: str = schemas.EDGE_DETAILS,
claim_id_column: str = schemas.CLAIM_ID,
claim_subject_column: str = schemas.CLAIM_SUBJECT,
claim_details_column: str = schemas.CLAIM_DETAILS,
):
def get_edge_details(node_df: pd.DataFrame, edge_df: pd.DataFrame, name_col: str):
return node_df.merge(
cast(
pd.DataFrame,
edge_df[[name_col, schemas.EDGE_DETAILS]],
).rename(columns={name_col: schemas.NODE_NAME}),
on=schemas.NODE_NAME,
how="left",
)
level_node_df = filter_nodes_to_level(node_df, level)
) -> pd.DataFrame:
"""Prepare reports at a given level."""
# Filter and prepare node details
level_node_df = node_df[node_df[schemas.NODE_LEVEL] == level]
log.info("Number of nodes at level=%s => %s", level, len(level_node_df))
nodes = level_node_df[node_name_column].tolist()
nodes_set = set(level_node_df[schemas.NODE_NAME])
# Filter edges & claims to those containing the target nodes
level_edge_df = filter_edges_to_nodes(edge_df, nodes)
level_claim_df = (
filter_claims_to_nodes(claim_df, nodes) if claim_df is not None else None
)
# concat all edge details per node
merged_node_df = pd.concat(
# Filter and prepare edge details
level_edge_df = edge_df[
edge_df.loc[:, schemas.EDGE_SOURCE].isin(nodes_set)
& edge_df.loc[:, schemas.EDGE_TARGET].isin(nodes_set)
]
level_edge_df.loc[:, schemas.EDGE_DETAILS] = level_edge_df.loc[
:,
[
get_edge_details(level_node_df, level_edge_df, edge_source_column),
get_edge_details(level_node_df, level_edge_df, edge_target_column),
schemas.EDGE_ID,
schemas.EDGE_SOURCE,
schemas.EDGE_TARGET,
schemas.EDGE_DESCRIPTION,
schemas.EDGE_DEGREE,
],
axis=0,
)
merged_node_df = (
merged_node_df.groupby([
node_name_column,
node_community_column,
node_degree_column,
node_level_column,
])
.agg({node_details_column: "first", edge_details_column: list})
].to_dict(orient="records")
level_claim_df = pd.DataFrame()
if claim_df is not None:
level_claim_df = claim_df[
claim_df.loc[:, schemas.CLAIM_SUBJECT].isin(nodes_set)
]
# Merge node and edge details
# Group edge details by node and aggregate into lists
source_edges = (
level_edge_df.groupby(schemas.EDGE_SOURCE)
.agg({schemas.EDGE_DETAILS: "first"})
.reset_index()
.rename(columns={schemas.EDGE_SOURCE: schemas.NODE_NAME})
)
# concat claim details per node
if level_claim_df is not None:
merged_node_df = merged_node_df.merge(
cast(
pd.DataFrame,
level_claim_df[[claim_subject_column, claim_details_column]],
).rename(columns={claim_subject_column: node_name_column}),
on=node_name_column,
how="left",
)
target_edges = (
level_edge_df.groupby(schemas.EDGE_TARGET)
.agg({schemas.EDGE_DETAILS: "first"})
.reset_index()
.rename(columns={schemas.EDGE_TARGET: schemas.NODE_NAME})
)
# Merge aggregated edges into the node DataFrame
merged_node_df = level_node_df.merge(
source_edges, on=schemas.NODE_NAME, how="left"
).merge(target_edges, on=schemas.NODE_NAME, how="left")
# Combine source and target edge details into a single column
merged_node_df.loc[:, schemas.EDGE_DETAILS] = merged_node_df.loc[
:, f"{schemas.EDGE_DETAILS}_x"
].combine_first(merged_node_df.loc[:, f"{schemas.EDGE_DETAILS}_y"])
# Drop intermediate columns
merged_node_df.drop(
columns=[f"{schemas.EDGE_DETAILS}_x", f"{schemas.EDGE_DETAILS}_y"], inplace=True
)
# Aggregate node and edge details
merged_node_df = (
merged_node_df.groupby([
node_name_column,
node_community_column,
node_level_column,
node_degree_column,
schemas.NODE_NAME,
schemas.NODE_COMMUNITY,
schemas.NODE_LEVEL,
schemas.NODE_DEGREE,
])
.agg({
node_details_column: "first",
edge_details_column: "first",
**({claim_details_column: list} if level_claim_df is not None else {}),
schemas.NODE_DETAILS: "first",
schemas.EDGE_DETAILS: lambda x: list(x.dropna()),
})
.reset_index()
)
# concat all node details, including name, degree, node_details, edge_details, and claim_details
merged_node_df[schemas.ALL_CONTEXT] = merged_node_df.apply(
lambda x: {
node_name_column: x[node_name_column],
node_degree_column: x[node_degree_column],
node_details_column: x[node_details_column],
edge_details_column: x[edge_details_column],
claim_details_column: x[claim_details_column]
if level_claim_df is not None
else [],
},
axis=1,
# Add ALL_CONTEXT column
# Ensure schemas.CLAIM_DETAILS exists with the correct length
# Merge claim details if available
if claim_df is not None:
merged_node_df = merged_node_df.merge(
level_claim_df.loc[
:, [schemas.CLAIM_SUBJECT, schemas.CLAIM_DETAILS]
].rename(columns={schemas.CLAIM_SUBJECT: schemas.NODE_NAME}),
on=schemas.NODE_NAME,
how="left",
)
# Create the ALL_CONTEXT column
merged_node_df[schemas.ALL_CONTEXT] = (
merged_node_df.loc[
:,
[
schemas.NODE_NAME,
schemas.NODE_DEGREE,
schemas.NODE_DETAILS,
schemas.EDGE_DETAILS,
],
]
.assign(
**{schemas.CLAIM_DETAILS: merged_node_df[schemas.CLAIM_DETAILS]}
if claim_df is not None
else {}
)
.to_dict(orient="records")
)
# group all node details by community
community_df = (
merged_node_df.groupby(node_community_column)
merged_node_df.groupby(schemas.NODE_COMMUNITY)
.agg({schemas.ALL_CONTEXT: list})
.reset_index()
)
community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply(
lambda x: sort_context(
x,
node_id_column=node_id_column,
node_name_column=node_name_column,
node_details_column=node_details_column,
edge_id_column=edge_id_column,
edge_details_column=edge_details_column,
edge_degree_column=edge_degree_column,
edge_source_column=edge_source_column,
edge_target_column=edge_target_column,
claim_id_column=claim_id_column,
claim_details_column=claim_details_column,
community_id_column=community_id_column,
)
)
set_context_size(community_df)
set_context_exceeds_flag(community_df, max_tokens)
community_df[schemas.COMMUNITY_LEVEL] = level
community_df[node_community_column] = community_df[node_community_column].astype(
int
# Generate community-level context strings using vectorized batch processing
return parallel_sort_context_batch(
community_df,
max_tokens=max_tokens,
)
return community_df

View File

@ -4,6 +4,7 @@
"""A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition."""
import logging
from itertools import pairwise
import pandas as pd
@ -19,50 +20,39 @@ def restore_community_hierarchy(
level_column: str = schemas.NODE_LEVEL,
) -> pd.DataFrame:
"""Restore the community hierarchy from the node data."""
# Group by community and level, aggregate names as lists
community_df = (
input.groupby([community_column, level_column])
.agg({name_column: list})
input.groupby([community_column, level_column])[name_column]
.apply(set)
.reset_index()
)
community_levels = {}
for _, row in community_df.iterrows():
level = row[level_column]
name = row[name_column]
community = row[community_column]
if community_levels.get(level) is None:
community_levels[level] = {}
community_levels[level][community] = name
# Build dictionary with levels as integers
community_levels = {
level: group.set_index(community_column)[name_column].to_dict()
for level, group in community_df.groupby(level_column)
}
# get unique levels, sorted in ascending order
levels = sorted(community_levels.keys())
levels = sorted(community_levels.keys()) # type: ignore
community_hierarchy = []
for idx in range(len(levels) - 1):
level = levels[idx]
next_level = levels[idx + 1]
current_level_communities = community_levels[level]
next_level_communities = community_levels[next_level]
# Iterate through adjacent levels
for current_level, next_level in pairwise(levels):
current_communities = community_levels[current_level]
next_communities = community_levels[next_level]
for current_community in current_level_communities:
current_entities = current_level_communities[current_community]
# loop through next level's communities to find all the subcommunities
entities_found = 0
for next_level_community in next_level_communities:
next_entities = next_level_communities[next_level_community]
if set(next_entities).issubset(set(current_entities)):
# Find sub-communities
for curr_comm, curr_entities in current_communities.items():
for next_comm, next_entities in next_communities.items():
if next_entities.issubset(curr_entities):
community_hierarchy.append({
community_column: current_community,
schemas.COMMUNITY_LEVEL: level,
schemas.SUB_COMMUNITY: next_level_community,
community_column: curr_comm,
schemas.COMMUNITY_LEVEL: current_level,
schemas.SUB_COMMUNITY: next_comm,
schemas.SUB_COMMUNITY_SIZE: len(next_entities),
})
entities_found += len(next_entities)
if entities_found == len(current_entities):
break
return pd.DataFrame(
community_hierarchy,
)

View File

@ -18,9 +18,9 @@ import graphrag.config.defaults as defaults
import graphrag.index.graph.extractors.community_reports.schemas as schemas
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.graph.extractors.community_reports import (
get_levels,
prep_community_report_context,
)
from graphrag.index.graph.extractors.community_reports.utils import get_levels
from graphrag.index.operations.summarize_communities.typing import (
CommunityReport,
CommunityReportsStrategy,

View File

@ -28,15 +28,7 @@ def antijoin(df: pd.DataFrame, exclude: pd.DataFrame, column: str) -> pd.DataFra
* exclude: The DataFrame containing rows to remove.
* column: The join-on column.
"""
result = df.merge(
exclude[[column]],
on=column,
how="outer",
indicator=True,
)
if "_merge" in result.columns:
result = result[result["_merge"] == "left_only"].drop("_merge", axis=1)
return cast(pd.DataFrame, result)
return df.loc[~df.loc[:, column].isin(exclude.loc[:, column])]
def transform_series(series: pd.Series, fn: Callable[[Any], Any]) -> pd.Series:

View File

@ -72,7 +72,7 @@ class LocalSearchMixedContext(LocalContextBuilder):
text_units = []
self.entities = {entity.id: entity for entity in entities}
self.community_reports = {
community.id: community for community in community_reports
community.community_id: community for community in community_reports
}
self.text_units = {unit.id: unit for unit in text_units}
self.relationships = {
@ -254,7 +254,7 @@ class LocalSearchMixedContext(LocalContextBuilder):
for community in selected_communities:
if community.attributes is None:
community.attributes = {}
community.attributes["matches"] = community_matches[community.id]
community.attributes["matches"] = community_matches[community.community_id]
selected_communities.sort(
key=lambda x: (x.attributes["matches"], x.rank), # type: ignore
reverse=True, # type: ignore