From fbf11f3a7b639031c9863456f448bcc339a78541 Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Fri, 25 Apr 2025 16:20:56 -0700 Subject: [PATCH] Optional embeddings (#1890) * Make all tables optional for embeddings * Semver --------- Co-authored-by: Alonso Guevara --- .../patch-20250422232634719243.json | 4 +++ .../workflows/generate_text_embeddings.py | 34 ++++++++++++++----- 2 files changed, 30 insertions(+), 8 deletions(-) create mode 100644 .semversioner/next-release/patch-20250422232634719243.json diff --git a/.semversioner/next-release/patch-20250422232634719243.json b/.semversioner/next-release/patch-20250422232634719243.json new file mode 100644 index 00000000..e0c37a2e --- /dev/null +++ b/.semversioner/next-release/patch-20250422232634719243.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Align embeddings table loading with configured fields." +} diff --git a/graphrag/index/workflows/generate_text_embeddings.py b/graphrag/index/workflows/generate_text_embeddings.py index 98bc4e66..b4790f55 100644 --- a/graphrag/index/workflows/generate_text_embeddings.py +++ b/graphrag/index/workflows/generate_text_embeddings.py @@ -25,7 +25,11 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.operations.embed_text import embed_text from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage +from graphrag.utils.storage import ( + load_table_from_storage, + storage_has_table, + write_table_to_storage, +) log = logging.getLogger(__name__) @@ -35,13 +39,23 @@ async def run_workflow( context: PipelineRunContext, ) -> WorkflowFunctionOutput: """All the steps to transform community reports.""" - documents = await load_table_from_storage("documents", context.storage) - relationships = await load_table_from_storage("relationships", context.storage) - text_units = await load_table_from_storage("text_units", context.storage) - entities = await load_table_from_storage("entities", context.storage) - community_reports = await load_table_from_storage( - "community_reports", context.storage - ) + documents = None + relationships = None + text_units = None + entities = None + community_reports = None + if await storage_has_table("documents", context.storage): + documents = await load_table_from_storage("documents", context.storage) + if await storage_has_table("relationships", context.storage): + relationships = await load_table_from_storage("relationships", context.storage) + if await storage_has_table("text_units", context.storage): + text_units = await load_table_from_storage("text_units", context.storage) + if await storage_has_table("entities", context.storage): + entities = await load_table_from_storage("entities", context.storage) + if await storage_has_table("community_reports", context.storage): + community_reports = await load_table_from_storage( + "community_reports", context.storage + ) embedded_fields = get_embedded_fields(config) text_embed = get_embedding_settings(config) @@ -133,6 +147,10 @@ async def generate_text_embeddings( log.info("Creating embeddings") outputs = {} for field in embedded_fields: + if embedding_param_map[field]["data"] is None: + msg = f"Embedding {field} is specified but data table is not in storage." + raise ValueError(msg) + outputs[field] = await _run_and_snapshot_embeddings( name=field, callbacks=callbacks,