mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-13 16:47:20 +08:00
Selective embeddings loading (#2035)
* Invert embedding table loading logic * Semver
This commit is contained in:
parent
77fb7d9d7d
commit
30bdb35cc8
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "generate_text_embeddings only loads tables if embedding field is specified."
|
||||
}
|
||||
@ -26,7 +26,6 @@ from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import (
|
||||
load_table_from_storage,
|
||||
storage_has_table,
|
||||
write_table_to_storage,
|
||||
)
|
||||
|
||||
@ -39,27 +38,35 @@ async def run_workflow(
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform community reports."""
|
||||
logger.info("Workflow started: generate_text_embeddings")
|
||||
embedded_fields = config.embed_text.names
|
||||
logger.info("Embedding the following fields: %s", embedded_fields)
|
||||
documents = None
|
||||
relationships = None
|
||||
text_units = None
|
||||
entities = None
|
||||
community_reports = None
|
||||
if await storage_has_table("documents", context.output_storage):
|
||||
if document_text_embedding in embedded_fields:
|
||||
documents = await load_table_from_storage("documents", context.output_storage)
|
||||
if await storage_has_table("relationships", context.output_storage):
|
||||
if relationship_description_embedding in embedded_fields:
|
||||
relationships = await load_table_from_storage(
|
||||
"relationships", context.output_storage
|
||||
)
|
||||
if await storage_has_table("text_units", context.output_storage):
|
||||
if text_unit_text_embedding in embedded_fields:
|
||||
text_units = await load_table_from_storage("text_units", context.output_storage)
|
||||
if await storage_has_table("entities", context.output_storage):
|
||||
if (
|
||||
entity_title_embedding in embedded_fields
|
||||
or entity_description_embedding in embedded_fields
|
||||
):
|
||||
entities = await load_table_from_storage("entities", context.output_storage)
|
||||
if await storage_has_table("community_reports", context.output_storage):
|
||||
if (
|
||||
community_title_embedding in embedded_fields
|
||||
or community_summary_embedding in embedded_fields
|
||||
or community_full_content_embedding in embedded_fields
|
||||
):
|
||||
community_reports = await load_table_from_storage(
|
||||
"community_reports", context.output_storage
|
||||
)
|
||||
|
||||
embedded_fields = config.embed_text.names
|
||||
text_embed = get_embedding_settings(config)
|
||||
|
||||
output = await generate_text_embeddings(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user