diff --git a/.semversioner/next-release/patch-20250422215029679348.json b/.semversioner/next-release/patch-20250422215029679348.json new file mode 100644 index 00000000..0677145b --- /dev/null +++ b/.semversioner/next-release/patch-20250422215029679348.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Add option to snapshot raw extractd graph tables." +} diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index 627544ec..8d3afe13 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -339,6 +339,7 @@ class SnapshotsDefaults: embeddings: bool = False graphml: bool = False + raw_graph: bool = False @dataclass diff --git a/graphrag/config/models/snapshots_config.py b/graphrag/config/models/snapshots_config.py index b097e3d8..5c0109e3 100644 --- a/graphrag/config/models/snapshots_config.py +++ b/graphrag/config/models/snapshots_config.py @@ -19,3 +19,7 @@ class SnapshotsConfig(BaseModel): description="A flag indicating whether to take snapshots of GraphML.", default=graphrag_config_defaults.snapshots.graphml, ) + raw_graph: bool = Field( + description="A flag indicating whether to take snapshots of the raw extracted graph (entities and relationships) before merging.", + default=graphrag_config_defaults.snapshots.raw_graph, + ) diff --git a/graphrag/index/workflows/extract_graph.py b/graphrag/index/workflows/extract_graph.py index 2a85be4b..84f8647e 100644 --- a/graphrag/index/workflows/extract_graph.py +++ b/graphrag/index/workflows/extract_graph.py @@ -43,7 +43,7 @@ async def run_workflow( config.root_dir, summarization_llm_settings ) - entities, relationships = await extract_graph( + entities, relationships, raw_entities, raw_relationships = await extract_graph( text_units=text_units, callbacks=context.callbacks, cache=context.cache, @@ -58,6 +58,12 @@ async def run_workflow( await write_table_to_storage(entities, "entities", context.storage) await write_table_to_storage(relationships, "relationships", context.storage) + if config.snapshots.raw_graph: + await write_table_to_storage(raw_entities, "raw_entities", context.storage) + await write_table_to_storage( + raw_relationships, "raw_relationships", context.storage + ) + return WorkflowFunctionOutput( result={ "entities": entities, @@ -76,7 +82,7 @@ async def extract_graph( entity_types: list[str] | None = None, summarization_strategy: dict[str, Any] | None = None, summarization_num_threads: int = 4, -) -> tuple[pd.DataFrame, pd.DataFrame]: +) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]: """All the steps to create the base entity graph.""" # this returns a graph for each text unit, to be merged later extracted_entities, extracted_relationships = await extractor( @@ -103,6 +109,10 @@ async def extract_graph( callbacks.error(error_msg) raise ValueError(error_msg) + # copy these as is before any summarization + raw_entities = extracted_entities.copy() + raw_relationships = extracted_relationships.copy() + entities, relationships = await get_summarized_entities_relationships( extracted_entities=extracted_entities, extracted_relationships=extracted_relationships, @@ -112,7 +122,7 @@ async def extract_graph( summarization_num_threads=summarization_num_threads, ) - return (entities, relationships) + return (entities, relationships, raw_entities, raw_relationships) async def get_summarized_entities_relationships(