Optional embeddings (#1890)

* Make all tables optional for embeddings

* Semver

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
Nathan Evans 2025-04-25 16:20:56 -07:00 committed by GitHub
parent 56e0fad218
commit fbf11f3a7b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 8 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Align embeddings table loading with configured fields."
}

View File

@ -25,7 +25,11 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.operations.embed_text import embed_text
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
from graphrag.utils.storage import (
load_table_from_storage,
storage_has_table,
write_table_to_storage,
)
log = logging.getLogger(__name__)
@ -35,13 +39,23 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to transform community reports."""
documents = await load_table_from_storage("documents", context.storage)
relationships = await load_table_from_storage("relationships", context.storage)
text_units = await load_table_from_storage("text_units", context.storage)
entities = await load_table_from_storage("entities", context.storage)
community_reports = await load_table_from_storage(
"community_reports", context.storage
)
documents = None
relationships = None
text_units = None
entities = None
community_reports = None
if await storage_has_table("documents", context.storage):
documents = await load_table_from_storage("documents", context.storage)
if await storage_has_table("relationships", context.storage):
relationships = await load_table_from_storage("relationships", context.storage)
if await storage_has_table("text_units", context.storage):
text_units = await load_table_from_storage("text_units", context.storage)
if await storage_has_table("entities", context.storage):
entities = await load_table_from_storage("entities", context.storage)
if await storage_has_table("community_reports", context.storage):
community_reports = await load_table_from_storage(
"community_reports", context.storage
)
embedded_fields = get_embedded_fields(config)
text_embed = get_embedding_settings(config)
@ -133,6 +147,10 @@ async def generate_text_embeddings(
log.info("Creating embeddings")
outputs = {}
for field in embedded_fields:
if embedding_param_map[field]["data"] is None:
msg = f"Embedding {field} is specified but data table is not in storage."
raise ValueError(msg)
outputs[field] = await _run_and_snapshot_embeddings(
name=field,
callbacks=callbacks,