mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 09:07:20 +08:00
Merge remote-tracking branch 'origin/main' into migration-scripts
This commit is contained in:
commit
6352ca400e
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Move embedding verbs to operations."
|
||||
}
|
||||
@ -36,7 +36,9 @@ class EmbedGraphConfig(BaseModel):
|
||||
|
||||
def resolved_strategy(self) -> dict:
|
||||
"""Get the resolved node2vec strategy."""
|
||||
from graphrag.index.verbs.graph.embed import EmbedGraphStrategyType
|
||||
from graphrag.index.operations.embed_graph.embed_graph import (
|
||||
EmbedGraphStrategyType,
|
||||
)
|
||||
|
||||
return self.strategy or {
|
||||
"type": EmbedGraphStrategyType.node2vec,
|
||||
|
||||
@ -35,7 +35,9 @@ class TextEmbeddingConfig(LLMConfig):
|
||||
|
||||
def resolved_strategy(self) -> dict:
|
||||
"""Get the resolved text embedding strategy."""
|
||||
from graphrag.index.verbs.text.embed import TextEmbedStrategyType
|
||||
from graphrag.index.operations.embed_text.embed_text import (
|
||||
TextEmbedStrategyType,
|
||||
)
|
||||
|
||||
return self.strategy or {
|
||||
"type": TextEmbedStrategyType.openai,
|
||||
|
||||
@ -10,9 +10,9 @@ from datashaper import (
|
||||
VerbCallbacks,
|
||||
)
|
||||
|
||||
from graphrag.index.operations.embed_graph.embed_graph import embed_graph
|
||||
from graphrag.index.storage import PipelineStorage
|
||||
from graphrag.index.verbs.graph.clustering.cluster_graph import cluster_graph_df
|
||||
from graphrag.index.verbs.graph.embed.embed_graph import embed_graph_df
|
||||
from graphrag.index.verbs.snapshot_rows import snapshot_rows_df
|
||||
|
||||
|
||||
@ -20,14 +20,11 @@ async def create_base_entity_graph(
|
||||
entities: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
storage: PipelineStorage,
|
||||
clustering_config: dict[str, Any],
|
||||
embedding_config: dict[str, Any],
|
||||
clustering_strategy: dict[str, Any],
|
||||
embedding_strategy: dict[str, Any] | None,
|
||||
graphml_snapshot_enabled: bool = False,
|
||||
embed_graph_enabled: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to create the base entity graph."""
|
||||
clustering_strategy = clustering_config.get("strategy", {"type": "leiden"})
|
||||
|
||||
clustered = cluster_graph_df(
|
||||
entities,
|
||||
callbacks,
|
||||
@ -46,14 +43,12 @@ async def create_base_entity_graph(
|
||||
formats=[{"format": "text", "extension": "graphml"}],
|
||||
)
|
||||
|
||||
embedding_strategy = embedding_config.get("strategy")
|
||||
if embed_graph_enabled and embedding_strategy:
|
||||
clustered = await embed_graph_df(
|
||||
if embedding_strategy:
|
||||
clustered["embeddings"] = await embed_graph(
|
||||
clustered,
|
||||
callbacks,
|
||||
column="clustered_graph",
|
||||
strategy=embedding_strategy,
|
||||
to="embeddings",
|
||||
)
|
||||
|
||||
# take second snapshot after embedding
|
||||
@ -68,7 +63,7 @@ async def create_base_entity_graph(
|
||||
)
|
||||
|
||||
final_columns = ["level", "clustered_graph"]
|
||||
if embed_graph_enabled:
|
||||
if embedding_strategy:
|
||||
final_columns.append("embeddings")
|
||||
|
||||
return cast(pd.DataFrame, clustered[final_columns])
|
||||
|
||||
@ -31,6 +31,7 @@ from graphrag.index.graph.extractors.community_reports.schemas import (
|
||||
NODE_ID,
|
||||
NODE_NAME,
|
||||
)
|
||||
from graphrag.index.operations.embed_text.embed_text import embed_text
|
||||
from graphrag.index.verbs.graph.report.create_community_reports import (
|
||||
create_community_reports_df,
|
||||
)
|
||||
@ -40,7 +41,6 @@ from graphrag.index.verbs.graph.report.prepare_community_reports import (
|
||||
from graphrag.index.verbs.graph.report.restore_community_hierarchy import (
|
||||
restore_community_hierarchy_df,
|
||||
)
|
||||
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
|
||||
|
||||
|
||||
async def create_final_community_reports(
|
||||
@ -87,7 +87,7 @@ async def create_final_community_reports(
|
||||
|
||||
# Embed full content if not skipped
|
||||
if full_content_text_embed:
|
||||
community_reports["full_content_embedding"] = await text_embed_df(
|
||||
community_reports["full_content_embedding"] = await embed_text(
|
||||
community_reports,
|
||||
callbacks,
|
||||
cache,
|
||||
@ -98,7 +98,7 @@ async def create_final_community_reports(
|
||||
|
||||
# Embed summary if not skipped
|
||||
if summary_text_embed:
|
||||
community_reports["summary_embedding"] = await text_embed_df(
|
||||
community_reports["summary_embedding"] = await embed_text(
|
||||
community_reports,
|
||||
callbacks,
|
||||
cache,
|
||||
@ -109,7 +109,7 @@ async def create_final_community_reports(
|
||||
|
||||
# Embed title if not skipped
|
||||
if title_text_embed:
|
||||
community_reports["title_embedding"] = await text_embed_df(
|
||||
community_reports["title_embedding"] = await embed_text(
|
||||
community_reports,
|
||||
callbacks,
|
||||
cache,
|
||||
|
||||
@ -9,25 +9,25 @@ from datashaper import (
|
||||
)
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
|
||||
from graphrag.index.operations.embed_text.embed_text import embed_text
|
||||
|
||||
|
||||
async def create_final_documents(
|
||||
documents: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
text_embed: dict | None = None,
|
||||
raw_content_text_embed: dict | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to transform final documents."""
|
||||
documents.rename(columns={"text_units": "text_unit_ids"}, inplace=True)
|
||||
|
||||
if text_embed:
|
||||
documents["raw_content_embedding"] = await text_embed_df(
|
||||
if raw_content_text_embed:
|
||||
documents["raw_content_embedding"] = await embed_text(
|
||||
documents,
|
||||
callbacks,
|
||||
cache,
|
||||
column="raw_content",
|
||||
strategy=text_embed["strategy"],
|
||||
strategy=raw_content_text_embed["strategy"],
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
@ -9,8 +9,8 @@ from datashaper import (
|
||||
)
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.operations.embed_text.embed_text import embed_text
|
||||
from graphrag.index.verbs.graph.unpack import unpack_graph_df
|
||||
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
|
||||
from graphrag.index.verbs.text.split import text_split_df
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ async def create_final_entities(
|
||||
|
||||
# Embed name if not skipped
|
||||
if name_text_embed:
|
||||
nodes["name_embedding"] = await text_embed_df(
|
||||
nodes["name_embedding"] = await embed_text(
|
||||
nodes,
|
||||
callbacks,
|
||||
cache,
|
||||
@ -63,7 +63,7 @@ async def create_final_entities(
|
||||
if description_text_embed:
|
||||
# Concatenate 'name' and 'description' and embed
|
||||
nodes["name_description"] = nodes["name"] + ":" + nodes["description"]
|
||||
nodes["description_embedding"] = await text_embed_df(
|
||||
nodes["description_embedding"] = await embed_text(
|
||||
nodes,
|
||||
callbacks,
|
||||
cache,
|
||||
|
||||
@ -11,11 +11,11 @@ from datashaper import (
|
||||
)
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.operations.embed_text.embed_text import embed_text
|
||||
from graphrag.index.verbs.graph.compute_edge_combined_degree import (
|
||||
compute_edge_combined_degree_df,
|
||||
)
|
||||
from graphrag.index.verbs.graph.unpack import unpack_graph_df
|
||||
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
|
||||
|
||||
|
||||
async def create_final_relationships(
|
||||
@ -23,7 +23,7 @@ async def create_final_relationships(
|
||||
nodes: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
text_embed: dict | None = None,
|
||||
description_text_embed: dict | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to transform final relationships."""
|
||||
graph_edges = unpack_graph_df(entity_graph, callbacks, "clustered_graph", "edges")
|
||||
@ -34,13 +34,13 @@ async def create_final_relationships(
|
||||
pd.DataFrame, graph_edges[graph_edges["level"] == 0].reset_index(drop=True)
|
||||
)
|
||||
|
||||
if text_embed:
|
||||
filtered["description_embedding"] = await text_embed_df(
|
||||
if description_text_embed:
|
||||
filtered["description_embedding"] = await embed_text(
|
||||
filtered,
|
||||
callbacks,
|
||||
cache,
|
||||
column="description",
|
||||
strategy=text_embed["strategy"],
|
||||
strategy=description_text_embed["strategy"],
|
||||
embedding_name="relationship_description",
|
||||
)
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from datashaper import (
|
||||
)
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.verbs.text.embed.text_embed import text_embed_df
|
||||
from graphrag.index.operations.embed_text.embed_text import embed_text
|
||||
|
||||
|
||||
async def create_final_text_units(
|
||||
@ -21,7 +21,7 @@ async def create_final_text_units(
|
||||
final_covariates: pd.DataFrame | None,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
text_embed: dict | None = None,
|
||||
text_text_embed: dict | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to transform the text units."""
|
||||
selected = text_units.loc[:, ["id", "chunk", "document_ids", "n_tokens"]].rename(
|
||||
@ -42,16 +42,16 @@ async def create_final_text_units(
|
||||
aggregated = final_joined.groupby("id", sort=False).agg("first").reset_index()
|
||||
|
||||
is_using_vector_store = False
|
||||
if text_embed:
|
||||
aggregated["text_embedding"] = await text_embed_df(
|
||||
if text_text_embed:
|
||||
aggregated["text_embedding"] = await embed_text(
|
||||
aggregated,
|
||||
callbacks,
|
||||
cache,
|
||||
column="text",
|
||||
strategy=text_embed["strategy"],
|
||||
strategy=text_text_embed["strategy"],
|
||||
)
|
||||
is_using_vector_store = (
|
||||
text_embed.get("strategy", {}).get("vector_store", None) is not None
|
||||
text_text_embed.get("strategy", {}).get("vector_store", None) is not None
|
||||
)
|
||||
|
||||
return cast(
|
||||
@ -62,7 +62,7 @@ async def create_final_text_units(
|
||||
"text",
|
||||
*(
|
||||
[]
|
||||
if (not text_embed or is_using_vector_store)
|
||||
if (not text_text_embed or is_using_vector_store)
|
||||
else ["text_embedding"]
|
||||
),
|
||||
"n_tokens",
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import Any, cast
|
||||
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
from datashaper import TableContainer, VerbCallbacks, VerbInput, derive_from_rows, verb
|
||||
from datashaper import VerbCallbacks, derive_from_rows
|
||||
|
||||
from graphrag.index.utils import load_graph
|
||||
|
||||
@ -25,71 +25,18 @@ class EmbedGraphStrategyType(str, Enum):
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
@verb(name="embed_graph")
|
||||
async def embed_graph(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
strategy: dict[str, Any],
|
||||
column: str,
|
||||
to: str,
|
||||
**kwargs,
|
||||
) -> TableContainer:
|
||||
"""
|
||||
Embed a graph into a vector space. The graph is expected to be in graphml format. The verb outputs a new column containing a mapping between node_id and vector.
|
||||
|
||||
## Usage
|
||||
```yaml
|
||||
verb: embed_graph
|
||||
args:
|
||||
column: clustered_graph # The name of the column containing the graph, should be a graphml graph
|
||||
to: embeddings # The name of the column to output the embeddings to
|
||||
strategy: <strategy config> # See strategies section below
|
||||
```
|
||||
|
||||
## Strategies
|
||||
The embed_graph verb uses a strategy to embed the graph. The strategy is an object which defines the strategy to use. The following strategies are available:
|
||||
|
||||
### node2vec
|
||||
This strategy uses the node2vec algorithm to embed a graph. The strategy config is as follows:
|
||||
|
||||
```yaml
|
||||
strategy:
|
||||
type: node2vec
|
||||
dimensions: 1536 # Optional, The number of dimensions to use for the embedding, default: 1536
|
||||
num_walks: 10 # Optional, The number of walks to use for the embedding, default: 10
|
||||
walk_length: 40 # Optional, The walk length to use for the embedding, default: 40
|
||||
window_size: 2 # Optional, The window size to use for the embedding, default: 2
|
||||
iterations: 3 # Optional, The number of iterations to use for the embedding, default: 3
|
||||
random_seed: 86 # Optional, The random seed to use for the embedding, default: 86
|
||||
```
|
||||
"""
|
||||
input_df = cast(pd.DataFrame, input.get_input())
|
||||
|
||||
output_df = await embed_graph_df(
|
||||
input_df,
|
||||
callbacks,
|
||||
strategy,
|
||||
column,
|
||||
to,
|
||||
**kwargs,
|
||||
)
|
||||
return TableContainer(table=output_df)
|
||||
|
||||
|
||||
async def embed_graph_df(
|
||||
input: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
strategy: dict[str, Any],
|
||||
column: str,
|
||||
to: str,
|
||||
**kwargs,
|
||||
) -> pd.DataFrame:
|
||||
num_threads: int = 4,
|
||||
):
|
||||
"""
|
||||
Embed a graph into a vector space. The graph is expected to be in graphml format. The verb outputs a new column containing a mapping between node_id and vector.
|
||||
Embed a graph into a vector space. The graph is expected to be in graphml format. The operation outputs a new column containing a mapping between node_id and vector.
|
||||
|
||||
## Usage
|
||||
```yaml
|
||||
verb: embed_graph
|
||||
args:
|
||||
column: clustered_graph # The name of the column containing the graph, should be a graphml graph
|
||||
to: embeddings # The name of the column to output the embeddings to
|
||||
@ -97,7 +44,7 @@ async def embed_graph_df(
|
||||
```
|
||||
|
||||
## Strategies
|
||||
The embed_graph verb uses a strategy to embed the graph. The strategy is an object which defines the strategy to use. The following strategies are available:
|
||||
The embed_graph operation uses a strategy to embed the graph. The strategy is an object which defines the strategy to use. The following strategies are available:
|
||||
|
||||
### node2vec
|
||||
This strategy uses the node2vec algorithm to embed a graph. The strategy config is as follows:
|
||||
@ -123,10 +70,10 @@ async def embed_graph_df(
|
||||
input,
|
||||
run_strategy,
|
||||
callbacks=callbacks,
|
||||
num_threads=kwargs.get("num_threads", None),
|
||||
num_threads=num_threads,
|
||||
)
|
||||
input[to] = list(results)
|
||||
return input
|
||||
|
||||
return list(results)
|
||||
|
||||
|
||||
def run_embeddings(
|
||||
@ -9,7 +9,7 @@ import networkx as nx
|
||||
|
||||
from graphrag.index.graph.embedding import embed_nod2vec
|
||||
from graphrag.index.graph.utils import stable_largest_connected_component
|
||||
from graphrag.index.verbs.graph.embed.typing import NodeEmbeddings
|
||||
from graphrag.index.operations.embed_graph.typing import NodeEmbeddings
|
||||
|
||||
|
||||
def run(graph: nx.Graph, args: dict[str, Any]) -> NodeEmbeddings:
|
||||
@ -3,6 +3,6 @@
|
||||
|
||||
"""The Indexing Engine text embed package root."""
|
||||
|
||||
from .text_embed import TextEmbedStrategyType, text_embed
|
||||
from .embed_text import TextEmbedStrategyType, embed_text
|
||||
|
||||
__all__ = ["TextEmbedStrategyType", "text_embed"]
|
||||
__all__ = ["TextEmbedStrategyType", "embed_text"]
|
||||
@ -1,15 +1,15 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing text_embed, load_strategy and create_row_from_embedding_data methods definition."""
|
||||
"""A module containing embed_text, load_strategy and create_row_from_embedding_data methods definition."""
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datashaper import TableContainer, VerbCallbacks, VerbInput, verb
|
||||
from datashaper import VerbCallbacks
|
||||
|
||||
from graphrag.index.cache import PipelineCache
|
||||
from graphrag.vector_stores import (
|
||||
@ -38,21 +38,19 @@ class TextEmbedStrategyType(str, Enum):
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
@verb(name="text_embed")
|
||||
async def text_embed(
|
||||
input: VerbInput,
|
||||
async def embed_text(
|
||||
input: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
column: str,
|
||||
strategy: dict,
|
||||
**kwargs,
|
||||
) -> TableContainer:
|
||||
embedding_name: str = "default",
|
||||
):
|
||||
"""
|
||||
Embed a piece of text into a vector space. The verb outputs a new column containing a mapping between doc_id and vector.
|
||||
Embed a piece of text into a vector space. The operation outputs a new column containing a mapping between doc_id and vector.
|
||||
|
||||
## Usage
|
||||
```yaml
|
||||
verb: text_embed
|
||||
args:
|
||||
column: text # The name of the column containing the text to embed, this can either be a column with text, or a column with a list[tuple[doc_id, str]]
|
||||
to: embedding # The name of the column to output the embedding to
|
||||
@ -60,7 +58,7 @@ async def text_embed(
|
||||
```
|
||||
|
||||
## Strategies
|
||||
The text embed verb uses a strategy to embed the text. The strategy is an object which defines the strategy to use. The following strategies are available:
|
||||
The text embed operation uses a strategy to embed the text. The strategy is an object which defines the strategy to use. The following strategies are available:
|
||||
|
||||
### openai
|
||||
This strategy uses openai to embed a piece of text. In particular it uses a LLM to embed a piece of text. The strategy config is as follows:
|
||||
@ -79,28 +77,9 @@ async def text_embed(
|
||||
<...>
|
||||
```
|
||||
"""
|
||||
input_df = cast(pd.DataFrame, input.get_input())
|
||||
to = kwargs.get("to", f"{column}_embedding")
|
||||
input_df[to] = await text_embed_df(
|
||||
input_df, callbacks, cache, column, strategy, **kwargs
|
||||
)
|
||||
return TableContainer(table=input_df)
|
||||
|
||||
|
||||
# TODO: this ultimately just creates a new column, so our embed function could just generate a series instead of updating the dataframe
|
||||
async def text_embed_df(
|
||||
input: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
column: str,
|
||||
strategy: dict,
|
||||
**kwargs,
|
||||
):
|
||||
"""Embed a piece of text into a vector space."""
|
||||
vector_store_config = strategy.get("vector_store")
|
||||
|
||||
if vector_store_config:
|
||||
embedding_name = kwargs.get("embedding_name", "default")
|
||||
collection_name = _get_collection_name(vector_store_config, embedding_name)
|
||||
vector_store: BaseVectorStore = _create_vector_store(
|
||||
vector_store_config, collection_name
|
||||
@ -10,7 +10,6 @@ from .graph import (
|
||||
cluster_graph,
|
||||
create_community_reports,
|
||||
create_graph,
|
||||
embed_graph,
|
||||
layout_graph,
|
||||
merge_graphs,
|
||||
unpack_graph,
|
||||
@ -19,7 +18,7 @@ from .overrides import aggregate, concat
|
||||
from .snapshot import snapshot
|
||||
from .snapshot_rows import snapshot_rows
|
||||
from .spread_json import spread_json
|
||||
from .text import chunk, text_embed, text_split, text_translate
|
||||
from .text import chunk, text_split, text_translate
|
||||
from .unzip import unzip
|
||||
from .zip import zip_verb
|
||||
|
||||
@ -30,7 +29,6 @@ __all__ = [
|
||||
"concat",
|
||||
"create_community_reports",
|
||||
"create_graph",
|
||||
"embed_graph",
|
||||
"entity_extract",
|
||||
"extract_covariates",
|
||||
"genid",
|
||||
@ -40,7 +38,6 @@ __all__ = [
|
||||
"snapshot_rows",
|
||||
"spread_json",
|
||||
"summarize_descriptions",
|
||||
"text_embed",
|
||||
"text_split",
|
||||
"text_translate",
|
||||
"unpack_graph",
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
from .clustering import cluster_graph
|
||||
from .compute_edge_combined_degree import compute_edge_combined_degree
|
||||
from .create import DEFAULT_EDGE_ATTRIBUTES, DEFAULT_NODE_ATTRIBUTES, create_graph
|
||||
from .embed import embed_graph
|
||||
from .layout import layout_graph
|
||||
from .merge import merge_graphs
|
||||
from .report import (
|
||||
@ -25,7 +24,6 @@ __all__ = [
|
||||
"compute_edge_combined_degree",
|
||||
"create_community_reports",
|
||||
"create_graph",
|
||||
"embed_graph",
|
||||
"layout_graph",
|
||||
"merge_graphs",
|
||||
"prepare_community_reports",
|
||||
|
||||
@ -11,8 +11,8 @@ import pandas as pd
|
||||
from datashaper import TableContainer, VerbCallbacks, VerbInput, progress_callback, verb
|
||||
|
||||
from graphrag.index.graph.visualization import GraphLayout
|
||||
from graphrag.index.operations.embed_graph.typing import NodeEmbeddings
|
||||
from graphrag.index.utils import load_graph
|
||||
from graphrag.index.verbs.graph.embed.typing import NodeEmbeddings
|
||||
|
||||
|
||||
class LayoutGraphStrategyType(str, Enum):
|
||||
|
||||
@ -15,8 +15,8 @@ from graphrag.index.graph.visualization import (
|
||||
NodePosition,
|
||||
compute_umap_positions,
|
||||
)
|
||||
from graphrag.index.operations.embed_graph.typing import NodeEmbeddings
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
from graphrag.index.verbs.graph.embed.typing import NodeEmbeddings
|
||||
|
||||
# TODO: This could be handled more elegantly, like what columns to use
|
||||
# for "size" or "cluster"
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
"""The Indexing Engine text package root."""
|
||||
|
||||
from .chunk.text_chunk import chunk
|
||||
from .embed import text_embed
|
||||
from .replace import replace
|
||||
from .split import text_split
|
||||
from .translate import text_translate
|
||||
@ -12,7 +11,6 @@ from .translate import text_translate
|
||||
__all__ = [
|
||||
"chunk",
|
||||
"replace",
|
||||
"text_embed",
|
||||
"text_split",
|
||||
"text_translate",
|
||||
]
|
||||
|
||||
@ -21,6 +21,8 @@ def build_steps(
|
||||
"cluster_graph",
|
||||
{"strategy": {"type": "leiden"}},
|
||||
)
|
||||
clustering_strategy = clustering_config["strategy"]
|
||||
|
||||
embed_graph_config = config.get(
|
||||
"embed_graph",
|
||||
{
|
||||
@ -34,18 +36,20 @@ def build_steps(
|
||||
}
|
||||
},
|
||||
)
|
||||
embedding_strategy = embed_graph_config["strategy"]
|
||||
embed_graph_enabled = config.get("embed_graph_enabled", False) or False
|
||||
|
||||
graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False
|
||||
embed_graph_enabled = config.get("embed_graph_enabled", False) or False
|
||||
|
||||
return [
|
||||
{
|
||||
"verb": "create_base_entity_graph",
|
||||
"args": {
|
||||
"clustering_config": clustering_config,
|
||||
"clustering_strategy": clustering_strategy,
|
||||
"graphml_snapshot_enabled": graphml_snapshot_enabled,
|
||||
"embed_graph_enabled": embed_graph_enabled,
|
||||
"embedding_config": embed_graph_config,
|
||||
"embedding_strategy": embedding_strategy
|
||||
if embed_graph_enabled
|
||||
else None,
|
||||
},
|
||||
"input": ({"source": "workflow:create_summarized_entities"}),
|
||||
},
|
||||
|
||||
@ -26,7 +26,7 @@ def build_steps(
|
||||
{
|
||||
"verb": "create_final_documents",
|
||||
"args": {
|
||||
"text_embed": document_raw_content_embed_config
|
||||
"raw_content_text_embed": document_raw_content_embed_config
|
||||
if not skip_raw_content_embedding
|
||||
else None,
|
||||
},
|
||||
|
||||
@ -27,7 +27,7 @@ def build_steps(
|
||||
{
|
||||
"verb": "create_final_relationships",
|
||||
"args": {
|
||||
"text_embed": relationship_description_embed_config
|
||||
"description_text_embed": relationship_description_embed_config
|
||||
if not skip_description_embedding
|
||||
else None,
|
||||
},
|
||||
|
||||
@ -37,7 +37,7 @@ def build_steps(
|
||||
{
|
||||
"verb": "create_final_text_units",
|
||||
"args": {
|
||||
"text_embed": text_unit_text_embed_config
|
||||
"text_text_embed": text_unit_text_embed_config
|
||||
if not skip_text_unit_embedding
|
||||
else None,
|
||||
},
|
||||
|
||||
@ -28,10 +28,9 @@ async def create_base_entity_graph(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
storage: PipelineStorage,
|
||||
clustering_config: dict[str, Any],
|
||||
embedding_config: dict[str, Any],
|
||||
clustering_strategy: dict[str, Any],
|
||||
embedding_strategy: dict[str, Any] | None,
|
||||
graphml_snapshot_enabled: bool = False,
|
||||
embed_graph_enabled: bool = False,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to create the base entity graph."""
|
||||
@ -41,10 +40,9 @@ async def create_base_entity_graph(
|
||||
source,
|
||||
callbacks,
|
||||
storage,
|
||||
clustering_config,
|
||||
embedding_config,
|
||||
clustering_strategy,
|
||||
embedding_strategy,
|
||||
graphml_snapshot_enabled,
|
||||
embed_graph_enabled,
|
||||
)
|
||||
|
||||
return create_verb_result(cast(Table, output))
|
||||
|
||||
@ -28,7 +28,7 @@ async def create_final_documents(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
text_embed: dict | None = None,
|
||||
raw_content_text_embed: dict | None = None,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform final documents."""
|
||||
@ -38,7 +38,7 @@ async def create_final_documents(
|
||||
source,
|
||||
callbacks,
|
||||
cache,
|
||||
text_embed,
|
||||
raw_content_text_embed,
|
||||
)
|
||||
|
||||
return create_verb_result(cast(Table, output))
|
||||
|
||||
@ -29,7 +29,7 @@ async def create_final_relationships(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
text_embed: dict | None = None,
|
||||
description_text_embed: dict | None = None,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform final relationships."""
|
||||
@ -41,7 +41,7 @@ async def create_final_relationships(
|
||||
nodes,
|
||||
callbacks,
|
||||
cache,
|
||||
text_embed,
|
||||
description_text_embed,
|
||||
)
|
||||
|
||||
return create_verb_result(cast(Table, output))
|
||||
|
||||
@ -27,7 +27,7 @@ async def create_final_text_units(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
text_embed: dict | None = None,
|
||||
text_text_embed: dict | None = None,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform the text units."""
|
||||
@ -50,7 +50,7 @@ async def create_final_text_units(
|
||||
final_covariates,
|
||||
callbacks,
|
||||
cache,
|
||||
text_embed,
|
||||
text_text_embed,
|
||||
)
|
||||
|
||||
return create_verb_result(cast(Table, output))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user