Selective embeddings loading (#2035)

* Invert embedding table loading logic

* Semver
This commit is contained in:
Nathan Evans 2025-08-27 11:12:01 -07:00 committed by GitHub
parent 77fb7d9d7d
commit 30bdb35cc8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 7 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "generate_text_embeddings only loads tables if embedding field is specified."
}

View File

@ -26,7 +26,6 @@ from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import (
load_table_from_storage,
storage_has_table,
write_table_to_storage,
)
@ -39,27 +38,35 @@ async def run_workflow(
) -> WorkflowFunctionOutput:
"""All the steps to transform community reports."""
logger.info("Workflow started: generate_text_embeddings")
embedded_fields = config.embed_text.names
logger.info("Embedding the following fields: %s", embedded_fields)
documents = None
relationships = None
text_units = None
entities = None
community_reports = None
if await storage_has_table("documents", context.output_storage):
if document_text_embedding in embedded_fields:
documents = await load_table_from_storage("documents", context.output_storage)
if await storage_has_table("relationships", context.output_storage):
if relationship_description_embedding in embedded_fields:
relationships = await load_table_from_storage(
"relationships", context.output_storage
)
if await storage_has_table("text_units", context.output_storage):
if text_unit_text_embedding in embedded_fields:
text_units = await load_table_from_storage("text_units", context.output_storage)
if await storage_has_table("entities", context.output_storage):
if (
entity_title_embedding in embedded_fields
or entity_description_embedding in embedded_fields
):
entities = await load_table_from_storage("entities", context.output_storage)
if await storage_has_table("community_reports", context.output_storage):
if (
community_title_embedding in embedded_fields
or community_summary_embedding in embedded_fields
or community_full_content_embedding in embedded_fields
):
community_reports = await load_table_from_storage(
"community_reports", context.output_storage
)
embedded_fields = config.embed_text.names
text_embed = get_embedding_settings(config)
output = await generate_text_embeddings(