From 30bdb35cc8145d184e5689639827d0c9432b5622 Mon Sep 17 00:00:00 2001 From: Nathan Evans Date: Wed, 27 Aug 2025 11:12:01 -0700 Subject: [PATCH] Selective embeddings loading (#2035) * Invert embedding table loading logic * Semver --- .../patch-20250827003854793879.json | 4 ++++ .../workflows/generate_text_embeddings.py | 21 ++++++++++++------- 2 files changed, 18 insertions(+), 7 deletions(-) create mode 100644 .semversioner/next-release/patch-20250827003854793879.json diff --git a/.semversioner/next-release/patch-20250827003854793879.json b/.semversioner/next-release/patch-20250827003854793879.json new file mode 100644 index 00000000..e699566b --- /dev/null +++ b/.semversioner/next-release/patch-20250827003854793879.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "generate_text_embeddings only loads tables if embedding field is specified." +} diff --git a/graphrag/index/workflows/generate_text_embeddings.py b/graphrag/index/workflows/generate_text_embeddings.py index e56c0ed8..ff676623 100644 --- a/graphrag/index/workflows/generate_text_embeddings.py +++ b/graphrag/index/workflows/generate_text_embeddings.py @@ -26,7 +26,6 @@ from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.utils.storage import ( load_table_from_storage, - storage_has_table, write_table_to_storage, ) @@ -39,27 +38,35 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform community reports.""" logger.info("Workflow started: generate_text_embeddings") + embedded_fields = config.embed_text.names + logger.info("Embedding the following fields: %s", embedded_fields) documents = None relationships = None text_units = None entities = None community_reports = None - if await storage_has_table("documents", context.output_storage): + if document_text_embedding in embedded_fields: documents = await load_table_from_storage("documents", context.output_storage) - if await storage_has_table("relationships", context.output_storage): + if relationship_description_embedding in embedded_fields: relationships = await load_table_from_storage( "relationships", context.output_storage ) - if await storage_has_table("text_units", context.output_storage): + if text_unit_text_embedding in embedded_fields: text_units = await load_table_from_storage("text_units", context.output_storage) - if await storage_has_table("entities", context.output_storage): + if ( + entity_title_embedding in embedded_fields + or entity_description_embedding in embedded_fields + ): entities = await load_table_from_storage("entities", context.output_storage) - if await storage_has_table("community_reports", context.output_storage): + if ( + community_title_embedding in embedded_fields + or community_summary_embedding in embedded_fields + or community_full_content_embedding in embedded_fields + ): community_reports = await load_table_from_storage( "community_reports", context.output_storage ) - embedded_fields = config.embed_text.names text_embed = get_embedding_settings(config) output = await generate_text_embeddings(