mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-13 16:47:20 +08:00
Optional embeddings (#1890)
* Make all tables optional for embeddings * Semver --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
parent
56e0fad218
commit
fbf11f3a7b
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Align embeddings table loading with configured fields."
|
||||
}
|
||||
@ -25,7 +25,11 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.operations.embed_text import embed_text
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
from graphrag.utils.storage import (
|
||||
load_table_from_storage,
|
||||
storage_has_table,
|
||||
write_table_to_storage,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -35,13 +39,23 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform community reports."""
|
||||
documents = await load_table_from_storage("documents", context.storage)
|
||||
relationships = await load_table_from_storage("relationships", context.storage)
|
||||
text_units = await load_table_from_storage("text_units", context.storage)
|
||||
entities = await load_table_from_storage("entities", context.storage)
|
||||
community_reports = await load_table_from_storage(
|
||||
"community_reports", context.storage
|
||||
)
|
||||
documents = None
|
||||
relationships = None
|
||||
text_units = None
|
||||
entities = None
|
||||
community_reports = None
|
||||
if await storage_has_table("documents", context.storage):
|
||||
documents = await load_table_from_storage("documents", context.storage)
|
||||
if await storage_has_table("relationships", context.storage):
|
||||
relationships = await load_table_from_storage("relationships", context.storage)
|
||||
if await storage_has_table("text_units", context.storage):
|
||||
text_units = await load_table_from_storage("text_units", context.storage)
|
||||
if await storage_has_table("entities", context.storage):
|
||||
entities = await load_table_from_storage("entities", context.storage)
|
||||
if await storage_has_table("community_reports", context.storage):
|
||||
community_reports = await load_table_from_storage(
|
||||
"community_reports", context.storage
|
||||
)
|
||||
|
||||
embedded_fields = get_embedded_fields(config)
|
||||
text_embed = get_embedding_settings(config)
|
||||
@ -133,6 +147,10 @@ async def generate_text_embeddings(
|
||||
log.info("Creating embeddings")
|
||||
outputs = {}
|
||||
for field in embedded_fields:
|
||||
if embedding_param_map[field]["data"] is None:
|
||||
msg = f"Embedding {field} is specified but data table is not in storage."
|
||||
raise ValueError(msg)
|
||||
|
||||
outputs[field] = await _run_and_snapshot_embeddings(
|
||||
name=field,
|
||||
callbacks=callbacks,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user