diff --git a/.semversioner/next-release/patch-20241014220522452574.json b/.semversioner/next-release/patch-20241014220522452574.json new file mode 100644 index 00000000..ac1f1552 --- /dev/null +++ b/.semversioner/next-release/patch-20241014220522452574.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Collapse intermediate workflow outputs." +} diff --git a/.semversioner/next-release/patch-20241017015645367096.json b/.semversioner/next-release/patch-20241017015645367096.json new file mode 100644 index 00000000..5fd3e58e --- /dev/null +++ b/.semversioner/next-release/patch-20241017015645367096.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Allow empty workflow returns to avoid disk writing." +} diff --git a/docs/scripts/create_cookie_banner.js b/docs/scripts/create_cookie_banner.js new file mode 100644 index 00000000..16f986db --- /dev/null +++ b/docs/scripts/create_cookie_banner.js @@ -0,0 +1,18 @@ +function onConsentChanged(categoryPreferences) { + console.log("onConsentChanged", categoryPreferences); +} + + +cb = document.createElement("div"); +cb.id = "cookie-banner"; +document.body.insertBefore(cb, document.body.children[0]); + +window.WcpConsent && WcpConsent.init("en-US", "cookie-banner", function (err, consent) { + if (!err) { + console.log("consent: ", consent); + window.manageConsent = () => consent.manageConsent(); + siteConsent = consent; + } else { + console.log("Error initializing WcpConsent: "+ err); + } +}, onConsentChanged, WcpConsent.themes.light); \ No newline at end of file diff --git a/graphrag/index/context.py b/graphrag/index/context.py index e74799bd..48c1cd80 100644 --- a/graphrag/index/context.py +++ b/graphrag/index/context.py @@ -8,7 +8,7 @@ from dataclasses import dataclass as dc_dataclass from dataclasses import field from .cache import PipelineCache -from .storage.typing import PipelineStorage +from .storage.pipeline_storage import PipelineStorage @dc_dataclass diff --git a/graphrag/index/create_pipeline_config.py b/graphrag/index/create_pipeline_config.py index 0632ade2..e2778f3b 100644 --- a/graphrag/index/create_pipeline_config.py +++ b/graphrag/index/create_pipeline_config.py @@ -49,9 +49,7 @@ from graphrag.index.config.workflow import ( PipelineWorkflowReference, ) from graphrag.index.workflows.default_workflows import ( - create_base_documents, create_base_entity_graph, - create_base_extracted_entities, create_base_text_units, create_final_communities, create_final_community_reports, @@ -61,7 +59,6 @@ from graphrag.index.workflows.default_workflows import ( create_final_nodes, create_final_relationships, create_final_text_units, - create_summarized_entities, ) log = logging.getLogger(__name__) @@ -173,17 +170,12 @@ def _document_workflows( ) return [ PipelineWorkflowReference( - name=create_base_documents, + name=create_final_documents, config={ "document_attribute_columns": list( {*(settings.input.document_attribute_columns)} - builtin_document_attributes - ) - }, - ), - PipelineWorkflowReference( - name=create_final_documents, - config={ + ), "document_raw_content_embed": _get_embedding_settings( settings.embeddings, "document_raw_content", @@ -267,10 +259,9 @@ def _graph_workflows( ) return [ PipelineWorkflowReference( - name=create_base_extracted_entities, + name=create_base_entity_graph, config={ "graphml_snapshot": settings.snapshots.graphml, - "raw_entity_snapshot": settings.snapshots.raw_entities, "entity_extract": { **settings.entity_extraction.parallelization.model_dump(), "async_mode": settings.entity_extraction.async_mode, @@ -279,12 +270,6 @@ def _graph_workflows( ), "entity_types": settings.entity_extraction.entity_types, }, - }, - ), - PipelineWorkflowReference( - name=create_summarized_entities, - config={ - "graphml_snapshot": settings.snapshots.graphml, "summarize_descriptions": { **settings.summarize_descriptions.parallelization.model_dump(), "async_mode": settings.summarize_descriptions.async_mode, @@ -292,12 +277,6 @@ def _graph_workflows( settings.root_dir, ), }, - }, - ), - PipelineWorkflowReference( - name=create_base_entity_graph, - config={ - "graphml_snapshot": settings.snapshots.graphml, "embed_graph_enabled": settings.embed_graph.enabled, "cluster_graph": { "strategy": settings.cluster_graph.resolved_strategy() diff --git a/graphrag/index/flows/create_base_documents.py b/graphrag/index/flows/create_base_documents.py deleted file mode 100644 index 3f1ba29e..00000000 --- a/graphrag/index/flows/create_base_documents.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Transform base documents by joining them with their text_units and adding optional attributes.""" - -import pandas as pd - - -def create_base_documents( - documents: pd.DataFrame, - text_units: pd.DataFrame, - document_attribute_columns: list[str] | None = None, -) -> pd.DataFrame: - """Transform base documents by joining them with their text_units and adding optional attributes.""" - exploded = ( - text_units.explode("document_ids") - .loc[:, ["id", "document_ids", "text"]] - .rename( - columns={ - "document_ids": "chunk_doc_id", - "id": "chunk_id", - "text": "chunk_text", - } - ) - ) - - joined = exploded.merge( - documents, - left_on="chunk_doc_id", - right_on="id", - how="inner", - copy=False, - ) - - docs_with_text_units = joined.groupby("id", sort=False).agg( - text_units=("chunk_id", list) - ) - - rejoined = docs_with_text_units.merge( - documents, - on="id", - how="right", - copy=False, - ).reset_index(drop=True) - - rejoined.rename(columns={"text": "raw_content"}, inplace=True) - rejoined["id"] = rejoined["id"].astype(str) - - # Convert attribute columns to strings and collapse them into a JSON object - if document_attribute_columns: - # Convert all specified columns to string at once - rejoined[document_attribute_columns] = rejoined[ - document_attribute_columns - ].astype(str) - - # Collapse the document_attribute_columns into a single JSON object column - rejoined["attributes"] = rejoined[document_attribute_columns].to_dict( - orient="records" - ) - - # Drop the original attribute columns after collapsing them - rejoined.drop(columns=document_attribute_columns, inplace=True) - - return rejoined diff --git a/graphrag/index/flows/create_base_entity_graph.py b/graphrag/index/flows/create_base_entity_graph.py index 39880a45..68ae2193 100644 --- a/graphrag/index/flows/create_base_entity_graph.py +++ b/graphrag/index/flows/create_base_entity_graph.py @@ -7,26 +7,76 @@ from typing import Any, cast import pandas as pd from datashaper import ( + AsyncType, VerbCallbacks, ) +from graphrag.index.cache import PipelineCache from graphrag.index.operations.cluster_graph import cluster_graph from graphrag.index.operations.embed_graph import embed_graph +from graphrag.index.operations.extract_entities import extract_entities +from graphrag.index.operations.merge_graphs import merge_graphs +from graphrag.index.operations.snapshot import snapshot +from graphrag.index.operations.snapshot_graphml import snapshot_graphml from graphrag.index.operations.snapshot_rows import snapshot_rows +from graphrag.index.operations.summarize_descriptions import ( + summarize_descriptions, +) from graphrag.index.storage import PipelineStorage async def create_base_entity_graph( - entities: pd.DataFrame, + text_units: pd.DataFrame, callbacks: VerbCallbacks, + cache: PipelineCache, storage: PipelineStorage, + text_column: str, + id_column: str, clustering_strategy: dict[str, Any], - embedding_strategy: dict[str, Any] | None, + extraction_strategy: dict[str, Any] | None = None, + extraction_num_threads: int = 4, + extraction_async_mode: AsyncType = AsyncType.AsyncIO, + entity_types: list[str] | None = None, + node_merge_config: dict[str, Any] | None = None, + edge_merge_config: dict[str, Any] | None = None, + summarization_strategy: dict[str, Any] | None = None, + summarization_num_threads: int = 4, + embedding_strategy: dict[str, Any] | None = None, graphml_snapshot_enabled: bool = False, + raw_entity_snapshot_enabled: bool = False, ) -> pd.DataFrame: """All the steps to create the base entity graph.""" + # this returns a graph for each text unit, to be merged later + entities, entity_graphs = await extract_entities( + text_units, + callbacks, + cache, + text_column=text_column, + id_column=id_column, + strategy=extraction_strategy, + async_mode=extraction_async_mode, + entity_types=entity_types, + to="entities", + num_threads=extraction_num_threads, + ) + + merged_graph = merge_graphs( + entity_graphs, + callbacks, + node_operations=node_merge_config, + edge_operations=edge_merge_config, + ) + + summarized = await summarize_descriptions( + merged_graph, + callbacks, + cache, + strategy=summarization_strategy, + num_threads=summarization_num_threads, + ) + clustered = cluster_graph( - entities, + summarized, callbacks, column="entity_graph", strategy=clustering_strategy, @@ -34,15 +84,6 @@ async def create_base_entity_graph( level_to="level", ) - if graphml_snapshot_enabled: - await snapshot_rows( - clustered, - column="clustered_graph", - base_name="clustered_graph", - storage=storage, - formats=[{"format": "text", "extension": "graphml"}], - ) - if embedding_strategy: clustered["embeddings"] = await embed_graph( clustered, @@ -51,16 +92,40 @@ async def create_base_entity_graph( strategy=embedding_strategy, ) - # take second snapshot after embedding - # todo: this could be skipped if embedding isn't performed, other wise it is a copy of the regular graph? + if raw_entity_snapshot_enabled: + await snapshot( + entities, + name="raw_extracted_entities", + storage=storage, + formats=["json"], + ) + if graphml_snapshot_enabled: + await snapshot_graphml( + merged_graph, + name="merged_graph", + storage=storage, + ) + await snapshot_graphml( + summarized, + name="summarized_graph", + storage=storage, + ) await snapshot_rows( clustered, - column="entity_graph", - base_name="embedded_graph", + column="clustered_graph", + base_name="clustered_graph", storage=storage, formats=[{"format": "text", "extension": "graphml"}], ) + if embedding_strategy: + await snapshot_rows( + clustered, + column="entity_graph", + base_name="embedded_graph", + storage=storage, + formats=[{"format": "text", "extension": "graphml"}], + ) final_columns = ["level", "clustered_graph"] if embedding_strategy: diff --git a/graphrag/index/flows/create_base_extracted_entities.py b/graphrag/index/flows/create_base_extracted_entities.py deleted file mode 100644 index bfbf4d23..00000000 --- a/graphrag/index/flows/create_base_extracted_entities.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""All the steps to extract and format covariates.""" - -from typing import Any - -import pandas as pd -from datashaper import ( - AsyncType, - VerbCallbacks, -) - -from graphrag.index.cache import PipelineCache -from graphrag.index.operations.extract_entities import extract_entities -from graphrag.index.operations.merge_graphs import merge_graphs -from graphrag.index.operations.snapshot import snapshot -from graphrag.index.operations.snapshot_rows import snapshot_rows -from graphrag.index.storage import PipelineStorage - - -async def create_base_extracted_entities( - text_units: pd.DataFrame, - callbacks: VerbCallbacks, - cache: PipelineCache, - storage: PipelineStorage, - column: str, - id_column: str, - nodes: dict[str, Any], - edges: dict[str, Any], - extraction_strategy: dict[str, Any] | None, - async_mode: AsyncType = AsyncType.AsyncIO, - entity_types: list[str] | None = None, - graphml_snapshot_enabled: bool = False, - raw_entity_snapshot_enabled: bool = False, - num_threads: int = 4, -) -> pd.DataFrame: - """All the steps to extract and format covariates.""" - entity_graph = await extract_entities( - text_units, - callbacks, - cache, - column=column, - id_column=id_column, - strategy=extraction_strategy, - async_mode=async_mode, - entity_types=entity_types, - to="entities", - graph_to="entity_graph", - num_threads=num_threads, - ) - - if raw_entity_snapshot_enabled: - await snapshot( - entity_graph, - name="raw_extracted_entities", - storage=storage, - formats=["json"], - ) - - merged_graph = merge_graphs( - entity_graph, - callbacks, - column="entity_graph", - to="entity_graph", - nodes=nodes, - edges=edges, - ) - - if graphml_snapshot_enabled: - await snapshot_rows( - merged_graph, - base_name="merged_graph", - column="entity_graph", - storage=storage, - formats=[{"format": "text", "extension": "graphml"}], - ) - - return merged_graph diff --git a/graphrag/index/flows/create_final_documents.py b/graphrag/index/flows/create_final_documents.py index 29504000..29f28e52 100644 --- a/graphrag/index/flows/create_final_documents.py +++ b/graphrag/index/flows/create_final_documents.py @@ -14,20 +14,71 @@ from graphrag.index.operations.embed_text import embed_text async def create_final_documents( documents: pd.DataFrame, + text_units: pd.DataFrame, callbacks: VerbCallbacks, cache: PipelineCache, + document_attribute_columns: list[str] | None = None, raw_content_text_embed: dict | None = None, ) -> pd.DataFrame: """All the steps to transform final documents.""" - documents.rename(columns={"text_units": "text_unit_ids"}, inplace=True) + exploded = ( + text_units.explode("document_ids") + .loc[:, ["id", "document_ids", "text"]] + .rename( + columns={ + "document_ids": "chunk_doc_id", + "id": "chunk_id", + "text": "chunk_text", + } + ) + ) + + joined = exploded.merge( + documents, + left_on="chunk_doc_id", + right_on="id", + how="inner", + copy=False, + ) + + docs_with_text_units = joined.groupby("id", sort=False).agg( + text_units=("chunk_id", list) + ) + + rejoined = docs_with_text_units.merge( + documents, + on="id", + how="right", + copy=False, + ).reset_index(drop=True) + + rejoined.rename( + columns={"text": "raw_content", "text_units": "text_unit_ids"}, inplace=True + ) + rejoined["id"] = rejoined["id"].astype(str) + + # Convert attribute columns to strings and collapse them into a JSON object + if document_attribute_columns: + # Convert all specified columns to string at once + rejoined[document_attribute_columns] = rejoined[ + document_attribute_columns + ].astype(str) + + # Collapse the document_attribute_columns into a single JSON object column + rejoined["attributes"] = rejoined[document_attribute_columns].to_dict( + orient="records" + ) + + # Drop the original attribute columns after collapsing them + rejoined.drop(columns=document_attribute_columns, inplace=True) if raw_content_text_embed: - documents["raw_content_embedding"] = await embed_text( - documents, + rejoined["raw_content_embedding"] = await embed_text( + rejoined, callbacks, cache, column="raw_content", strategy=raw_content_text_embed["strategy"], ) - return documents + return rejoined diff --git a/graphrag/index/flows/create_summarized_entities.py b/graphrag/index/flows/create_summarized_entities.py deleted file mode 100644 index a9a5d59a..00000000 --- a/graphrag/index/flows/create_summarized_entities.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""All the steps to summarize entities.""" - -from typing import Any - -import pandas as pd -from datashaper import ( - VerbCallbacks, -) - -from graphrag.index.cache import PipelineCache -from graphrag.index.operations.snapshot_rows import snapshot_rows -from graphrag.index.operations.summarize_descriptions import ( - summarize_descriptions, -) -from graphrag.index.storage import PipelineStorage - - -async def create_summarized_entities( - entities: pd.DataFrame, - callbacks: VerbCallbacks, - cache: PipelineCache, - storage: PipelineStorage, - summarization_strategy: dict[str, Any] | None = None, - num_threads: int = 4, - graphml_snapshot_enabled: bool = False, -) -> pd.DataFrame: - """All the steps to summarize entities.""" - summarized = await summarize_descriptions( - entities, - callbacks, - cache, - column="entity_graph", - to="entity_graph", - strategy=summarization_strategy, - num_threads=num_threads, - ) - - if graphml_snapshot_enabled: - await snapshot_rows( - summarized, - column="entity_graph", - base_name="summarized_graph", - storage=storage, - formats=[{"format": "text", "extension": "graphml"}], - ) - - return summarized diff --git a/graphrag/index/operations/cluster_graph.py b/graphrag/index/operations/cluster_graph.py index 731c4b5b..b993789d 100644 --- a/graphrag/index/operations/cluster_graph.py +++ b/graphrag/index/operations/cluster_graph.py @@ -14,7 +14,7 @@ from datashaper import VerbCallbacks, progress_iterable from graspologic.partition import hierarchical_leiden from graphrag.index.graph.utils import stable_largest_connected_component -from graphrag.index.utils import gen_uuid, load_graph +from graphrag.index.utils import gen_uuid Communities = list[tuple[int, str, list[str]]] @@ -33,7 +33,7 @@ log = logging.getLogger(__name__) def cluster_graph( - input: pd.DataFrame, + input: nx.Graph, callbacks: VerbCallbacks, strategy: dict[str, Any], column: str, @@ -41,63 +41,64 @@ def cluster_graph( level_to: str | None = None, ) -> pd.DataFrame: """Apply a hierarchical clustering algorithm to a graph.""" - results = input[column].apply(lambda graph: run_layout(strategy, graph)) + output = pd.DataFrame() + # TODO: for back-compat, downstream expects a graphml string + output[column] = ["\n".join(nx.generate_graphml(input))] + communities = run_layout(strategy, input) community_map_to = "communities" - input[community_map_to] = results + output[community_map_to] = [communities] level_to = level_to or f"{to}_level" - input[level_to] = input.apply( + output[level_to] = output.apply( lambda x: list({level for level, _, _ in x[community_map_to]}), axis=1 ) - input[to] = None + output[to] = None - num_total = len(input) + num_total = len(output) # Create a seed for this run (if not provided) seed = strategy.get("seed", Random().randint(0, 0xFFFFFFFF)) # noqa S311 # Go through each of the rows graph_level_pairs_column: list[list[tuple[int, str]]] = [] - for _, row in progress_iterable(input.iterrows(), callbacks.progress, num_total): + for _, row in progress_iterable(output.iterrows(), callbacks.progress, num_total): levels = row[level_to] graph_level_pairs: list[tuple[int, str]] = [] # For each of the levels, get the graph and add it to the list for level in levels: - graph = "\n".join( + graphml = "\n".join( nx.generate_graphml( apply_clustering( - cast(str, row[column]), + input, cast(Communities, row[community_map_to]), level, seed=seed, ) ) ) - graph_level_pairs.append((level, graph)) + graph_level_pairs.append((level, graphml)) graph_level_pairs_column.append(graph_level_pairs) - input[to] = graph_level_pairs_column + output[to] = graph_level_pairs_column # explode the list of (level, graph) pairs into separate rows - input = input.explode(to, ignore_index=True) + output = output.explode(to, ignore_index=True) # split the (level, graph) pairs into separate columns # TODO: There is probably a better way to do this - input[[level_to, to]] = pd.DataFrame(input[to].tolist(), index=input.index) + output[[level_to, to]] = pd.DataFrame(output[to].tolist(), index=output.index) # clean up the community map - input.drop(columns=[community_map_to], inplace=True) - return input + output.drop(columns=[community_map_to], inplace=True) + return output -# TODO: This should support str | nx.Graph as a graphml param def apply_clustering( - graphml: str, communities: Communities, level: int = 0, seed: int | None = None + graph: nx.Graph, communities: Communities, level: int = 0, seed: int | None = None ) -> nx.Graph: - """Apply clustering to a graphml string.""" + """Apply clustering to a graph.""" random = Random(seed) # noqa S311 - graph = nx.parse_graphml(graphml) for community_level, community_id, nodes in communities: if level == community_level: for node in nodes: @@ -121,11 +122,8 @@ def apply_clustering( return graph -def run_layout( - strategy: dict[str, Any], graphml_or_graph: str | nx.Graph -) -> Communities: +def run_layout(strategy: dict[str, Any], graph: nx.Graph) -> Communities: """Run layout method definition.""" - graph = load_graph(graphml_or_graph) if len(graph.nodes) == 0: log.warning("Graph has no nodes") return [] diff --git a/graphrag/index/operations/extract_entities/extract_entities.py b/graphrag/index/operations/extract_entities/extract_entities.py index 77f29dd6..96bec73b 100644 --- a/graphrag/index/operations/extract_entities/extract_entities.py +++ b/graphrag/index/operations/extract_entities/extract_entities.py @@ -7,6 +7,7 @@ import logging from enum import Enum from typing import Any +import networkx as nx import pandas as pd from datashaper import ( AsyncType, @@ -41,15 +42,14 @@ async def extract_entities( input: pd.DataFrame, callbacks: VerbCallbacks, cache: PipelineCache, - column: str, + text_column: str, id_column: str, to: str, strategy: dict[str, Any] | None, - graph_to: str | None = None, async_mode: AsyncType = AsyncType.AsyncIO, entity_types=DEFAULT_ENTITY_TYPES, num_threads: int = 4, -) -> pd.DataFrame: +) -> tuple[pd.DataFrame, list[nx.Graph]]: """ Extract entities from a piece of text. @@ -59,7 +59,6 @@ async def extract_entities( column: the_document_text_column_to_extract_entities_from id_column: the_column_with_the_unique_id_for_each_row to: the_column_to_output_the_entities_to - graph_to: the_column_to_output_the_graphml_to strategy: , see strategies section below summarize_descriptions: true | false /* Optional: This will summarize the descriptions of the entities and relationships, default: true */ entity_types: @@ -124,7 +123,7 @@ async def extract_entities( async def run_strategy(row): nonlocal num_started - text = row[column] + text = row[text_column] id = row[id_column] result = await strategy_exec( [Document(text=text, id=id)], @@ -134,7 +133,7 @@ async def extract_entities( strategy_config, ) num_started += 1 - return [result.entities, result.graphml_graph] + return [result.entities, result.graph] results = await derive_from_rows( input, @@ -145,20 +144,18 @@ async def extract_entities( ) to_result = [] - graph_to_result = [] + graphs = [] for result in results: if result: to_result.append(result[0]) - graph_to_result.append(result[1]) + graphs.append(result[1]) else: to_result.append(None) - graph_to_result.append(None) + graphs.append(None) input[to] = to_result - if graph_to is not None: - input[graph_to] = graph_to_result - return input.reset_index(drop=True) + return (input.reset_index(drop=True), graphs) def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy: diff --git a/graphrag/index/operations/extract_entities/strategies/graph_intelligence.py b/graphrag/index/operations/extract_entities/strategies/graph_intelligence.py index 1536df34..072d5bed 100644 --- a/graphrag/index/operations/extract_entities/strategies/graph_intelligence.py +++ b/graphrag/index/operations/extract_entities/strategies/graph_intelligence.py @@ -3,7 +3,6 @@ """A module containing run_graph_intelligence, run_extract_entities and _create_text_splitter methods to run graph intelligence.""" -import networkx as nx from datashaper import VerbCallbacks import graphrag.config.defaults as defs @@ -113,8 +112,7 @@ async def run_extract_entities( if item is not None ] - graph_data = "".join(nx.generate_graphml(graph)) - return EntityExtractionResult(entities, graph_data) + return EntityExtractionResult(entities, graph) def _create_text_splitter( diff --git a/graphrag/index/operations/extract_entities/strategies/nltk.py b/graphrag/index/operations/extract_entities/strategies/nltk.py index 9403c5a5..8f9aefa0 100644 --- a/graphrag/index/operations/extract_entities/strategies/nltk.py +++ b/graphrag/index/operations/extract_entities/strategies/nltk.py @@ -57,5 +57,5 @@ async def run( # noqa RUF029 async is required for interface {"type": entity_type, "name": name} for name, entity_type in entity_map.items() ], - graphml_graph="".join(nx.generate_graphml(graph)), + graph=graph, ) diff --git a/graphrag/index/operations/extract_entities/strategies/typing.py b/graphrag/index/operations/extract_entities/strategies/typing.py index 45d3f1b8..e1c548b0 100644 --- a/graphrag/index/operations/extract_entities/strategies/typing.py +++ b/graphrag/index/operations/extract_entities/strategies/typing.py @@ -7,6 +7,7 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass from typing import Any +import networkx as nx from datashaper import VerbCallbacks from graphrag.index.cache import PipelineCache @@ -29,7 +30,7 @@ class EntityExtractionResult: """Entity extraction result class definition.""" entities: list[ExtractedEntity] - graphml_graph: str | None + graph: nx.Graph | None EntityExtractStrategy = Callable[ diff --git a/graphrag/index/operations/merge_graphs/merge_graphs.py b/graphrag/index/operations/merge_graphs/merge_graphs.py index ca654e6e..80ab20ef 100644 --- a/graphrag/index/operations/merge_graphs/merge_graphs.py +++ b/graphrag/index/operations/merge_graphs/merge_graphs.py @@ -3,14 +3,11 @@ """A module containing merge_graphs, merge_nodes, merge_edges, merge_attributes, apply_merge_operation and _get_detailed_attribute_merge_operation methods definitions.""" -from typing import Any, cast +from typing import Any import networkx as nx -import pandas as pd from datashaper import VerbCallbacks, progress_iterable -from graphrag.index.utils import load_graph - from .typing import ( BasicMergeOperation, DetailedAttributeMergeOperation, @@ -35,15 +32,13 @@ DEFAULT_CONCAT_SEPARATOR = "," def merge_graphs( - input: pd.DataFrame, + graphs: list[nx.Graph], callbacks: VerbCallbacks, - column: str, - to: str, - nodes: dict[str, Any] = DEFAULT_NODE_OPERATIONS, - edges: dict[str, Any] = DEFAULT_EDGE_OPERATIONS, -) -> pd.DataFrame: + node_operations: dict[str, Any] | None, + edge_operations: dict[str, Any] | None, +) -> nx.Graph: """ - Merge multiple graphs together. The graphs are expected to be in graphml format. The verb outputs a new column containing the merged graph. + Merge multiple graphs together. The graphs are expected to be in nx.Graph format. The verb outputs a new column containing the merged graph. > Note: This will merge all rows into a single graph. @@ -51,8 +46,6 @@ def merge_graphs( ```yaml verb: merge_graph args: - column: clustered_graph # The name of the column containing the graph, should be a graphml graph - to: merged_graph # The name of the column to output the merged graph to nodes: # See node operations section below edges: # See edge operations section below ``` @@ -90,8 +83,8 @@ def merge_graphs( - __average__: This operation takes the mean of the attribute with the last value seen. - __multiply__: This operation multiplies the attribute with the last value seen. """ - output = pd.DataFrame() - + nodes = node_operations or DEFAULT_NODE_OPERATIONS + edges = edge_operations or DEFAULT_EDGE_OPERATIONS node_ops = { attrib: _get_detailed_attribute_merge_operation(value) for attrib, value in nodes.items() @@ -102,15 +95,12 @@ def merge_graphs( } mega_graph = nx.Graph() - num_total = len(input) - for graphml in progress_iterable(input[column], callbacks.progress, num_total): - graph = load_graph(cast(str | nx.Graph, graphml)) + num_total = len(graphs) + for graph in progress_iterable(graphs, callbacks.progress, num_total): merge_nodes(mega_graph, graph, node_ops) merge_edges(mega_graph, graph, edge_ops) - output[to] = ["\n".join(nx.generate_graphml(mega_graph))] - - return output + return mega_graph def merge_nodes( diff --git a/graphrag/index/operations/snapshot_graphml.py b/graphrag/index/operations/snapshot_graphml.py new file mode 100644 index 00000000..07a174fa --- /dev/null +++ b/graphrag/index/operations/snapshot_graphml.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing snapshot method definition.""" + +import networkx as nx + +from graphrag.index.storage import PipelineStorage + + +async def snapshot_graphml( + input: str | nx.Graph, + name: str, + storage: PipelineStorage, +) -> None: + """Take a entire snapshot of a graph to standard graphml format.""" + graphml = input if isinstance(input, str) else "\n".join(nx.generate_graphml(input)) + await storage.set(name + ".graphml", graphml) diff --git a/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py b/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py index 40af8dfc..07754473 100644 --- a/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py +++ b/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py @@ -5,10 +5,9 @@ import asyncio import logging -from typing import Any, cast +from typing import Any import networkx as nx -import pandas as pd from datashaper import ( ProgressTicker, VerbCallbacks, @@ -16,10 +15,8 @@ from datashaper import ( ) from graphrag.index.cache import PipelineCache -from graphrag.index.utils import load_graph from .typing import ( - DescriptionSummarizeRow, SummarizationStrategy, SummarizeStrategyType, ) @@ -28,14 +25,12 @@ log = logging.getLogger(__name__) async def summarize_descriptions( - input: pd.DataFrame, + input: nx.Graph, callbacks: VerbCallbacks, cache: PipelineCache, - column: str, - to: str, strategy: dict[str, Any] | None = None, - **kwargs, -) -> pd.DataFrame: + num_threads: int = 4, +) -> nx.Graph: """ Summarize entity and relationship descriptions from an entity graph. @@ -43,25 +38,10 @@ async def summarize_descriptions( To turn this feature ON please set the environment variable `GRAPHRAG_SUMMARIZE_DESCRIPTIONS_ENABLED=True`. - ### json - - ```json - { - "verb": "", - "args": { - "column": "the_document_text_column_to_extract_descriptions_from", /* Required: This will be a graphml graph in string form which represents the entities and their relationships */ - "to": "the_column_to_output_the_summarized_descriptions_to", /* Required: This will be a graphml graph in string form which represents the entities and their relationships after being summarized */ - "strategy": {...} , see strategies section below - } - } - ``` - ### yaml ```yaml args: - column: the_document_text_column_to_extract_descriptions_from - to: the_column_to_output_the_summarized_descriptions_to strategy: , see strategies section below ``` @@ -99,9 +79,7 @@ async def summarize_descriptions( ) strategy_config = {**strategy} - async def get_resolved_entities(row, semaphore: asyncio.Semaphore): - graph: nx.Graph = load_graph(cast(str | nx.Graph, getattr(row, column))) - + async def get_resolved_entities(graph: nx.Graph, semaphore: asyncio.Semaphore): ticker_length = len(graph.nodes) + len(graph.edges) ticker = progress_ticker(callbacks.progress, ticker_length) @@ -134,9 +112,7 @@ async def summarize_descriptions( elif isinstance(graph_item, tuple) and graph_item in graph.edges(): graph.edges[graph_item]["description"] = result.description - return DescriptionSummarizeRow( - graph="\n".join(nx.generate_graphml(graph)), - ) + return graph async def do_summarize_descriptions( graph_item: str | tuple[str, str], @@ -155,25 +131,9 @@ async def summarize_descriptions( ticker(1) return results - # Graph is always on row 0, so here a derive from rows does not work - # This iteration will only happen once, but avoids hardcoding a iloc[0] - # Since parallelization is at graph level (nodes and edges), we can't use - # the parallelization of the derive_from_rows - semaphore = asyncio.Semaphore(kwargs.get("num_threads", 4)) + semaphore = asyncio.Semaphore(num_threads) - results = [ - await get_resolved_entities(row, semaphore) for row in input.itertuples() - ] - - to_result = [] - - for result in results: - if result: - to_result.append(result.graph) - else: - to_result.append(None) - input[to] = to_result - return input + return await get_resolved_entities(input, semaphore) def load_strategy(strategy_type: SummarizeStrategyType) -> SummarizationStrategy: diff --git a/graphrag/index/run/profiling.py b/graphrag/index/run/profiling.py index 8640f027..d1a54a66 100644 --- a/graphrag/index/run/profiling.py +++ b/graphrag/index/run/profiling.py @@ -11,7 +11,7 @@ from dataclasses import asdict from datashaper import MemoryProfile, Workflow, WorkflowRunResult from graphrag.index.context import PipelineRunStats -from graphrag.index.storage.typing import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage log = logging.getLogger(__name__) diff --git a/graphrag/index/run/utils.py b/graphrag/index/run/utils.py index 3617b925..0b4e9a80 100644 --- a/graphrag/index/run/utils.py +++ b/graphrag/index/run/utils.py @@ -33,7 +33,7 @@ from graphrag.index.config.storage import ( from graphrag.index.context import PipelineRunContext, PipelineRunStats from graphrag.index.input import load_input from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage -from graphrag.index.storage.typing import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage from graphrag.logging import ProgressReporter log = logging.getLogger(__name__) diff --git a/graphrag/index/run/workflow.py b/graphrag/index/run/workflow.py index 3c57d112..aa288262 100644 --- a/graphrag/index/run/workflow.py +++ b/graphrag/index/run/workflow.py @@ -19,7 +19,7 @@ from graphrag.callbacks.progress_workflow_callbacks import ProgressWorkflowCallb from graphrag.index.context import PipelineRunContext from graphrag.index.emit.table_emitter import TableEmitter from graphrag.index.run.profiling import _write_workflow_stats -from graphrag.index.storage.typing import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage from graphrag.index.typing import PipelineRunResult from graphrag.logging import ProgressReporter from graphrag.utils.storage import _load_table_from_storage @@ -48,8 +48,12 @@ async def _emit_workflow_output( ) -> pd.DataFrame: """Emit the workflow output.""" output = cast(pd.DataFrame, workflow.output()) - for emitter in emitters: - await emitter.emit(workflow.name, output) + # only write the final output if it has content + # this is expressly designed to allow us to create + # workflows with side effects that don't produce a formal output to save + if not output.empty: + for emitter in emitters: + await emitter.emit(workflow.name, output) return output diff --git a/graphrag/index/storage/__init__.py b/graphrag/index/storage/__init__.py index 7ca943db..d1025d24 100644 --- a/graphrag/index/storage/__init__.py +++ b/graphrag/index/storage/__init__.py @@ -7,7 +7,7 @@ from .blob_pipeline_storage import BlobPipelineStorage, create_blob_storage from .file_pipeline_storage import FilePipelineStorage from .load_storage import load_storage from .memory_pipeline_storage import MemoryPipelineStorage -from .typing import PipelineStorage +from .pipeline_storage import PipelineStorage __all__ = [ "BlobPipelineStorage", diff --git a/graphrag/index/storage/blob_pipeline_storage.py b/graphrag/index/storage/blob_pipeline_storage.py index 456fe7aa..5494c5a6 100644 --- a/graphrag/index/storage/blob_pipeline_storage.py +++ b/graphrag/index/storage/blob_pipeline_storage.py @@ -15,7 +15,7 @@ from datashaper import Progress from graphrag.logging import ProgressReporter -from .typing import PipelineStorage +from .pipeline_storage import PipelineStorage log = logging.getLogger(__name__) diff --git a/graphrag/index/storage/file_pipeline_storage.py b/graphrag/index/storage/file_pipeline_storage.py index 8e51fedd..ccb06d1f 100644 --- a/graphrag/index/storage/file_pipeline_storage.py +++ b/graphrag/index/storage/file_pipeline_storage.py @@ -18,7 +18,7 @@ from datashaper import Progress from graphrag.logging import ProgressReporter -from .typing import PipelineStorage +from .pipeline_storage import PipelineStorage log = logging.getLogger(__name__) diff --git a/graphrag/index/storage/memory_pipeline_storage.py b/graphrag/index/storage/memory_pipeline_storage.py index 9c796526..ed490983 100644 --- a/graphrag/index/storage/memory_pipeline_storage.py +++ b/graphrag/index/storage/memory_pipeline_storage.py @@ -6,7 +6,7 @@ from typing import Any from .file_pipeline_storage import FilePipelineStorage -from .typing import PipelineStorage +from .pipeline_storage import PipelineStorage class MemoryPipelineStorage(FilePipelineStorage): @@ -34,9 +34,7 @@ class MemoryPipelineStorage(FilePipelineStorage): """ return self._storage.get(key) or await super().get(key, as_bytes, encoding) - async def set( - self, key: str, value: str | bytes | None, encoding: str | None = None - ) -> None: + async def set(self, key: str, value: Any, encoding: str | None = None) -> None: """Set the value for the given key. Args: diff --git a/graphrag/index/storage/typing.py b/graphrag/index/storage/pipeline_storage.py similarity index 94% rename from graphrag/index/storage/typing.py rename to graphrag/index/storage/pipeline_storage.py index 6eb727d7..08ccb409 100644 --- a/graphrag/index/storage/typing.py +++ b/graphrag/index/storage/pipeline_storage.py @@ -41,9 +41,7 @@ class PipelineStorage(metaclass=ABCMeta): """ @abstractmethod - async def set( - self, key: str, value: str | bytes | None, encoding: str | None = None - ) -> None: + async def set(self, key: str, value: Any, encoding: str | None = None) -> None: """Set the value for the given key. Args: diff --git a/graphrag/index/update/dataframes.py b/graphrag/index/update/dataframes.py index ee9ccf07..e92319a3 100644 --- a/graphrag/index/update/dataframes.py +++ b/graphrag/index/update/dataframes.py @@ -9,7 +9,7 @@ from dataclasses import dataclass import numpy as np import pandas as pd -from graphrag.index.storage.typing import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage from graphrag.utils.storage import _load_table_from_storage mergeable_outputs = [ diff --git a/graphrag/index/workflows/default_workflows.py b/graphrag/index/workflows/default_workflows.py index 70cb2cf3..896a861a 100644 --- a/graphrag/index/workflows/default_workflows.py +++ b/graphrag/index/workflows/default_workflows.py @@ -7,24 +7,12 @@ from .v1.subflows import * # noqa from .typing import WorkflowDefinitions -from .v1.create_base_documents import ( - build_steps as build_create_base_documents_steps, -) -from .v1.create_base_documents import ( - workflow_name as create_base_documents, -) from .v1.create_base_entity_graph import ( build_steps as build_create_base_entity_graph_steps, ) from .v1.create_base_entity_graph import ( workflow_name as create_base_entity_graph, ) -from .v1.create_base_extracted_entities import ( - build_steps as build_create_base_extracted_entities_steps, -) -from .v1.create_base_extracted_entities import ( - workflow_name as create_base_extracted_entities, -) from .v1.create_base_text_units import ( build_steps as build_create_base_text_units_steps, ) @@ -79,16 +67,9 @@ from .v1.create_final_text_units import ( from .v1.create_final_text_units import ( workflow_name as create_final_text_units, ) -from .v1.create_summarized_entities import ( - build_steps as build_create_summarized_entities_steps, -) -from .v1.create_summarized_entities import ( - workflow_name as create_summarized_entities, -) default_workflows: WorkflowDefinitions = { - create_base_extracted_entities: build_create_base_extracted_entities_steps, create_base_entity_graph: build_create_base_entity_graph_steps, create_base_text_units: build_create_base_text_units_steps, create_final_text_units: build_create_final_text_units, @@ -97,8 +78,6 @@ default_workflows: WorkflowDefinitions = { create_final_relationships: build_create_final_relationships_steps, create_final_documents: build_create_final_documents_steps, create_final_covariates: build_create_final_covariates_steps, - create_base_documents: build_create_base_documents_steps, create_final_entities: build_create_final_entities_steps, create_final_communities: build_create_final_communities_steps, - create_summarized_entities: build_create_summarized_entities_steps, } diff --git a/graphrag/index/workflows/v1/create_base_documents.py b/graphrag/index/workflows/v1/create_base_documents.py deleted file mode 100644 index 1186b1d4..00000000 --- a/graphrag/index/workflows/v1/create_base_documents.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing build_steps method definition.""" - -from datashaper import DEFAULT_INPUT_NAME - -from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep - -workflow_name = "create_base_documents" - - -def build_steps( - config: PipelineWorkflowConfig, -) -> list[PipelineWorkflowStep]: - """ - Create the documents table. - - ## Dependencies - * `workflow:create_final_text_units` - """ - document_attribute_columns = config.get("document_attribute_columns", []) - return [ - { - "verb": "create_base_documents", - "args": { - "document_attribute_columns": document_attribute_columns, - }, - "input": { - "source": DEFAULT_INPUT_NAME, - "text_units": "workflow:create_final_text_units", - }, - }, - ] diff --git a/graphrag/index/workflows/v1/create_base_entity_graph.py b/graphrag/index/workflows/v1/create_base_entity_graph.py index dadaece5..0b77b438 100644 --- a/graphrag/index/workflows/v1/create_base_entity_graph.py +++ b/graphrag/index/workflows/v1/create_base_entity_graph.py @@ -3,6 +3,10 @@ """A module containing build_steps method definition.""" +from datashaper import ( + AsyncType, +) + from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep workflow_name = "create_base_entity_graph" @@ -15,8 +19,53 @@ def build_steps( Create the base table for the entity graph. ## Dependencies - * `workflow:create_base_extracted_entities` + * `workflow:create_base_summarized_entities` """ + entity_extraction_config = config.get("entity_extract", {}) + text_column = entity_extraction_config.get("text_column", "chunk") + id_column = entity_extraction_config.get("id_column", "chunk_id") + async_mode = entity_extraction_config.get("async_mode", AsyncType.AsyncIO) + extraction_strategy = entity_extraction_config.get("strategy") + extraction_num_threads = entity_extraction_config.get("num_threads", 4) + entity_types = entity_extraction_config.get("entity_types") + + graph_merge_operations_config = config.get( + "graph_merge_operations", + { + "nodes": { + "source_id": { + "operation": "concat", + "delimiter": ", ", + "distinct": True, + }, + "description": ({ + "operation": "concat", + "separator": "\n", + "distinct": False, + }), + }, + "edges": { + "source_id": { + "operation": "concat", + "delimiter": ", ", + "distinct": True, + }, + "description": ({ + "operation": "concat", + "separator": "\n", + "distinct": False, + }), + "weight": "sum", + }, + }, + ) + node_merge_config = graph_merge_operations_config.get("nodes") + edge_merge_config = graph_merge_operations_config.get("edges") + + summarize_descriptions_config = config.get("summarize_descriptions", {}) + summarization_strategy = summarize_descriptions_config.get("strategy") + summarization_num_threads = summarize_descriptions_config.get("num_threads", 4) + clustering_config = config.get( "cluster_graph", {"strategy": {"type": "leiden"}}, @@ -40,17 +89,29 @@ def build_steps( embed_graph_enabled = config.get("embed_graph_enabled", False) or False graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False + raw_entity_snapshot_enabled = config.get("raw_entity_snapshot", False) or False return [ { "verb": "create_base_entity_graph", "args": { + "text_column": text_column, + "id_column": id_column, + "extraction_strategy": extraction_strategy, + "extraction_num_threads": extraction_num_threads, + "extraction_async_mode": async_mode, + "entity_types": entity_types, + "node_merge_config": node_merge_config, + "edge_merge_config": edge_merge_config, + "summarization_strategy": summarization_strategy, + "summarization_num_threads": summarization_num_threads, "clustering_strategy": clustering_strategy, - "graphml_snapshot_enabled": graphml_snapshot_enabled, "embedding_strategy": embedding_strategy if embed_graph_enabled else None, + "raw_entity_snapshot_enabled": raw_entity_snapshot_enabled, + "graphml_snapshot_enabled": graphml_snapshot_enabled, }, - "input": ({"source": "workflow:create_summarized_entities"}), + "input": ({"source": "workflow:create_base_text_units"}), }, ] diff --git a/graphrag/index/workflows/v1/create_base_extracted_entities.py b/graphrag/index/workflows/v1/create_base_extracted_entities.py deleted file mode 100644 index e18266b3..00000000 --- a/graphrag/index/workflows/v1/create_base_extracted_entities.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing build_steps method definition.""" - -from datashaper import AsyncType - -from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep - -workflow_name = "create_base_extracted_entities" - - -def build_steps( - config: PipelineWorkflowConfig, -) -> list[PipelineWorkflowStep]: - """ - Create the base table for extracted entities. - - ## Dependencies - * `workflow:create_base_text_units` - """ - entity_extraction_config = config.get("entity_extract", {}) - - column = entity_extraction_config.get("text_column", "chunk") - id_column = entity_extraction_config.get("id_column", "chunk_id") - async_mode = entity_extraction_config.get("async_mode", AsyncType.AsyncIO) - extraction_strategy = entity_extraction_config.get("strategy") - num_threads = entity_extraction_config.get("num_threads", 4) - entity_types = entity_extraction_config.get("entity_types") - - graph_merge_operations_config = config.get( - "graph_merge_operations", - { - "nodes": { - "source_id": { - "operation": "concat", - "delimiter": ", ", - "distinct": True, - }, - "description": ({ - "operation": "concat", - "separator": "\n", - "distinct": False, - }), - }, - "edges": { - "source_id": { - "operation": "concat", - "delimiter": ", ", - "distinct": True, - }, - "description": ({ - "operation": "concat", - "separator": "\n", - "distinct": False, - }), - "weight": "sum", - }, - }, - ) - nodes = graph_merge_operations_config.get("nodes") - edges = graph_merge_operations_config.get("edges") - - graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False - raw_entity_snapshot_enabled = config.get("raw_entity_snapshot", False) or False - - return [ - { - "verb": "create_base_extracted_entities", - "args": { - "column": column, - "id_column": id_column, - "async_mode": async_mode, - "extraction_strategy": extraction_strategy, - "num_threads": num_threads, - "entity_types": entity_types, - "nodes": nodes, - "edges": edges, - "raw_entity_snapshot_enabled": raw_entity_snapshot_enabled, - "graphml_snapshot_enabled": graphml_snapshot_enabled, - }, - "input": {"source": "workflow:create_base_text_units"}, - }, - ] diff --git a/graphrag/index/workflows/v1/create_final_covariates.py b/graphrag/index/workflows/v1/create_final_covariates.py index 1fdab708..0f8191ab 100644 --- a/graphrag/index/workflows/v1/create_final_covariates.py +++ b/graphrag/index/workflows/v1/create_final_covariates.py @@ -20,7 +20,6 @@ def build_steps( ## Dependencies * `workflow:create_base_text_units` - * `workflow:create_base_extracted_entities` """ claim_extract_config = config.get("claim_extract", {}) extraction_strategy = claim_extract_config.get("strategy") diff --git a/graphrag/index/workflows/v1/create_final_documents.py b/graphrag/index/workflows/v1/create_final_documents.py index 3fa21e31..33425e4e 100644 --- a/graphrag/index/workflows/v1/create_final_documents.py +++ b/graphrag/index/workflows/v1/create_final_documents.py @@ -3,6 +3,8 @@ """A module containing build_steps method definition.""" +from datashaper import DEFAULT_INPUT_NAME + from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep workflow_name = "create_final_documents" @@ -15,21 +17,26 @@ def build_steps( Create the final documents table. ## Dependencies - * `workflow:create_base_documents` + * `workflow:create_final_text_units` """ base_text_embed = config.get("text_embed", {}) document_raw_content_embed_config = config.get( "document_raw_content_embed", base_text_embed ) skip_raw_content_embedding = config.get("skip_raw_content_embedding", False) + document_attribute_columns = config.get("document_attribute_columns", []) return [ { "verb": "create_final_documents", "args": { + "document_attribute_columns": document_attribute_columns, "raw_content_text_embed": document_raw_content_embed_config if not skip_raw_content_embedding else None, }, - "input": {"source": "workflow:create_base_documents"}, + "input": { + "source": DEFAULT_INPUT_NAME, + "text_units": "workflow:create_final_text_units", + }, }, ] diff --git a/graphrag/index/workflows/v1/create_summarized_entities.py b/graphrag/index/workflows/v1/create_summarized_entities.py deleted file mode 100644 index 53821814..00000000 --- a/graphrag/index/workflows/v1/create_summarized_entities.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing build_steps method definition.""" - -from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep - -workflow_name = "create_summarized_entities" - - -def build_steps( - config: PipelineWorkflowConfig, -) -> list[PipelineWorkflowStep]: - """ - Create the base table for extracted entities. - - ## Dependencies - * `workflow:create_base_text_units` - """ - summarize_descriptions_config = config.get("summarize_descriptions", {}) - summarization_strategy = summarize_descriptions_config.get("strategy") - num_threads = summarize_descriptions_config.get("num_threads", 4) - - graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False - - return [ - { - "verb": "create_summarized_entities", - "args": { - "summarization_strategy": summarization_strategy, - "num_threads": num_threads, - "graphml_snapshot_enabled": graphml_snapshot_enabled, - }, - "input": {"source": "workflow:create_base_extracted_entities"}, - }, - ] diff --git a/graphrag/index/workflows/v1/subflows/__init__.py b/graphrag/index/workflows/v1/subflows/__init__.py index 857d20c4..37e20367 100644 --- a/graphrag/index/workflows/v1/subflows/__init__.py +++ b/graphrag/index/workflows/v1/subflows/__init__.py @@ -3,9 +3,7 @@ """The Indexing Engine workflows -> subflows package root.""" -from .create_base_documents import create_base_documents from .create_base_entity_graph import create_base_entity_graph -from .create_base_extracted_entities import create_base_extracted_entities from .create_base_text_units import create_base_text_units from .create_final_communities import create_final_communities from .create_final_community_reports import create_final_community_reports @@ -17,12 +15,9 @@ from .create_final_relationships import ( create_final_relationships, ) from .create_final_text_units import create_final_text_units -from .create_summarized_entities import create_summarized_entities __all__ = [ - "create_base_documents", "create_base_entity_graph", - "create_base_extracted_entities", "create_base_text_units", "create_final_communities", "create_final_community_reports", @@ -32,5 +27,4 @@ __all__ = [ "create_final_nodes", "create_final_relationships", "create_final_text_units", - "create_summarized_entities", ] diff --git a/graphrag/index/workflows/v1/subflows/create_base_documents.py b/graphrag/index/workflows/v1/subflows/create_base_documents.py deleted file mode 100644 index c3e52098..00000000 --- a/graphrag/index/workflows/v1/subflows/create_base_documents.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""All the steps to transform base documents.""" - -from typing import cast - -import pandas as pd -from datashaper import ( - Table, - VerbInput, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.index.flows.create_base_documents import ( - create_base_documents as create_base_documents_flow, -) -from graphrag.index.utils.ds_util import get_required_input_table - - -@verb(name="create_base_documents", treats_input_tables_as_immutable=True) -def create_base_documents( - input: VerbInput, - document_attribute_columns: list[str] | None = None, - **_kwargs: dict, -) -> VerbResult: - """All the steps to transform base documents.""" - source = cast(pd.DataFrame, input.get_input()) - text_units = cast(pd.DataFrame, get_required_input_table(input, "text_units").table) - - output = create_base_documents_flow( - source, text_units, document_attribute_columns=document_attribute_columns - ) - - return create_verb_result( - cast( - Table, - output, - ) - ) diff --git a/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py b/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py index 009da03f..2ab85b01 100644 --- a/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py +++ b/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py @@ -7,6 +7,7 @@ from typing import Any, cast import pandas as pd from datashaper import ( + AsyncType, Table, VerbCallbacks, VerbInput, @@ -14,6 +15,7 @@ from datashaper import ( ) from datashaper.table_store.types import VerbResult, create_verb_result +from graphrag.index.cache import PipelineCache from graphrag.index.flows.create_base_entity_graph import ( create_base_entity_graph as create_base_entity_graph_flow, ) @@ -27,10 +29,22 @@ from graphrag.index.storage import PipelineStorage async def create_base_entity_graph( input: VerbInput, callbacks: VerbCallbacks, + cache: PipelineCache, storage: PipelineStorage, + text_column: str, + id_column: str, clustering_strategy: dict[str, Any], - embedding_strategy: dict[str, Any] | None, + extraction_strategy: dict[str, Any] | None, + extraction_num_threads: int = 4, + extraction_async_mode: AsyncType = AsyncType.AsyncIO, + entity_types: list[str] | None = None, + node_merge_config: dict[str, Any] | None = None, + edge_merge_config: dict[str, Any] | None = None, + summarization_strategy: dict[str, Any] | None = None, + summarization_num_threads: int = 4, + embedding_strategy: dict[str, Any] | None = None, graphml_snapshot_enabled: bool = False, + raw_entity_snapshot_enabled: bool = False, **_kwargs: dict, ) -> VerbResult: """All the steps to create the base entity graph.""" @@ -39,10 +53,22 @@ async def create_base_entity_graph( output = await create_base_entity_graph_flow( source, callbacks, + cache, storage, - clustering_strategy, - embedding_strategy, + text_column, + id_column, + clustering_strategy=clustering_strategy, + extraction_strategy=extraction_strategy, + extraction_num_threads=extraction_num_threads, + extraction_async_mode=extraction_async_mode, + entity_types=entity_types, + node_merge_config=node_merge_config, + edge_merge_config=edge_merge_config, + summarization_strategy=summarization_strategy, + summarization_num_threads=summarization_num_threads, + embedding_strategy=embedding_strategy, graphml_snapshot_enabled=graphml_snapshot_enabled, + raw_entity_snapshot_enabled=raw_entity_snapshot_enabled, ) return create_verb_result(cast(Table, output)) diff --git a/graphrag/index/workflows/v1/subflows/create_base_extracted_entities.py b/graphrag/index/workflows/v1/subflows/create_base_extracted_entities.py deleted file mode 100644 index 34660e01..00000000 --- a/graphrag/index/workflows/v1/subflows/create_base_extracted_entities.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""All the steps to extract and format base entities.""" - -from typing import Any, cast - -import pandas as pd -from datashaper import ( - AsyncType, - Table, - VerbCallbacks, - VerbInput, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.index.cache import PipelineCache -from graphrag.index.flows.create_base_extracted_entities import ( - create_base_extracted_entities as create_base_extracted_entities_flow, -) -from graphrag.index.storage import PipelineStorage - - -@verb(name="create_base_extracted_entities", treats_input_tables_as_immutable=True) -async def create_base_extracted_entities( - input: VerbInput, - callbacks: VerbCallbacks, - cache: PipelineCache, - storage: PipelineStorage, - column: str, - id_column: str, - nodes: dict[str, Any], - edges: dict[str, Any], - extraction_strategy: dict[str, Any] | None, - async_mode: AsyncType = AsyncType.AsyncIO, - entity_types: list[str] | None = None, - num_threads: int = 4, - graphml_snapshot_enabled: bool = False, - raw_entity_snapshot_enabled: bool = False, - **_kwargs: dict, -) -> VerbResult: - """All the steps to extract and format base entities.""" - source = cast(pd.DataFrame, input.get_input()) - - output = await create_base_extracted_entities_flow( - source, - callbacks, - cache, - storage, - column, - id_column, - nodes, - edges, - extraction_strategy, - async_mode=async_mode, - entity_types=entity_types, - graphml_snapshot_enabled=graphml_snapshot_enabled, - raw_entity_snapshot_enabled=raw_entity_snapshot_enabled, - num_threads=num_threads, - ) - - return create_verb_result(cast(Table, output)) diff --git a/graphrag/index/workflows/v1/subflows/create_final_documents.py b/graphrag/index/workflows/v1/subflows/create_final_documents.py index bc883552..191bbd9e 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_documents.py +++ b/graphrag/index/workflows/v1/subflows/create_final_documents.py @@ -18,6 +18,7 @@ from graphrag.index.cache import PipelineCache from graphrag.index.flows.create_final_documents import ( create_final_documents as create_final_documents_flow, ) +from graphrag.index.utils.ds_util import get_required_input_table @verb( @@ -28,16 +29,20 @@ async def create_final_documents( input: VerbInput, callbacks: VerbCallbacks, cache: PipelineCache, + document_attribute_columns: list[str] | None = None, raw_content_text_embed: dict | None = None, **_kwargs: dict, ) -> VerbResult: """All the steps to transform final documents.""" source = cast(pd.DataFrame, input.get_input()) + text_units = cast(pd.DataFrame, get_required_input_table(input, "text_units").table) output = await create_final_documents_flow( source, + text_units, callbacks, cache, + document_attribute_columns=document_attribute_columns, raw_content_text_embed=raw_content_text_embed, ) diff --git a/graphrag/index/workflows/v1/subflows/create_summarized_entities.py b/graphrag/index/workflows/v1/subflows/create_summarized_entities.py deleted file mode 100644 index 4ac2dd5b..00000000 --- a/graphrag/index/workflows/v1/subflows/create_summarized_entities.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""All the steps to summarize entities.""" - -from typing import Any, cast - -import pandas as pd -from datashaper import ( - Table, - VerbCallbacks, - VerbInput, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.index.cache import PipelineCache -from graphrag.index.flows.create_summarized_entities import ( - create_summarized_entities as create_summarized_entities_flow, -) -from graphrag.index.storage import PipelineStorage - - -@verb( - name="create_summarized_entities", - treats_input_tables_as_immutable=True, -) -async def create_summarized_entities( - input: VerbInput, - callbacks: VerbCallbacks, - cache: PipelineCache, - storage: PipelineStorage, - summarization_strategy: dict[str, Any] | None = None, - num_threads: int = 4, - graphml_snapshot_enabled: bool = False, - **_kwargs: dict, -) -> VerbResult: - """All the steps to summarize entities.""" - source = cast(pd.DataFrame, input.get_input()) - - output = await create_summarized_entities_flow( - source, - callbacks, - cache, - storage, - summarization_strategy, - num_threads=num_threads, - graphml_snapshot_enabled=graphml_snapshot_enabled, - ) - - return create_verb_result(cast(Table, output)) diff --git a/graphrag/utils/storage.py b/graphrag/utils/storage.py index 087b5a29..26072bda 100644 --- a/graphrag/utils/storage.py +++ b/graphrag/utils/storage.py @@ -14,7 +14,7 @@ from graphrag.index.config.storage import ( PipelineStorageConfigTypes, ) from graphrag.index.storage import load_storage -from graphrag.index.storage.typing import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage log = logging.getLogger(__name__) diff --git a/tests/fixtures/min-csv/config.json b/tests/fixtures/min-csv/config.json index 73a46f0d..53986d99 100644 --- a/tests/fixtures/min-csv/config.json +++ b/tests/fixtures/min-csv/config.json @@ -10,29 +10,13 @@ "subworkflows": 1, "max_runtime": 10 }, - "create_base_extracted_entities": { - "row_range": [ - 1, - 2000 - ], - "subworkflows": 1, - "max_runtime": 300 - }, - "create_summarized_entities": { - "row_range": [ - 1, - 2000 - ], - "subworkflows": 1, - "max_runtime": 300 - }, "create_base_entity_graph": { "row_range": [ 1, 2000 ], "subworkflows": 1, - "max_runtime": 10 + "max_runtime": 300 }, "create_final_entities": { "row_range": [ @@ -108,14 +92,6 @@ "subworkflows": 1, "max_runtime": 100 }, - "create_base_documents": { - "row_range": [ - 1, - 2000 - ], - "subworkflows": 1, - "max_runtime": 10 - }, "create_final_documents": { "row_range": [ 1, diff --git a/tests/fixtures/text/config.json b/tests/fixtures/text/config.json index 473e7d8b..af1eba3f 100644 --- a/tests/fixtures/text/config.json +++ b/tests/fixtures/text/config.json @@ -10,14 +10,6 @@ "subworkflows": 1, "max_runtime": 10 }, - "create_base_extracted_entities": { - "row_range": [ - 1, - 2000 - ], - "subworkflows": 1, - "max_runtime": 300 - }, "create_final_covariates": { "row_range": [ 1, @@ -35,21 +27,13 @@ "subworkflows": 1, "max_runtime": 300 }, - "create_summarized_entities": { - "row_range": [ - 1, - 2000 - ], - "subworkflows": 1, - "max_runtime": 300 - }, "create_base_entity_graph": { "row_range": [ 1, 2000 ], "subworkflows": 1, - "max_runtime": 10 + "max_runtime": 300 }, "create_final_entities": { "row_range": [ @@ -125,14 +109,6 @@ "subworkflows": 1, "max_runtime": 100 }, - "create_base_documents": { - "row_range": [ - 1, - 2000 - ], - "subworkflows": 1, - "max_runtime": 10 - }, "create_final_documents": { "row_range": [ 1, diff --git a/tests/integration/_pipeline/megapipeline.yml b/tests/integration/_pipeline/megapipeline.yml index a6004b9b..fc84d6f1 100644 --- a/tests/integration/_pipeline/megapipeline.yml +++ b/tests/integration/_pipeline/megapipeline.yml @@ -19,9 +19,10 @@ workflows: # Just lump everything together chunk_by: [] - - name: create_base_extracted_entities + - name: create_base_entity_graph config: graphml_snapshot: True + embed_graph_enabled: True entity_extract: strategy: type: graph_intelligence @@ -37,9 +38,6 @@ workflows: ("relationship"<|>COMPANY_A<|>COMPANY_B<|>Company_A and Company_B are related because Company_A is 100% owned by Company_B and the two companies also share the same address)<|>2) ## ("relationship"<|>COMPANY_A<|>PERSON_C<|>Company_A and Person_C are related because Person_C is director of Company_A<|>1))' - - - name: create_summarized_entities - config: summarize_descriptions: strategy: type: graph_intelligence @@ -47,11 +45,6 @@ workflows: type: static_response responses: - This is a MOCK response for the LLM. It is summarized! - - - name: create_base_entity_graph - config: - graphml_snapshot: True - embed_graph_enabled: True cluster_graph: strategy: type: leiden @@ -59,8 +52,6 @@ workflows: - name: create_final_nodes - - name: create_base_documents - - name: create_final_communities - name: create_final_text_units config: diff --git a/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py b/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py index 31a83a26..d72f10b3 100644 --- a/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py +++ b/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py @@ -2,8 +2,6 @@ # Licensed under the MIT License import unittest -import networkx as nx - from graphrag.index.operations.extract_entities.strategies.graph_intelligence import ( run_extract_entities, ) @@ -119,8 +117,8 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase): # self.assertItemsEqual isn't available yet, or I am just silly # so we sort the lists and compare them - assert results.graphml_graph is not None, "No graphml graph returned!" - graph = nx.parse_graphml(results.graphml_graph) # type: ignore + graph = results.graph + assert graph is not None, "No graph returned!" # convert to strings for more visual comparison edges_str = sorted([f"{edge[0]} -> {edge[1]}" for edge in graph.edges]) @@ -162,8 +160,8 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase): ), ) - assert results.graphml_graph is not None, "No graphml graph returned!" - graph = nx.parse_graphml(results.graphml_graph) # type: ignore + graph = results.graph # type: ignore + assert graph is not None, "No graph returned!" # TODO: The edges might come back in any order, but we're assuming they're coming # back in the order that we passed in the docs, that might not be true @@ -173,9 +171,11 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase): assert ( graph.nodes["TEST_ENTITY_2"].get("source_id") == "1" ) # TEST_ENTITY_2 should be in just 1 - assert sorted( - graph.nodes["TEST_ENTITY_1"].get("source_id").split(",") - ) == sorted(["1", "2"]) # TEST_ENTITY_1 should be 1 and 2 + ids_str = graph.nodes["TEST_ENTITY_1"].get("source_id") or "" + assert sorted(ids_str.split(",")) == sorted([ + "1", + "2", + ]) # TEST_ENTITY_1 should be 1 and 2 async def test_run_extract_entities_multiple_documents_correct_edge_source_ids_mapped( self, @@ -210,8 +210,8 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase): ), ) - assert results.graphml_graph is not None, "No graphml graph returned!" - graph = nx.parse_graphml(results.graphml_graph) # type: ignore + graph = results.graph # type: ignore + assert graph is not None, "No graph returned!" edges = list(graph.edges(data=True)) # should only have 2 edges diff --git a/tests/unit/indexing/workflows/test_emit.py b/tests/unit/indexing/workflows/test_emit.py new file mode 100644 index 00000000..5c16f66c --- /dev/null +++ b/tests/unit/indexing/workflows/test_emit.py @@ -0,0 +1,124 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +from typing import Any, cast + +import pandas as pd +from datashaper import ( + Table, + VerbInput, + VerbResult, + create_verb_result, +) + +from graphrag.index.config import PipelineWorkflowReference +from graphrag.index.run import run_pipeline +from graphrag.index.storage import MemoryPipelineStorage, PipelineStorage + + +async def mock_verb( + input: VerbInput, storage: PipelineStorage, **_kwargs +) -> VerbResult: + source = cast(pd.DataFrame, input.get_input()) + + output = source[["id"]] + + await storage.set("mock_write", source[["id"]]) + + return create_verb_result( + cast( + Table, + output, + ) + ) + + +async def mock_no_return_verb( + input: VerbInput, storage: PipelineStorage, **_kwargs +) -> VerbResult: + source = cast(pd.DataFrame, input.get_input()) + + # write some outputs to storage independent of the return + await storage.set("empty_write", source[["name"]]) + + return create_verb_result( + cast( + Table, + pd.DataFrame(), + ) + ) + + +async def test_normal_result_emits_parquet(): + mock_verbs: Any = {"mock_verb": mock_verb} + mock_workflows: Any = { + "mock_workflow": lambda _x: [ + { + "verb": "mock_verb", + "args": { + "column": "test", + }, + } + ] + } + workflows = [ + PipelineWorkflowReference( + name="mock_workflow", + config=None, + ) + ] + dataset = pd.DataFrame({"id": [1, 2, 3], "name": ["a", "b", "c"]}) + storage = MemoryPipelineStorage() + pipeline_result = [ + gen + async for gen in run_pipeline( + workflows, + dataset, + storage=storage, + additional_workflows=mock_workflows, + additional_verbs=mock_verbs, + ) + ] + + assert len(pipeline_result) == 1 + assert ( + storage.keys() == ["stats.json", "mock_write", "mock_workflow.parquet"] + ), "Mock workflow output should be written to storage by the emitter when there is a non-empty data frame" + + +async def test_empty_result_does_not_emit_parquet(): + mock_verbs: Any = {"mock_no_return_verb": mock_no_return_verb} + mock_workflows: Any = { + "mock_workflow": lambda _x: [ + { + "verb": "mock_no_return_verb", + "args": { + "column": "test", + }, + } + ] + } + workflows = [ + PipelineWorkflowReference( + name="mock_workflow", + config=None, + ) + ] + dataset = pd.DataFrame({"id": [1, 2, 3], "name": ["a", "b", "c"]}) + storage = MemoryPipelineStorage() + pipeline_result = [ + gen + async for gen in run_pipeline( + workflows, + dataset, + storage=storage, + additional_workflows=mock_workflows, + additional_verbs=mock_verbs, + ) + ] + + assert len(pipeline_result) == 1 + assert storage.keys() == [ + "stats.json", + "empty_write", + ], "Mock workflow output should not be written to storage by the emitter" diff --git a/tests/verbs/data/create_base_documents.parquet b/tests/verbs/data/create_base_documents.parquet deleted file mode 100644 index db873e66..00000000 Binary files a/tests/verbs/data/create_base_documents.parquet and /dev/null differ diff --git a/tests/verbs/data/create_base_extracted_entities.parquet b/tests/verbs/data/create_base_extracted_entities.parquet deleted file mode 100644 index d7ec39be..00000000 Binary files a/tests/verbs/data/create_base_extracted_entities.parquet and /dev/null differ diff --git a/tests/verbs/data/create_summarized_entities.parquet b/tests/verbs/data/create_summarized_entities.parquet deleted file mode 100644 index 80e377c4..00000000 Binary files a/tests/verbs/data/create_summarized_entities.parquet and /dev/null differ diff --git a/tests/verbs/test_create_base_documents.py b/tests/verbs/test_create_base_documents.py deleted file mode 100644 index 1e182e26..00000000 --- a/tests/verbs/test_create_base_documents.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -from graphrag.index.workflows.v1.create_base_documents import ( - build_steps, - workflow_name, -) - -from .util import ( - compare_outputs, - get_config_for_workflow, - get_workflow_output, - load_expected, - load_input_tables, -) - - -async def test_create_base_documents(): - input_tables = load_input_tables(["workflow:create_final_text_units"]) - expected = load_expected(workflow_name) - - config = get_config_for_workflow(workflow_name) - - steps = build_steps(config) - - actual = await get_workflow_output( - input_tables, - { - "steps": steps, - }, - ) - - compare_outputs(actual, expected) - - -async def test_create_base_documents_with_attribute_columns(): - input_tables = load_input_tables(["workflow:create_final_text_units"]) - expected = load_expected(workflow_name) - - config = get_config_for_workflow(workflow_name) - - config["document_attribute_columns"] = ["title"] - - steps = build_steps(config) - - actual = await get_workflow_output( - input_tables, - { - "steps": steps, - }, - ) - - # we should have dropped "title" and added "attributes" - # our test dataframe does not have attributes, so we'll assert without it - # and separately confirm it is in the output - compare_outputs(actual, expected, columns=["id", "text_units", "raw_content"]) - assert len(actual.columns) == 4 - assert "attributes" in actual.columns diff --git a/tests/verbs/test_create_base_entity_graph.py b/tests/verbs/test_create_base_entity_graph.py index 0294ff45..86868585 100644 --- a/tests/verbs/test_create_base_entity_graph.py +++ b/tests/verbs/test_create_base_entity_graph.py @@ -2,7 +2,9 @@ # Licensed under the MIT License import networkx as nx +import pytest +from graphrag.config.enums import LLMType from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage from graphrag.index.workflows.v1.create_base_entity_graph import ( build_steps, @@ -16,16 +18,48 @@ from .util import ( load_input_tables, ) +MOCK_LLM_ENTITY_RESPONSES = [ + """ + ("entity"<|>COMPANY_A<|>COMPANY<|>Company_A is a test company) + ## + ("entity"<|>COMPANY_B<|>COMPANY<|>Company_B owns Company_A and also shares an address with Company_A) + ## + ("entity"<|>PERSON_C<|>PERSON<|>Person_C is director of Company_A) + ## + ("relationship"<|>COMPANY_A<|>COMPANY_B<|>Company_A and Company_B are related because Company_A is 100% owned by Company_B and the two companies also share the same address)<|>2) + ## + ("relationship"<|>COMPANY_A<|>PERSON_C<|>Company_A and Person_C are related because Person_C is director of Company_A<|>1)) + """.strip() +] + +MOCK_LLM_ENTITY_CONFIG = { + "type": LLMType.StaticResponse, + "responses": MOCK_LLM_ENTITY_RESPONSES, +} + +MOCK_LLM_SUMMARIZATION_RESPONSES = [ + """ + This is a MOCK response for the LLM. It is summarized! + """.strip() +] + +MOCK_LLM_SUMMARIZATION_CONFIG = { + "type": LLMType.StaticResponse, + "responses": MOCK_LLM_SUMMARIZATION_RESPONSES, +} + async def test_create_base_entity_graph(): input_tables = load_input_tables([ - "workflow:create_summarized_entities", + "workflow:create_base_text_units", ]) expected = load_expected(workflow_name) storage = MemoryPipelineStorage() config = get_config_for_workflow(workflow_name) + config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG + config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_SUMMARIZATION_CONFIG steps = build_steps(config) @@ -37,34 +71,36 @@ async def test_create_base_entity_graph(): storage=storage, ) - # the serialization of the graph may differ so we can't assert the dataframes directly - assert actual.shape == expected.shape, "Graph dataframe shapes differ" - + assert len(actual.columns) == len( + expected.columns + ), "Graph dataframe columns differ" # let's parse a sample of the raw graphml actual_graphml_0 = actual["clustered_graph"][:1][0] actual_graph_0 = nx.parse_graphml(actual_graphml_0) - expected_graphml_0 = expected["clustered_graph"][:1][0] - expected_graph_0 = nx.parse_graphml(expected_graphml_0) + assert actual_graph_0.number_of_nodes() == 3 + assert actual_graph_0.number_of_edges() == 2 - assert ( - actual_graph_0.number_of_nodes() == expected_graph_0.number_of_nodes() - ), "Graphml node count differs" - assert ( - actual_graph_0.number_of_edges() == expected_graph_0.number_of_edges() - ), "Graphml edge count differs" + # TODO: with the combined verb we can't force summarization + # this is because the mock responses always result in a single description, which is returned verbatim rather than summarized + # we need to update the mocking to provide somewhat unique graphs so a true merge happens + # the assertion should grab a node and ensure the description matches the mock description, not the original as we are doing below + nodes = list(actual_graph_0.nodes(data=True)) + assert nodes[0][1]["description"] == "Company_A is a test company" assert len(storage.keys()) == 0, "Storage should be empty" async def test_create_base_entity_graph_with_embeddings(): input_tables = load_input_tables([ - "workflow:create_summarized_entities", + "workflow:create_base_text_units", ]) expected = load_expected(workflow_name) config = get_config_for_workflow(workflow_name) + config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG + config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_SUMMARIZATION_CONFIG config["embed_graph_enabled"] = True steps = build_steps(config) @@ -84,19 +120,22 @@ async def test_create_base_entity_graph_with_embeddings(): async def test_create_base_entity_graph_with_snapshots(): input_tables = load_input_tables([ - "workflow:create_summarized_entities", + "workflow:create_base_text_units", ]) - expected = load_expected(workflow_name) storage = MemoryPipelineStorage() config = get_config_for_workflow(workflow_name) + config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG + config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_SUMMARIZATION_CONFIG + config["raw_entity_snapshot"] = True config["graphml_snapshot"] = True + config["embed_graph_enabled"] = True # need this on in order to see the snapshot steps = build_steps(config) - actual = await get_workflow_output( + await get_workflow_output( input_tables, { "steps": steps, @@ -104,15 +143,31 @@ async def test_create_base_entity_graph_with_snapshots(): storage=storage, ) - assert actual.shape == expected.shape, "Graph dataframe shapes differ" - assert storage.keys() == [ - "clustered_graph.0.graphml", - "clustered_graph.1.graphml", - "clustered_graph.2.graphml", - "clustered_graph.3.graphml", - "embedded_graph.0.graphml", - "embedded_graph.1.graphml", - "embedded_graph.2.graphml", - "embedded_graph.3.graphml", + "raw_extracted_entities.json", + "merged_graph.graphml", + "summarized_graph.graphml", + "clustered_graph.graphml", + "embedded_graph.graphml", ], "Graph snapshot keys differ" + + +async def test_create_base_entity_graph_missing_llm_throws(): + input_tables = load_input_tables([ + "workflow:create_base_text_units", + ]) + + config = get_config_for_workflow(workflow_name) + + config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG + del config["summarize_descriptions"]["strategy"]["llm"] + + steps = build_steps(config) + + with pytest.raises(ValueError): # noqa PT011 + await get_workflow_output( + input_tables, + { + "steps": steps, + }, + ) diff --git a/tests/verbs/test_create_base_extracted_entities.py b/tests/verbs/test_create_base_extracted_entities.py deleted file mode 100644 index 57ca6003..00000000 --- a/tests/verbs/test_create_base_extracted_entities.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -import networkx as nx -import pytest -from datashaper.errors import VerbParallelizationError - -from graphrag.config.enums import LLMType -from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage -from graphrag.index.workflows.v1.create_base_extracted_entities import ( - build_steps, - workflow_name, -) - -from .util import ( - get_config_for_workflow, - get_workflow_output, - load_expected, - load_input_tables, -) - -MOCK_LLM_RESPONSES = [ - """ - ("entity"<|>COMPANY_A<|>COMPANY<|>Company_A is a test company) - ## - ("entity"<|>COMPANY_B<|>COMPANY<|>Company_B owns Company_A and also shares an address with Company_A) - ## - ("entity"<|>PERSON_C<|>PERSON<|>Person_C is director of Company_A) - ## - ("relationship"<|>COMPANY_A<|>COMPANY_B<|>Company_A and Company_B are related because Company_A is 100% owned by Company_B and the two companies also share the same address)<|>2) - ## - ("relationship"<|>COMPANY_A<|>PERSON_C<|>Company_A and Person_C are related because Person_C is director of Company_A<|>1)) - """.strip() -] - -MOCK_LLM_CONFIG = { - "type": LLMType.StaticResponse, - "responses": MOCK_LLM_RESPONSES, -} - - -async def test_create_base_extracted_entities(): - input_tables = load_input_tables(["workflow:create_base_text_units"]) - expected = load_expected(workflow_name) - - storage = MemoryPipelineStorage() - - config = get_config_for_workflow(workflow_name) - - config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_CONFIG - - steps = build_steps(config) - - actual = await get_workflow_output( - input_tables, - { - "steps": steps, - }, - storage=storage, - ) - - # let's parse a sample of the raw graphml - actual_graphml_0 = actual["entity_graph"][:1][0] - actual_graph_0 = nx.parse_graphml(actual_graphml_0) - - assert actual_graph_0.number_of_nodes() == 3 - assert actual_graph_0.number_of_edges() == 2 - - assert actual.columns == expected.columns - - assert len(storage.keys()) == 0, "Storage should be empty" - - -async def test_create_base_extracted_entities_with_snapshots(): - input_tables = load_input_tables(["workflow:create_base_text_units"]) - expected = load_expected(workflow_name) - - storage = MemoryPipelineStorage() - - config = get_config_for_workflow(workflow_name) - - config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_CONFIG - config["raw_entity_snapshot"] = True - config["graphml_snapshot"] = True - - steps = build_steps(config) - - actual = await get_workflow_output( - input_tables, - { - "steps": steps, - }, - storage=storage, - ) - - print(storage.keys()) - - assert actual.columns == expected.columns - - assert storage.keys() == ["raw_extracted_entities.json", "merged_graph.graphml"] - - -async def test_create_base_extracted_entities_missing_llm_throws(): - input_tables = load_input_tables(["workflow:create_base_text_units"]) - - config = get_config_for_workflow(workflow_name) - - del config["entity_extract"]["strategy"]["llm"] - - steps = build_steps(config) - - with pytest.raises(VerbParallelizationError): - await get_workflow_output( - input_tables, - { - "steps": steps, - }, - ) diff --git a/tests/verbs/test_create_final_documents.py b/tests/verbs/test_create_final_documents.py index c25b69ad..ac42fb7e 100644 --- a/tests/verbs/test_create_final_documents.py +++ b/tests/verbs/test_create_final_documents.py @@ -17,7 +17,7 @@ from .util import ( async def test_create_final_documents(): input_tables = load_input_tables([ - "workflow:create_base_documents", + "workflow:create_final_text_units", ]) expected = load_expected(workflow_name) @@ -39,7 +39,7 @@ async def test_create_final_documents(): async def test_create_final_documents_with_embeddings(): input_tables = load_input_tables([ - "workflow:create_base_documents", + "workflow:create_final_text_units", ]) expected = load_expected(workflow_name) @@ -63,3 +63,28 @@ async def test_create_final_documents_with_embeddings(): assert len(actual.columns) == len(expected.columns) + 1 # the mock impl returns an array of 3 floats for each embedding assert len(actual["raw_content_embedding"][:1][0]) == 3 + + +async def test_create_final_documents_with_attribute_columns(): + input_tables = load_input_tables(["workflow:create_final_text_units"]) + expected = load_expected(workflow_name) + + config = get_config_for_workflow(workflow_name) + + config["document_attribute_columns"] = ["title"] + + steps = build_steps(config) + + actual = await get_workflow_output( + input_tables, + { + "steps": steps, + }, + ) + + # we should have dropped "title" and added "attributes" + # our test dataframe does not have attributes, so we'll assert without it + # and separately confirm it is in the output + compare_outputs(actual, expected, columns=["id", "text_unit_ids", "raw_content"]) + assert len(actual.columns) == 4 + assert "attributes" in actual.columns diff --git a/tests/verbs/test_create_summarized_entities.py b/tests/verbs/test_create_summarized_entities.py deleted file mode 100644 index c36c9b53..00000000 --- a/tests/verbs/test_create_summarized_entities.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -import networkx as nx -import pytest - -from graphrag.config.enums import LLMType -from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage -from graphrag.index.workflows.v1.create_summarized_entities import ( - build_steps, - workflow_name, -) - -from .util import ( - get_config_for_workflow, - get_workflow_output, - load_expected, - load_input_tables, -) - -MOCK_LLM_RESPONSES = [ - """ - This is a MOCK response for the LLM. It is summarized! - """.strip() -] - -MOCK_LLM_CONFIG = { - "type": LLMType.StaticResponse, - "responses": MOCK_LLM_RESPONSES, -} - - -async def test_create_summarized_entities(): - input_tables = load_input_tables([ - "workflow:create_base_extracted_entities", - ]) - expected = load_expected(workflow_name) - - storage = MemoryPipelineStorage() - - config = get_config_for_workflow(workflow_name) - - config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_CONFIG - - steps = build_steps(config) - - actual = await get_workflow_output( - input_tables, - { - "steps": steps, - }, - storage=storage, - ) - - # the serialization of the graph may differ so we can't assert the dataframes directly - assert actual.shape == expected.shape, "Graph dataframe shapes differ" - - # let's parse a sample of the raw graphml - actual_graphml_0 = actual["entity_graph"][:1][0] - actual_graph_0 = nx.parse_graphml(actual_graphml_0) - - expected_graphml_0 = expected["entity_graph"][:1][0] - expected_graph_0 = nx.parse_graphml(expected_graphml_0) - - assert ( - actual_graph_0.number_of_nodes() == expected_graph_0.number_of_nodes() - ), "Graphml node count differs" - assert ( - actual_graph_0.number_of_edges() == expected_graph_0.number_of_edges() - ), "Graphml edge count differs" - - # ensure the mock summary was injected to the nodes - nodes = list(actual_graph_0.nodes(data=True)) - assert ( - nodes[0][1]["description"] - == "This is a MOCK response for the LLM. It is summarized!" - ) - - assert len(storage.keys()) == 0, "Storage should be empty" - - -async def test_create_summarized_entities_with_snapshots(): - input_tables = load_input_tables([ - "workflow:create_base_extracted_entities", - ]) - expected = load_expected(workflow_name) - - storage = MemoryPipelineStorage() - - config = get_config_for_workflow(workflow_name) - - config["summarize_descriptions"]["strategy"]["llm"] = MOCK_LLM_CONFIG - config["graphml_snapshot"] = True - - steps = build_steps(config) - - actual = await get_workflow_output( - input_tables, - { - "steps": steps, - }, - storage=storage, - ) - - assert actual.shape == expected.shape, "Graph dataframe shapes differ" - - assert storage.keys() == [ - "summarized_graph.graphml", - ], "Graph snapshot keys differ" - - -async def test_create_summarized_entities_missing_llm_throws(): - input_tables = load_input_tables([ - "workflow:create_base_extracted_entities", - ]) - - config = get_config_for_workflow(workflow_name) - - del config["summarize_descriptions"]["strategy"]["llm"] - - steps = build_steps(config) - - with pytest.raises(ValueError): # noqa PT011 - await get_workflow_output( - input_tables, - { - "steps": steps, - }, - ) diff --git a/tests/verbs/util.py b/tests/verbs/util.py index 2d6d5b6d..9533bc61 100644 --- a/tests/verbs/util.py +++ b/tests/verbs/util.py @@ -13,7 +13,7 @@ from graphrag.index import ( create_pipeline_config, ) from graphrag.index.run.utils import _create_run_context -from graphrag.index.storage.typing import PipelineStorage +from graphrag.index.storage.pipeline_storage import PipelineStorage pd.set_option("display.max_columns", None) @@ -26,7 +26,7 @@ def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]: # all workflows implicitly receive the `input` source, which is formatted as a dataframe after loading from storage # we'll simulate that by just loading one of our output parquets and converting back to equivalent dataframe # so we aren't dealing with storage vagaries (which would become an integration test) - source = pd.read_parquet("tests/verbs/data/create_base_documents.parquet") + source = pd.read_parquet("tests/verbs/data/create_final_documents.parquet") source.rename(columns={"raw_content": "text"}, inplace=True) input_tables["source"] = cast(pd.DataFrame, source[["id", "text", "title"]])