Merge remote-tracking branch 'origin/main' into migration-scripts

This commit is contained in:
gaudyb 2024-10-10 12:05:54 -06:00
commit 4d15434c7e
151 changed files with 896 additions and 2405 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Moving verbs around."
}

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Small cleanup in community context history building"
}

View File

@ -89,7 +89,6 @@ nbconvert
binarize
prechunked
openai
genid
umap
concat
unhot

View File

@ -178,7 +178,7 @@ This section controls the cache mechanism used by the pipeline. This is used to
| `GRAPHRAG_CACHE_STORAGE_ACCOUNT_BLOB_URL` | The Azure Storage blob endpoint to use when in `blob` mode and using managed identity. Will have the format `https://<storage_account_name>.blob.core.windows.net` | `str` | optional | None |
| `GRAPHRAG_CACHE_CONNECTION_STRING` | The Azure Storage connection string to use when in `blob` mode. | `str` | optional | None |
| `GRAPHRAG_CACHE_CONTAINER_NAME` | The Azure Storage container name to use when in `blob` mode. | `str` | optional | None |
| `GRAPHRAG_CACHE_BASE_DIR` | The base path to the reporting outputs. | `str` | optional | None |
| `GRAPHRAG_CACHE_BASE_DIR` | The base path to the cache files. | `str` | optional | None |
## Reporting

View File

@ -6,7 +6,7 @@ layout: page
date: 2023-01-03
---
The default configuration mode may be configured by using a `config.json` or `config.yml` file in the data project root. If a `.env` file is present along with this config file, then it will be loaded, and the environment variables defined therein will be available for token replacements in your configuration document using `${ENV_VAR}` syntax.
The default configuration mode may be configured by using a `settings.json` or `settings.yml` file in the data project root. If a `.env` file is present along with this config file, then it will be loaded, and the environment variables defined therein will be available for token replacements in your configuration document using `${ENV_VAR}` syntax.
For example:
@ -14,7 +14,7 @@ For example:
# .env
API_KEY=some_api_key
# config.json
# settings.json
{
"llm": {
"api_key": "${API_KEY}"

View File

@ -48,7 +48,7 @@ mkdir -p ./ragtest/input
Now let's get a copy of A Christmas Carol by Charles Dickens from a trusted source
```sh
curl https://www.gutenberg.org/cache/epub/24022/pg24022.txt > ./ragtest/input/book.txt
curl https://www.gutenberg.org/cache/epub/24022/pg24022.txt -o ./ragtest/input/book.txt
```
Next we'll inject some required config variables:

View File

@ -29,7 +29,7 @@ class ChunkingConfig(BaseModel):
def resolved_strategy(self, encoding_model: str) -> dict:
"""Get the resolved chunking strategy."""
from graphrag.index.verbs.text.chunk import ChunkStrategyType
from graphrag.index.operations.chunk_text import ChunkStrategyType
return self.strategy or {
"type": ChunkStrategyType.tokens,

View File

@ -38,7 +38,7 @@ class ClaimExtractionConfig(LLMConfig):
def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
"""Get the resolved claim extraction strategy."""
from graphrag.index.verbs.covariates.extract_covariates import (
from graphrag.index.operations.extract_covariates import (
ExtractClaimsStrategyType,
)

View File

@ -20,7 +20,7 @@ class ClusterGraphConfig(BaseModel):
def resolved_strategy(self) -> dict:
"""Get the resolved cluster strategy."""
from graphrag.index.verbs.graph.clustering import GraphCommunityStrategyType
from graphrag.index.operations.cluster_graph import GraphCommunityStrategyType
return self.strategy or {
"type": GraphCommunityStrategyType.leiden,

View File

@ -32,7 +32,9 @@ class CommunityReportsConfig(LLMConfig):
def resolved_strategy(self, root_dir) -> dict:
"""Get the resolved community report extraction strategy."""
from graphrag.index.verbs.graph.report import CreateCommunityReportsStrategyType
from graphrag.index.operations.summarize_communities import (
CreateCommunityReportsStrategyType,
)
return self.strategy or {
"type": CreateCommunityReportsStrategyType.graph_intelligence,

View File

@ -36,7 +36,7 @@ class EmbedGraphConfig(BaseModel):
def resolved_strategy(self) -> dict:
"""Get the resolved node2vec strategy."""
from graphrag.index.operations.embed_graph.embed_graph import (
from graphrag.index.operations.embed_graph import (
EmbedGraphStrategyType,
)

View File

@ -35,7 +35,9 @@ class EntityExtractionConfig(LLMConfig):
def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
"""Get the resolved entity extraction strategy."""
from graphrag.index.verbs.entities.extraction import ExtractEntityStrategyType
from graphrag.index.operations.extract_entities import (
ExtractEntityStrategyType,
)
return self.strategy or {
"type": ExtractEntityStrategyType.graph_intelligence,

View File

@ -28,7 +28,9 @@ class SummarizeDescriptionsConfig(LLMConfig):
def resolved_strategy(self, root_dir: str) -> dict:
"""Get the resolved description summarization strategy."""
from graphrag.index.verbs.entities.summarize import SummarizeStrategyType
from graphrag.index.operations.summarize_descriptions import (
SummarizeStrategyType,
)
return self.strategy or {
"type": SummarizeStrategyType.graph_intelligence,

View File

@ -35,7 +35,7 @@ class TextEmbeddingConfig(LLMConfig):
def resolved_strategy(self) -> dict:
"""Get the resolved text embedding strategy."""
from graphrag.index.operations.embed_text.embed_text import (
from graphrag.index.operations.embed_text import (
TextEmbedStrategyType,
)

View File

@ -10,10 +10,10 @@ from datashaper import (
VerbCallbacks,
)
from graphrag.index.operations.embed_graph.embed_graph import embed_graph
from graphrag.index.operations.cluster_graph import cluster_graph
from graphrag.index.operations.embed_graph import embed_graph
from graphrag.index.operations.snapshot_rows import snapshot_rows
from graphrag.index.storage import PipelineStorage
from graphrag.index.verbs.graph.clustering.cluster_graph import cluster_graph_df
from graphrag.index.verbs.snapshot_rows import snapshot_rows_df
async def create_base_entity_graph(
@ -25,7 +25,7 @@ async def create_base_entity_graph(
graphml_snapshot_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to create the base entity graph."""
clustered = cluster_graph_df(
clustered = cluster_graph(
entities,
callbacks,
column="entity_graph",
@ -35,7 +35,7 @@ async def create_base_entity_graph(
)
if graphml_snapshot_enabled:
await snapshot_rows_df(
await snapshot_rows(
clustered,
column="clustered_graph",
base_name="clustered_graph",
@ -54,7 +54,7 @@ async def create_base_entity_graph(
# take second snapshot after embedding
# todo: this could be skipped if embedding isn't performed, other wise it is a copy of the regular graph?
if graphml_snapshot_enabled:
await snapshot_rows_df(
await snapshot_rows(
clustered,
column="entity_graph",
base_name="embedded_graph",

View File

@ -12,23 +12,23 @@ from datashaper import (
)
from graphrag.index.cache import PipelineCache
from graphrag.index.operations.extract_entities import extract_entities
from graphrag.index.operations.merge_graphs import merge_graphs
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.operations.snapshot_rows import snapshot_rows
from graphrag.index.storage import PipelineStorage
from graphrag.index.verbs.entities.extraction.entity_extract import entity_extract_df
from graphrag.index.verbs.graph.merge.merge_graphs import merge_graphs_df
from graphrag.index.verbs.snapshot import snapshot_df
from graphrag.index.verbs.snapshot_rows import snapshot_rows_df
async def create_base_extracted_entities(
text_units: pd.DataFrame,
cache: PipelineCache,
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
column: str,
id_column: str,
nodes: dict[str, Any],
edges: dict[str, Any],
strategy: dict[str, Any] | None,
extraction_strategy: dict[str, Any] | None,
async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
graphml_snapshot_enabled: bool = False,
@ -36,13 +36,13 @@ async def create_base_extracted_entities(
num_threads: int = 4,
) -> pd.DataFrame:
"""All the steps to extract and format covariates."""
entity_graph = await entity_extract_df(
entity_graph = await extract_entities(
text_units,
cache,
callbacks,
cache,
column=column,
id_column=id_column,
strategy=strategy,
strategy=extraction_strategy,
async_mode=async_mode,
entity_types=entity_types,
to="entities",
@ -51,14 +51,14 @@ async def create_base_extracted_entities(
)
if raw_entity_snapshot_enabled:
await snapshot_df(
await snapshot(
entity_graph,
name="raw_extracted_entities",
storage=storage,
formats=["json"],
)
merged_graph = merge_graphs_df(
merged_graph = merge_graphs(
entity_graph,
callbacks,
column="entity_graph",
@ -68,7 +68,7 @@ async def create_base_extracted_entities(
)
if graphml_snapshot_enabled:
await snapshot_rows_df(
await snapshot_rows(
merged_graph,
base_name="merged_graph",
column="entity_graph",

View File

@ -3,14 +3,19 @@
"""All the steps to transform base text_units."""
from dataclasses import dataclass
from typing import Any, cast
import pandas as pd
from datashaper import VerbCallbacks
from datashaper import (
FieldAggregateOperation,
Progress,
VerbCallbacks,
aggregate_operation_mapping,
)
from graphrag.index.verbs.genid import genid_df
from graphrag.index.verbs.overrides.aggregate import aggregate_df
from graphrag.index.verbs.text.chunk.text_chunk import chunk_df
from graphrag.index.operations.chunk_text import chunk_text
from graphrag.index.utils import gen_md5_hash
def create_base_text_units(
@ -19,7 +24,7 @@ def create_base_text_units(
chunk_column_name: str,
n_tokens_column_name: str,
chunk_by_columns: list[str],
strategy: dict[str, Any] | None = None,
chunk_strategy: dict[str, Any] | None = None,
) -> pd.DataFrame:
"""All the steps to transform base text_units."""
sort = documents.sort_values(by=["id"], ascending=[True])
@ -28,7 +33,9 @@ def create_base_text_units(
zip(*[sort[col] for col in ["id", "text"]], strict=True)
)
aggregated = aggregate_df(
callbacks.progress(Progress(percent=0))
aggregated = _aggregate_df(
sort,
groupby=[*chunk_by_columns] if len(chunk_by_columns) > 0 else None,
aggregations=[
@ -40,12 +47,14 @@ def create_base_text_units(
],
)
chunked = chunk_df(
callbacks.progress(Progress(percent=1))
chunked = chunk_text(
aggregated,
column="texts",
to="chunks",
callbacks=callbacks,
strategy=strategy,
strategy=chunk_strategy,
)
chunked = cast(pd.DataFrame, chunked[[*chunk_by_columns, "chunks"]])
@ -56,11 +65,9 @@ def create_base_text_units(
},
inplace=True,
)
chunked = genid_df(
chunked, to="chunk_id", method="md5_hash", hash=[chunk_column_name]
chunked["chunk_id"] = chunked.apply(
lambda row: gen_md5_hash(row, [chunk_column_name]), axis=1
)
chunked[["document_ids", chunk_column_name, n_tokens_column_name]] = pd.DataFrame(
chunked[chunk_column_name].tolist(), index=chunked.index
)
@ -69,3 +76,57 @@ def create_base_text_units(
return cast(
pd.DataFrame, chunked[chunked[chunk_column_name].notna()].reset_index(drop=True)
)
# TODO: would be nice to inline this completely in the main method with pandas
def _aggregate_df(
input: pd.DataFrame,
aggregations: list[dict[str, Any]],
groupby: list[str] | None = None,
) -> pd.DataFrame:
"""Aggregate method definition."""
aggregations_to_apply = _load_aggregations(aggregations)
df_aggregations = {
agg.column: _get_pandas_agg_operation(agg)
for agg in aggregations_to_apply.values()
}
if groupby is None:
output_grouped = input.groupby(lambda _x: True)
else:
output_grouped = input.groupby(groupby, sort=False)
output = cast(pd.DataFrame, output_grouped.agg(df_aggregations))
output.rename(
columns={agg.column: agg.to for agg in aggregations_to_apply.values()},
inplace=True,
)
output.columns = [agg.to for agg in aggregations_to_apply.values()]
return output.reset_index()
@dataclass
class Aggregation:
"""Aggregation class method definition."""
column: str | None
operation: str
to: str
# Only useful for the concat operation
separator: str | None = None
def _get_pandas_agg_operation(agg: Aggregation) -> Any:
if agg.operation == "string_concat":
return (agg.separator or ",").join
return aggregate_operation_mapping[FieldAggregateOperation(agg.operation)]
def _load_aggregations(
aggregations: list[dict[str, Any]],
) -> dict[str, Aggregation]:
return {
aggregation["column"]: Aggregation(
aggregation["column"], aggregation["operation"], aggregation["to"]
)
for aggregation in aggregations
}

View File

@ -8,7 +8,7 @@ from datashaper import (
VerbCallbacks,
)
from graphrag.index.verbs.graph.unpack import unpack_graph_df
from graphrag.index.operations.unpack_graph import unpack_graph
def create_final_communities(
@ -16,8 +16,8 @@ def create_final_communities(
callbacks: VerbCallbacks,
) -> pd.DataFrame:
"""All the steps to transform final communities."""
graph_nodes = unpack_graph_df(entity_graph, callbacks, "clustered_graph", "nodes")
graph_edges = unpack_graph_df(entity_graph, callbacks, "clustered_graph", "edges")
graph_nodes = unpack_graph(entity_graph, callbacks, "clustered_graph", "nodes")
graph_edges = unpack_graph(entity_graph, callbacks, "clustered_graph", "edges")
# Merge graph_nodes with graph_edges for both source and target matches
source_clusters = graph_nodes.merge(

View File

@ -31,15 +31,11 @@ from graphrag.index.graph.extractors.community_reports.schemas import (
NODE_ID,
NODE_NAME,
)
from graphrag.index.operations.embed_text.embed_text import embed_text
from graphrag.index.verbs.graph.report.create_community_reports import (
create_community_reports_df,
)
from graphrag.index.verbs.graph.report.prepare_community_reports import (
prepare_community_reports_df,
)
from graphrag.index.verbs.graph.report.restore_community_hierarchy import (
restore_community_hierarchy_df,
from graphrag.index.operations.embed_text import embed_text
from graphrag.index.operations.summarize_communities import (
prepare_community_reports,
restore_community_hierarchy,
summarize_communities,
)
@ -49,7 +45,7 @@ async def create_final_community_reports(
claims_input: pd.DataFrame | None,
callbacks: VerbCallbacks,
cache: PipelineCache,
strategy: dict,
summarization_strategy: dict,
async_mode: AsyncType = AsyncType.AsyncIO,
num_threads: int = 4,
full_content_text_embed: dict | None = None,
@ -64,19 +60,23 @@ async def create_final_community_reports(
if claims_input is not None:
claims = _prep_claims(claims_input)
community_hierarchy = restore_community_hierarchy_df(nodes)
community_hierarchy = restore_community_hierarchy(nodes)
local_contexts = prepare_community_reports_df(
nodes, edges, claims, callbacks, strategy.get("max_input_length", 16_000)
local_contexts = prepare_community_reports(
nodes,
edges,
claims,
callbacks,
summarization_strategy.get("max_input_length", 16_000),
)
community_reports = await create_community_reports_df(
community_reports = await summarize_communities(
local_contexts,
nodes,
community_hierarchy,
callbacks,
cache,
strategy,
summarization_strategy,
async_mode=async_mode,
num_threads=num_threads,
)

View File

@ -13,30 +13,30 @@ from datashaper import (
)
from graphrag.index.cache import PipelineCache
from graphrag.index.verbs.covariates.extract_covariates.extract_covariates import (
extract_covariates_df,
from graphrag.index.operations.extract_covariates import (
extract_covariates,
)
async def create_final_covariates(
text_units: pd.DataFrame,
cache: PipelineCache,
callbacks: VerbCallbacks,
cache: PipelineCache,
column: str,
covariate_type: str,
strategy: dict[str, Any] | None,
extraction_strategy: dict[str, Any] | None,
async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
num_threads: int = 4,
) -> pd.DataFrame:
"""All the steps to extract and format covariates."""
covariates = await extract_covariates_df(
covariates = await extract_covariates(
text_units,
cache,
callbacks,
cache,
column,
covariate_type,
strategy,
extraction_strategy,
async_mode,
entity_types,
num_threads,

View File

@ -9,7 +9,7 @@ from datashaper import (
)
from graphrag.index.cache import PipelineCache
from graphrag.index.operations.embed_text.embed_text import embed_text
from graphrag.index.operations.embed_text import embed_text
async def create_final_documents(

View File

@ -9,22 +9,22 @@ from datashaper import (
)
from graphrag.index.cache import PipelineCache
from graphrag.index.operations.embed_text.embed_text import embed_text
from graphrag.index.verbs.graph.unpack import unpack_graph_df
from graphrag.index.verbs.text.split import text_split_df
from graphrag.index.operations.embed_text import embed_text
from graphrag.index.operations.split_text import split_text
from graphrag.index.operations.unpack_graph import unpack_graph
async def create_final_entities(
entity_graph: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
name_text_embed: dict,
description_text_embed: dict,
name_text_embed: dict | None = None,
description_text_embed: dict | None = None,
) -> pd.DataFrame:
"""All the steps to transform final entities."""
# Process nodes
nodes = (
unpack_graph_df(entity_graph, callbacks, "clustered_graph", "nodes")
unpack_graph(entity_graph, callbacks, "clustered_graph", "nodes")
.rename(columns={"label": "name"})
.loc[
:,
@ -44,7 +44,7 @@ async def create_final_entities(
nodes = nodes.loc[nodes["name"].notna()]
# Split 'source_id' column into 'text_unit_ids'
nodes = text_split_df(
nodes = split_text(
nodes, column="source_id", separator=",", to="text_unit_ids"
).drop(columns=["source_id"])

View File

@ -10,27 +10,27 @@ from datashaper import (
VerbCallbacks,
)
from graphrag.index.operations.layout_graph import layout_graph
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.operations.unpack_graph import unpack_graph
from graphrag.index.storage import PipelineStorage
from graphrag.index.verbs.graph.layout.layout_graph import layout_graph_df
from graphrag.index.verbs.graph.unpack import unpack_graph_df
from graphrag.index.verbs.snapshot import snapshot_df
async def create_final_nodes(
entity_graph: pd.DataFrame,
callbacks: VerbCallbacks,
storage: PipelineStorage,
strategy: dict[str, Any],
layout_strategy: dict[str, Any],
level_for_node_positions: int,
snapshot_top_level_nodes: bool = False,
) -> pd.DataFrame:
"""All the steps to transform final nodes."""
laid_out_entity_graph = cast(
pd.DataFrame,
layout_graph_df(
layout_graph(
entity_graph,
callbacks,
strategy,
layout_strategy,
embeddings_column="embeddings",
graph_column="clustered_graph",
to="node_positions",
@ -40,7 +40,7 @@ async def create_final_nodes(
nodes = cast(
pd.DataFrame,
unpack_graph_df(
unpack_graph(
laid_out_entity_graph, callbacks, column="positioned_graph", type="nodes"
),
)
@ -51,7 +51,7 @@ async def create_final_nodes(
nodes = cast(pd.DataFrame, nodes[["id", "x", "y"]])
if snapshot_top_level_nodes:
await snapshot_df(
await snapshot(
nodes,
name="top_level_nodes",
storage=storage,

View File

@ -11,11 +11,11 @@ from datashaper import (
)
from graphrag.index.cache import PipelineCache
from graphrag.index.operations.embed_text.embed_text import embed_text
from graphrag.index.verbs.graph.compute_edge_combined_degree import (
compute_edge_combined_degree_df,
from graphrag.index.operations.compute_edge_combined_degree import (
compute_edge_combined_degree,
)
from graphrag.index.verbs.graph.unpack import unpack_graph_df
from graphrag.index.operations.embed_text import embed_text
from graphrag.index.operations.unpack_graph import unpack_graph
async def create_final_relationships(
@ -26,7 +26,7 @@ async def create_final_relationships(
description_text_embed: dict | None = None,
) -> pd.DataFrame:
"""All the steps to transform final relationships."""
graph_edges = unpack_graph_df(entity_graph, callbacks, "clustered_graph", "edges")
graph_edges = unpack_graph(entity_graph, callbacks, "clustered_graph", "edges")
graph_edges.rename(columns={"source_id": "text_unit_ids"}, inplace=True)
@ -49,7 +49,7 @@ async def create_final_relationships(
filtered_nodes = nodes[nodes["level"] == 0].reset_index(drop=True)
filtered_nodes = cast(pd.DataFrame, filtered_nodes[["title", "degree"]])
edge_combined_degree = compute_edge_combined_degree_df(
edge_combined_degree = compute_edge_combined_degree(
pruned_edges,
filtered_nodes,
to="rank",

View File

@ -11,7 +11,7 @@ from datashaper import (
)
from graphrag.index.cache import PipelineCache
from graphrag.index.operations.embed_text.embed_text import embed_text
from graphrag.index.operations.embed_text import embed_text
async def create_final_text_units(

View File

@ -11,35 +11,35 @@ from datashaper import (
)
from graphrag.index.cache import PipelineCache
from graphrag.index.storage import PipelineStorage
from graphrag.index.verbs.entities.summarize.description_summarize import (
summarize_descriptions_df,
from graphrag.index.operations.snapshot_rows import snapshot_rows
from graphrag.index.operations.summarize_descriptions import (
summarize_descriptions,
)
from graphrag.index.verbs.snapshot_rows import snapshot_rows_df
from graphrag.index.storage import PipelineStorage
async def create_summarized_entities(
entities: pd.DataFrame,
cache: PipelineCache,
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
strategy: dict[str, Any] | None = None,
summarization_strategy: dict[str, Any] | None = None,
num_threads: int = 4,
graphml_snapshot_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to summarize entities."""
summarized = await summarize_descriptions_df(
summarized = await summarize_descriptions(
entities,
cache,
callbacks,
cache,
column="entity_graph",
to="entity_graph",
strategy=strategy,
strategy=summarization_strategy,
num_threads=num_threads,
)
if graphml_snapshot_enabled:
await snapshot_rows_df(
await snapshot_rows(
summarized,
column="entity_graph",
base_name="summarized_graph",

View File

@ -0,0 +1,8 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The Indexing Engine text chunk package root."""
from .chunk_text import ChunkStrategy, ChunkStrategyType, chunk_text
__all__ = ["ChunkStrategy", "ChunkStrategyType", "chunk_text"]

View File

@ -3,59 +3,30 @@
"""A module containing _get_num_total, chunk, run_strategy and load_strategy methods definitions."""
from enum import Enum
from typing import Any, cast
import pandas as pd
from datashaper import (
ProgressTicker,
TableContainer,
VerbCallbacks,
VerbInput,
progress_ticker,
verb,
)
from .strategies.typing import ChunkStrategy as ChunkStrategy
from .typing import ChunkInput
from .typing import ChunkInput, ChunkStrategy, ChunkStrategyType
def _get_num_total(output: pd.DataFrame, column: str) -> int:
num_total = 0
for row in output[column]:
if isinstance(row, str):
num_total += 1
else:
num_total += len(row)
return num_total
class ChunkStrategyType(str, Enum):
"""ChunkStrategy class definition."""
tokens = "tokens"
sentence = "sentence"
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'
@verb(name="chunk")
def chunk(
input: VerbInput,
def chunk_text(
input: pd.DataFrame,
column: str,
to: str,
callbacks: VerbCallbacks,
strategy: dict[str, Any] | None = None,
**_kwargs,
) -> TableContainer:
) -> pd.DataFrame:
"""
Chunk a piece of text into smaller pieces.
## Usage
```yaml
verb: text_chunk
args:
column: <column name> # The name of the column containing the text to chunk, this can either be a column with text, or a column with a list[tuple[doc_id, str]]
to: <column name> # The name of the column to output the chunks to
@ -85,21 +56,6 @@ def chunk(
type: sentence
```
"""
input_table = cast(pd.DataFrame, input.get_input())
output = chunk_df(input_table, column, to, callbacks, strategy)
return TableContainer(table=output)
def chunk_df(
input: pd.DataFrame,
column: str,
to: str,
callbacks: VerbCallbacks,
strategy: dict[str, Any] | None = None,
) -> pd.DataFrame:
"""Chunk a piece of text into smaller pieces."""
output = input
if strategy is None:
strategy = {}
@ -161,17 +117,27 @@ def load_strategy(strategy: ChunkStrategyType) -> ChunkStrategy:
"""Load strategy method definition."""
match strategy:
case ChunkStrategyType.tokens:
from .strategies.tokens import run as run_tokens
from .strategies import run_tokens
return run_tokens
case ChunkStrategyType.sentence:
# NLTK
from graphrag.index.bootstrap import bootstrap
from .strategies.sentence import run as run_sentence
from .strategies import run_sentences
bootstrap()
return run_sentence
return run_sentences
case _:
msg = f"Unknown strategy: {strategy}"
raise ValueError(msg)
def _get_num_total(output: pd.DataFrame, column: str) -> int:
num_total = 0
for row in output[column]:
if isinstance(row, str):
num_total += 1
else:
num_total += len(row)
return num_total

View File

@ -1,23 +1,25 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing run and split_text_on_tokens methods definition."""
"""A module containing chunk strategies."""
from collections.abc import Iterable
from typing import Any
import nltk
import tiktoken
from datashaper import ProgressTicker
import graphrag.config.defaults as defs
from graphrag.index.text_splitting import Tokenizer
from graphrag.index.verbs.text.chunk.typing import TextChunk
from .typing import TextChunk
def run(
def run_tokens(
input: list[str], args: dict[str, Any], tick: ProgressTicker
) -> Iterable[TextChunk]:
"""Chunks text into multiple parts. A pipeline verb."""
"""Chunks text into chunks based on encoding tokens."""
tokens_per_chunk = args.get("chunk_size", defs.CHUNK_SIZE)
chunk_overlap = args.get("chunk_overlap", defs.CHUNK_OVERLAP)
encoding_name = args.get("encoding_name", defs.ENCODING_MODEL)
@ -31,7 +33,7 @@ def run(
def decode(tokens: list[int]) -> str:
return enc.decode(tokens)
return split_text_on_tokens(
return _split_text_on_tokens(
input,
Tokenizer(
chunk_overlap=chunk_overlap,
@ -45,7 +47,7 @@ def run(
# Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471
# So we could have better control over the chunking process
def split_text_on_tokens(
def _split_text_on_tokens(
texts: list[str], enc: Tokenizer, tick: ProgressTicker
) -> list[TextChunk]:
"""Split incoming text and return chunks."""
@ -79,3 +81,17 @@ def split_text_on_tokens(
chunk_ids = input_ids[start_idx:cur_idx]
return result
def run_sentences(
input: list[str], _args: dict[str, Any], tick: ProgressTicker
) -> Iterable[TextChunk]:
"""Chunks text into multiple parts by sentence."""
for doc_idx, text in enumerate(input):
sentences = nltk.sent_tokenize(text)
for sentence in sentences:
yield TextChunk(
text_chunk=sentence,
source_doc_indices=[doc_idx],
)
tick(1)

View File

@ -3,7 +3,12 @@
"""A module containing 'TextChunk' model."""
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from enum import Enum
from typing import Any
from datashaper import ProgressTicker
@dataclass
@ -17,3 +22,18 @@ class TextChunk:
ChunkInput = str | list[str] | list[tuple[str, str]]
"""Input to a chunking strategy. Can be a string, a list of strings, or a list of tuples of (id, text)."""
ChunkStrategy = Callable[
[list[str], dict[str, Any], ProgressTicker], Iterable[TextChunk]
]
class ChunkStrategyType(str, Enum):
"""ChunkStrategy class definition."""
tokens = "tokens"
sentence = "sentence"
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'

View File

@ -10,65 +10,29 @@ from typing import Any, cast
import networkx as nx
import pandas as pd
from datashaper import TableContainer, VerbCallbacks, VerbInput, progress_iterable, verb
from datashaper import VerbCallbacks, progress_iterable
from graspologic.partition import hierarchical_leiden
from graphrag.index.graph.utils import stable_largest_connected_component
from graphrag.index.utils import gen_uuid, load_graph
from .typing import Communities
Communities = list[tuple[int, str, list[str]]]
class GraphCommunityStrategyType(str, Enum):
"""GraphCommunityStrategyType class definition."""
leiden = "leiden"
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'
log = logging.getLogger(__name__)
@verb(name="cluster_graph")
def cluster_graph(
input: VerbInput,
callbacks: VerbCallbacks,
strategy: dict[str, Any],
column: str,
to: str,
level_to: str | None = None,
**_kwargs,
) -> TableContainer:
"""
Apply a hierarchical clustering algorithm to a graph. The graph is expected to be in graphml format. The verb outputs a new column containing the clustered graph, and a new column containing the level of the graph.
## Usage
```yaml
verb: cluster_graph
args:
column: entity_graph # The name of the column containing the graph, should be a graphml graph
to: clustered_graph # The name of the column to output the clustered graph to
level_to: level # The name of the column to output the level to
strategy: <strategy config> # See strategies section below
```
## Strategies
The cluster graph verb uses a strategy to cluster the graph. The strategy is a json object which defines the strategy to use. The following strategies are available:
### leiden
This strategy uses the leiden algorithm to cluster a graph. The strategy config is as follows:
```yaml
strategy:
type: leiden
max_cluster_size: 10 # Optional, The max cluster size to use, default: 10
use_lcc: true # Optional, if the largest connected component should be used with the leiden algorithm, default: true
seed: 0xDEADBEEF # Optional, the seed to use for the leiden algorithm, default: 0xDEADBEEF
levels: [0, 1] # Optional, the levels to output, default: all the levels detected
```
"""
output_df = cluster_graph_df(
cast(pd.DataFrame, input.get_input()),
callbacks,
strategy,
column,
to,
level_to=level_to,
)
return TableContainer(table=output_df)
def cluster_graph_df(
input: pd.DataFrame,
callbacks: VerbCallbacks,
strategy: dict[str, Any],
@ -157,16 +121,6 @@ def apply_clustering(
return graph
class GraphCommunityStrategyType(str, Enum):
"""GraphCommunityStrategyType class definition."""
leiden = "leiden"
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'
def run_layout(
strategy: dict[str, Any], graphml_or_graph: str | nx.Graph
) -> Communities:
@ -180,8 +134,6 @@ def run_layout(
strategy_type = strategy.get("type", GraphCommunityStrategyType.leiden)
match strategy_type:
case GraphCommunityStrategyType.leiden:
from .strategies.leiden import run as run_leiden
clusters = run_leiden(graph, strategy)
case _:
msg = f"Unknown clustering strategy {strategy_type}"
@ -192,3 +144,60 @@ def run_layout(
for cluster_id, nodes in clusters[level].items():
results.append((level, cluster_id, nodes))
return results
def run_leiden(
graph: nx.Graph, args: dict[str, Any]
) -> dict[int, dict[str, list[str]]]:
"""Run method definition."""
max_cluster_size = args.get("max_cluster_size", 10)
use_lcc = args.get("use_lcc", True)
if args.get("verbose", False):
log.info(
"Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc
)
node_id_to_community_map = _compute_leiden_communities(
graph=graph,
max_cluster_size=max_cluster_size,
use_lcc=use_lcc,
seed=args.get("seed", 0xDEADBEEF),
)
levels = args.get("levels")
# If they don't pass in levels, use them all
if levels is None:
levels = sorted(node_id_to_community_map.keys())
results_by_level: dict[int, dict[str, list[str]]] = {}
for level in levels:
result = {}
results_by_level[level] = result
for node_id, raw_community_id in node_id_to_community_map[level].items():
community_id = str(raw_community_id)
if community_id not in result:
result[community_id] = []
result[community_id].append(node_id)
return results_by_level
# Taken from graph_intelligence & adapted
def _compute_leiden_communities(
graph: nx.Graph | nx.DiGraph,
max_cluster_size: int,
use_lcc: bool,
seed=0xDEADBEEF,
) -> dict[int, dict[str, int]]:
"""Return Leiden root communities."""
if use_lcc:
graph = stable_largest_connected_component(graph)
community_mapping = hierarchical_leiden(
graph, max_cluster_size=max_cluster_size, random_seed=seed
)
results: dict[int, dict[str, int]] = {}
for partition in community_mapping:
results[partition.level] = results.get(partition.level, {})
results[partition.level][partition.node] = partition.cluster
return results

View File

@ -0,0 +1,44 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing compute_edge_combined_degree methods definition."""
import pandas as pd
def compute_edge_combined_degree(
edge_df: pd.DataFrame,
node_degree_df: pd.DataFrame,
to: str,
node_name_column: str,
node_degree_column: str,
edge_source_column: str,
edge_target_column: str,
) -> pd.DataFrame:
"""Compute the combined degree for each edge in a graph."""
if to in edge_df.columns:
return edge_df
def join_to_degree(df: pd.DataFrame, column: str) -> pd.DataFrame:
degree_column = _degree_colname(column)
result = df.merge(
node_degree_df.rename(
columns={node_name_column: column, node_degree_column: degree_column}
),
on=column,
how="left",
)
result[degree_column] = result[degree_column].fillna(0)
return result
output_df = join_to_degree(edge_df, edge_source_column)
output_df = join_to_degree(output_df, edge_target_column)
output_df[to] = (
output_df[_degree_colname(edge_source_column)]
+ output_df[_degree_colname(edge_target_column)]
)
return output_df
def _degree_colname(column: str) -> str:
return f"{column}_degree"

View File

@ -4,5 +4,6 @@
"""The Indexing Engine graph embed package root."""
from .embed_graph import EmbedGraphStrategyType, embed_graph
from .typing import NodeEmbeddings
__all__ = ["EmbedGraphStrategyType", "embed_graph"]
__all__ = ["EmbedGraphStrategyType", "NodeEmbeddings", "embed_graph"]

View File

@ -10,6 +10,8 @@ import networkx as nx
import pandas as pd
from datashaper import VerbCallbacks, derive_from_rows
from graphrag.index.graph.embedding import embed_nod2vec
from graphrag.index.graph.utils import stable_largest_connected_component
from graphrag.index.utils import load_graph
from .typing import NodeEmbeddings
@ -85,9 +87,29 @@ def run_embeddings(
graph = load_graph(graphml_or_graph)
match strategy:
case EmbedGraphStrategyType.node2vec:
from .strategies.node_2_vec import run as run_node_2_vec
return run_node_2_vec(graph, args)
case _:
msg = f"Unknown strategy {strategy}"
raise ValueError(msg)
def run_node_2_vec(graph: nx.Graph, args: dict[str, Any]) -> NodeEmbeddings:
"""Run method definition."""
if args.get("use_lcc", True):
graph = stable_largest_connected_component(graph)
# create graph embedding using node2vec
embeddings = embed_nod2vec(
graph=graph,
dimensions=args.get("dimensions", 1536),
num_walks=args.get("num_walks", 10),
walk_length=args.get("walk_length", 40),
window_size=args.get("window_size", 2),
iterations=args.get("iterations", 3),
random_seed=args.get("random_seed", 86),
)
pairs = zip(embeddings.nodes, embeddings.embeddings.tolist(), strict=True)
sorted_pairs = sorted(pairs, key=lambda x: x[0])
return dict(sorted_pairs)

View File

@ -1,4 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Text Embedding strategies."""

View File

@ -1,34 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing run method definition."""
from typing import Any
import networkx as nx
from graphrag.index.graph.embedding import embed_nod2vec
from graphrag.index.graph.utils import stable_largest_connected_component
from graphrag.index.operations.embed_graph.typing import NodeEmbeddings
def run(graph: nx.Graph, args: dict[str, Any]) -> NodeEmbeddings:
"""Run method definition."""
if args.get("use_lcc", True):
graph = stable_largest_connected_component(graph)
# create graph embedding using node2vec
embeddings = embed_nod2vec(
graph=graph,
dimensions=args.get("dimensions", 1536),
num_walks=args.get("num_walks", 10),
walk_length=args.get("walk_length", 40),
window_size=args.get("window_size", 2),
iterations=args.get("iterations", 3),
random_seed=args.get("random_seed", 86),
)
pairs = zip(embeddings.nodes, embeddings.embeddings.tolist(), strict=True)
sorted_pairs = sorted(pairs, key=lambda x: x[0])
return dict(sorted_pairs)

View File

@ -5,70 +5,29 @@
import logging
from dataclasses import asdict
from enum import Enum
from typing import Any, cast
from typing import Any
import pandas as pd
from datashaper import (
AsyncType,
TableContainer,
VerbCallbacks,
VerbInput,
derive_from_rows,
verb,
)
from graphrag.index.cache import PipelineCache
from graphrag.index.verbs.covariates.typing import Covariate, CovariateExtractStrategy
from .typing import Covariate, CovariateExtractStrategy, ExtractClaimsStrategyType
log = logging.getLogger(__name__)
class ExtractClaimsStrategyType(str, Enum):
"""ExtractClaimsStrategyType class definition."""
graph_intelligence = "graph_intelligence"
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]
@verb(name="extract_covariates")
async def extract_covariates(
input: VerbInput,
cache: PipelineCache,
callbacks: VerbCallbacks,
column: str,
covariate_type: str,
strategy: dict[str, Any] | None,
async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
**kwargs,
) -> TableContainer:
"""Extract claims from a piece of text."""
source = cast(pd.DataFrame, input.get_input())
output = await extract_covariates_df(
source,
cache,
callbacks,
column,
covariate_type,
strategy,
async_mode,
entity_types,
**kwargs,
)
return TableContainer(table=output)
async def extract_covariates_df(
input: pd.DataFrame,
cache: PipelineCache,
callbacks: VerbCallbacks,
cache: PipelineCache,
column: str,
covariate_type: str,
strategy: dict[str, Any] | None,
@ -113,9 +72,9 @@ def load_strategy(strategy_type: ExtractClaimsStrategyType) -> CovariateExtractS
"""Load strategy method definition."""
match strategy_type:
case ExtractClaimsStrategyType.graph_intelligence:
from .strategies.graph_intelligence import run as run_gi
from .strategies import run_graph_intelligence
return run_gi
return run_graph_intelligence
case _:
msg = f"Unknown strategy: {strategy_type}"
raise ValueError(msg)

View File

@ -9,35 +9,31 @@ from typing import Any
from datashaper import VerbCallbacks
import graphrag.config.defaults as defs
from graphrag.config.enums import LLMType
from graphrag.index.cache import PipelineCache
from graphrag.index.graph.extractors.claims import ClaimExtractor
from graphrag.index.llm import load_llm
from graphrag.index.verbs.covariates.typing import (
from graphrag.llm import CompletionLLM
from .typing import (
Covariate,
CovariateExtractionResult,
)
from graphrag.llm import CompletionLLM
from .defaults import MOCK_LLM_RESPONSES
async def run(
async def run_graph_intelligence(
input: str | Iterable[str],
entity_types: list[str],
resolved_entities_map: dict[str, str],
reporter: VerbCallbacks,
pipeline_cache: PipelineCache,
callbacks: VerbCallbacks,
cache: PipelineCache,
strategy_config: dict[str, Any],
) -> CovariateExtractionResult:
"""Run the Claim extraction chain."""
llm_config = strategy_config.get(
"llm", {"type": LLMType.StaticResponse, "responses": MOCK_LLM_RESPONSES}
)
llm_type = llm_config.get("type", LLMType.StaticResponse)
llm = load_llm("claim_extraction", llm_type, reporter, pipeline_cache, llm_config)
llm_config = strategy_config.get("llm", {})
llm_type = llm_config.get("type")
llm = load_llm("claim_extraction", llm_type, callbacks, cache, llm_config)
return await _execute(
llm, input, entity_types, resolved_entities_map, reporter, strategy_config
llm, input, entity_types, resolved_entities_map, callbacks, strategy_config
)
@ -46,7 +42,7 @@ async def _execute(
texts: Iterable[str],
entity_types: list[str],
resolved_entities_map: dict[str, str],
reporter: VerbCallbacks,
callbacks: VerbCallbacks,
strategy_config: dict[str, Any],
) -> CovariateExtractionResult:
extraction_prompt = strategy_config.get("extraction_prompt")
@ -62,7 +58,7 @@ async def _execute(
max_gleanings=max_gleanings,
encoding_model=encoding_model,
on_error=lambda e, s, d: (
reporter.error("Claim Extraction Error", e, s, d) if reporter else None
callbacks.error("Claim Extraction Error", e, s, d) if callbacks else None
),
)

View File

@ -5,6 +5,7 @@
from collections.abc import Awaitable, Callable, Iterable
from dataclasses import dataclass
from enum import Enum
from typing import Any
from datashaper import VerbCallbacks
@ -48,3 +49,13 @@ CovariateExtractStrategy = Callable[
],
Awaitable[CovariateExtractionResult],
]
class ExtractClaimsStrategyType(str, Enum):
"""ExtractClaimsStrategyType class definition."""
graph_intelligence = "graph_intelligence"
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'

View File

@ -3,6 +3,6 @@
"""The Indexing Engine entities extraction package root."""
from .entity_extract import ExtractEntityStrategyType, entity_extract
from .extract_entities import ExtractEntityStrategyType, extract_entities
__all__ = ["ExtractEntityStrategyType", "entity_extract"]
__all__ = ["ExtractEntityStrategyType", "extract_entities"]

View File

@ -5,16 +5,13 @@
import logging
from enum import Enum
from typing import Any, cast
from typing import Any
import pandas as pd
from datashaper import (
AsyncType,
TableContainer,
VerbCallbacks,
VerbInput,
derive_from_rows,
verb,
)
from graphrag.index.bootstrap import bootstrap
@ -40,43 +37,10 @@ class ExtractEntityStrategyType(str, Enum):
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]
@verb(name="entity_extract")
async def entity_extract(
input: VerbInput,
cache: PipelineCache,
callbacks: VerbCallbacks,
column: str,
id_column: str,
to: str,
strategy: dict[str, Any] | None,
graph_to: str | None = None,
async_mode: AsyncType = AsyncType.AsyncIO,
entity_types=DEFAULT_ENTITY_TYPES,
**kwargs,
) -> TableContainer:
"""Extract entities from a piece of text."""
source = cast(pd.DataFrame, input.get_input())
output = await entity_extract_df(
source,
cache,
callbacks,
column,
id_column,
to,
strategy,
graph_to,
async_mode,
entity_types,
**kwargs,
)
return TableContainer(table=output)
async def entity_extract_df(
async def extract_entities(
input: pd.DataFrame,
cache: PipelineCache,
callbacks: VerbCallbacks,
cache: PipelineCache,
column: str,
id_column: str,
to: str,
@ -90,24 +54,7 @@ async def entity_extract_df(
Extract entities from a piece of text.
## Usage
### json
```json
{
"verb": "entity_extract",
"args": {
"column": "the_document_text_column_to_extract_entities_from", /* In general this will be your document text column */
"id_column": "the_column_with_the_unique_id_for_each_row", /* In general this will be your document id */
"to": "the_column_to_output_the_entities_to", /* This will be a list[dict[str, Any]] a list of entities, with a name, and additional attributes */
"graph_to": "the_column_to_output_the_graphml_to", /* Optional: This will be a graphml graph in string form which represents the entities and their relationships */
"strategy": {...} <strategy_config>, see strategies section below
"entity_types": ["list", "of", "entity", "types", "to", "extract"] /* Optional: This will limit the entity types extracted, default: ["organization", "person", "geo", "event"] */
"summarize_descriptions" : true | false /* Optional: This will summarize the descriptions of the entities and relationships, default: true */
}
}
```
### yaml
```yaml
verb: entity_extract
args:
column: the_document_text_column_to_extract_entities_from
id_column: the_column_with_the_unique_id_for_each_row
@ -218,9 +165,9 @@ def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStr
"""Load strategy method definition."""
match strategy_type:
case ExtractEntityStrategyType.graph_intelligence:
from .strategies.graph_intelligence import run_gi
from .strategies.graph_intelligence import run_graph_intelligence
return run_gi
return run_graph_intelligence
case ExtractEntityStrategyType.nltk:
bootstrap()

View File

@ -1,51 +1,49 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing run_gi, run_extract_entities and _create_text_splitter methods to run graph intelligence."""
"""A module containing run_graph_intelligence, run_extract_entities and _create_text_splitter methods to run graph intelligence."""
import networkx as nx
from datashaper import VerbCallbacks
import graphrag.config.defaults as defs
from graphrag.config.enums import LLMType
from graphrag.index.cache import PipelineCache
from graphrag.index.graph.extractors.graph import GraphExtractor
from graphrag.index.graph.extractors import GraphExtractor
from graphrag.index.llm import load_llm
from graphrag.index.text_splitting import (
NoopTextSplitter,
TextSplitter,
TokenTextSplitter,
)
from graphrag.index.verbs.entities.extraction.strategies.typing import (
from graphrag.llm import CompletionLLM
from .typing import (
Document,
EntityExtractionResult,
EntityTypes,
StrategyConfig,
)
from graphrag.llm import CompletionLLM
from .defaults import DEFAULT_LLM_CONFIG
async def run_gi(
async def run_graph_intelligence(
docs: list[Document],
entity_types: EntityTypes,
reporter: VerbCallbacks,
pipeline_cache: PipelineCache,
callbacks: VerbCallbacks,
cache: PipelineCache,
args: StrategyConfig,
) -> EntityExtractionResult:
"""Run the graph intelligence entity extraction strategy."""
llm_config = args.get("llm", DEFAULT_LLM_CONFIG)
llm_type = llm_config.get("type", LLMType.StaticResponse)
llm = load_llm("entity_extraction", llm_type, reporter, pipeline_cache, llm_config)
return await run_extract_entities(llm, docs, entity_types, reporter, args)
llm_config = args.get("llm", {})
llm_type = llm_config.get("type")
llm = load_llm("entity_extraction", llm_type, callbacks, cache, llm_config)
return await run_extract_entities(llm, docs, entity_types, callbacks, args)
async def run_extract_entities(
llm: CompletionLLM,
docs: list[Document],
entity_types: EntityTypes,
reporter: VerbCallbacks | None,
callbacks: VerbCallbacks | None,
args: StrategyConfig,
) -> EntityExtractionResult:
"""Run the entity extraction chain."""
@ -76,7 +74,7 @@ async def run_extract_entities(
encoding_model=encoding_model,
max_gleanings=max_gleanings,
on_error=lambda e, s, d: (
reporter.error("Entity Extraction Error", e, s, d) if reporter else None
callbacks.error("Entity Extraction Error", e, s, d) if callbacks else None
),
)
text_list = [doc.text.strip() for doc in docs]

View File

@ -19,8 +19,8 @@ words.ensure_loaded()
async def run( # noqa RUF029 async is required for interface
docs: list[Document],
entity_types: EntityTypes,
reporter: VerbCallbacks, # noqa ARG001
pipeline_cache: PipelineCache, # noqa ARG001
callbacks: VerbCallbacks, # noqa ARG001
cache: PipelineCache, # noqa ARG001
args: StrategyConfig, # noqa ARG001
) -> EntityExtractionResult:
"""Run method definition."""

View File

@ -8,10 +8,10 @@ from typing import Any, cast
import networkx as nx
import pandas as pd
from datashaper import TableContainer, VerbCallbacks, VerbInput, progress_callback, verb
from datashaper import VerbCallbacks, progress_callback
from graphrag.index.graph.visualization import GraphLayout
from graphrag.index.operations.embed_graph.typing import NodeEmbeddings
from graphrag.index.operations.embed_graph import NodeEmbeddings
from graphrag.index.utils import load_graph
@ -26,23 +26,20 @@ class LayoutGraphStrategyType(str, Enum):
return f'"{self.value}"'
@verb(name="layout_graph")
def layout_graph(
input: VerbInput,
input_df: pd.DataFrame,
callbacks: VerbCallbacks,
strategy: dict[str, Any],
embeddings_column: str,
graph_column: str,
to: str,
graph_to: str | None = None,
**_kwargs: dict,
) -> TableContainer:
):
"""
Apply a layout algorithm to a graph. The graph is expected to be in graphml format. The verb outputs a new column containing the laid out graph.
## Usage
```yaml
verb: layout_graph
args:
graph_column: clustered_graph # The name of the column containing the graph, should be a graphml graph
embeddings_column: embeddings # The name of the column containing the embeddings
@ -63,24 +60,6 @@ def layout_graph(
min_dist: 0.75 # Optional, The min distance to use for the umap algorithm, default: 0.75
```
"""
input_df = cast(pd.DataFrame, input.get_input())
output_df = layout_graph_df(
input_df, callbacks, strategy, embeddings_column, graph_column, to, graph_to
)
return TableContainer(table=output_df)
def layout_graph_df(
input_df: pd.DataFrame,
callbacks: VerbCallbacks,
strategy: dict[str, Any],
embeddings_column: str,
graph_column: str,
to: str,
graph_to: str | None = None,
):
"""Apply a layout algorithm to a graph."""
output_df = input_df
num_items = len(output_df)
strategy_type = strategy.get("type", LayoutGraphStrategyType.umap)
@ -118,7 +97,7 @@ def _run_layout(
graphml_or_graph: str | nx.Graph,
embeddings: NodeEmbeddings,
args: dict[str, Any],
reporter: VerbCallbacks,
callbacks: VerbCallbacks,
) -> GraphLayout:
graph = load_graph(graphml_or_graph)
match strategy:
@ -129,7 +108,7 @@ def _run_layout(
graph,
embeddings,
args,
lambda e, stack, d: reporter.error("Error in Umap", e, stack, d),
lambda e, stack, d: callbacks.error("Error in Umap", e, stack, d),
)
case LayoutGraphStrategyType.zero:
from .methods.zero import run as run_zero
@ -137,7 +116,7 @@ def _run_layout(
return run_zero(
graph,
args,
lambda e, stack, d: reporter.error("Error in Zero", e, stack, d),
lambda e, stack, d: callbacks.error("Error in Zero", e, stack, d),
)
case _:
msg = f"Unknown strategy {strategy}"

View File

@ -15,7 +15,7 @@ from graphrag.index.graph.visualization import (
NodePosition,
compute_umap_positions,
)
from graphrag.index.operations.embed_graph.typing import NodeEmbeddings
from graphrag.index.operations.embed_graph import NodeEmbeddings
from graphrag.index.typing import ErrorHandlerFn
# TODO: This could be handled more elegantly, like what columns to use

View File

@ -1,8 +1,10 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The Indexing Engine graph merge package root."""
"""merge_graphs operation."""
from .merge_graphs import merge_graphs
__all__ = ["merge_graphs"]
__all__ = [
"merge_graphs",
]

View File

@ -7,15 +7,10 @@ from typing import Any, cast
import networkx as nx
import pandas as pd
from datashaper import TableContainer, VerbCallbacks, VerbInput, progress_iterable, verb
from datashaper import VerbCallbacks, progress_iterable
from graphrag.index.utils import load_graph
from .defaults import (
DEFAULT_CONCAT_SEPARATOR,
DEFAULT_EDGE_OPERATIONS,
DEFAULT_NODE_OPERATIONS,
)
from .typing import (
BasicMergeOperation,
DetailedAttributeMergeOperation,
@ -23,25 +18,23 @@ from .typing import (
StringOperation,
)
DEFAULT_NODE_OPERATIONS = {
"*": {
"operation": BasicMergeOperation.Replace,
}
}
DEFAULT_EDGE_OPERATIONS = {
"*": {
"operation": BasicMergeOperation.Replace,
},
"weight": "sum",
}
DEFAULT_CONCAT_SEPARATOR = ","
@verb(name="merge_graphs")
def merge_graphs(
input: VerbInput,
callbacks: VerbCallbacks,
column: str,
to: str,
nodes: dict[str, Any] = DEFAULT_NODE_OPERATIONS,
edges: dict[str, Any] = DEFAULT_EDGE_OPERATIONS,
**_kwargs,
) -> TableContainer:
"""Merge multiple graphs together. The graphs are expected to be in graphml format. The verb outputs a new column containing the merged graph."""
input_df = cast(pd.DataFrame, input.get_input())
output = merge_graphs_df(input_df, callbacks, column, to, nodes, edges)
return TableContainer(table=output)
def merge_graphs_df(
input: pd.DataFrame,
callbacks: VerbCallbacks,
column: str,

View File

@ -3,31 +3,12 @@
"""A module containing snapshot method definition."""
from typing import cast
import pandas as pd
from datashaper import TableContainer, VerbInput, verb
from graphrag.index.storage import PipelineStorage
@verb(name="snapshot")
async def snapshot(
input: VerbInput,
name: str,
formats: list[str],
storage: PipelineStorage,
**_kwargs: dict,
) -> TableContainer:
"""Take a entire snapshot of the tabular data."""
source = cast(pd.DataFrame, input.get_input())
await snapshot_df(source, name, formats, storage)
return TableContainer(table=source)
async def snapshot_df(
input: pd.DataFrame,
name: str,
formats: list[str],

View File

@ -5,10 +5,9 @@
import json
from dataclasses import dataclass
from typing import Any, cast
from typing import Any
import pandas as pd
from datashaper import TableContainer, VerbInput, verb
from graphrag.index.storage import PipelineStorage
@ -21,30 +20,7 @@ class FormatSpecifier:
extension: str
@verb(name="snapshot_rows")
async def snapshot_rows(
input: VerbInput,
column: str | None,
base_name: str,
storage: PipelineStorage,
formats: list[str | dict[str, Any]],
row_name_column: str | None = None,
**_kwargs: dict,
) -> TableContainer:
"""Take a by-row snapshot of the tabular data."""
source = cast(pd.DataFrame, input.get_input())
await snapshot_rows_df(
source,
column,
base_name,
storage,
formats,
row_name_column,
)
return TableContainer(table=source)
async def snapshot_rows_df(
input: pd.DataFrame,
column: str | None,
base_name: str,

View File

@ -0,0 +1,26 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing the split_text method definition."""
import pandas as pd
def split_text(
input: pd.DataFrame, column: str, to: str, separator: str = ","
) -> pd.DataFrame:
"""Split a column into a list of strings."""
output = input
def _apply_split(row):
if row[column] is None or isinstance(row[column], list):
return row[column]
if row[column] == "":
return []
if not isinstance(row[column], str):
message = f"Expected {column} to be a string, but got {type(row[column])}"
raise TypeError(message)
return row[column].split(separator)
output[to] = output.apply(_apply_split, axis=1)
return output

View File

@ -0,0 +1,16 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Community summarization modules."""
from .prepare_community_reports import prepare_community_reports
from .restore_community_hierarchy import restore_community_hierarchy
from .summarize_communities import summarize_communities
from .typing import CreateCommunityReportsStrategyType
__all__ = [
"CreateCommunityReportsStrategyType",
"prepare_community_reports",
"restore_community_hierarchy",
"summarize_communities",
]

View File

@ -8,11 +8,8 @@ from typing import cast
import pandas as pd
from datashaper import (
TableContainer,
VerbCallbacks,
VerbInput,
progress_iterable,
verb,
)
import graphrag.index.graph.extractors.community_reports.schemas as schemas
@ -25,32 +22,11 @@ from graphrag.index.graph.extractors.community_reports import (
set_context_size,
sort_context,
)
from graphrag.index.utils.ds_util import get_named_input_table, get_required_input_table
log = logging.getLogger(__name__)
@verb(name="prepare_community_reports")
def prepare_community_reports(
input: VerbInput,
callbacks: VerbCallbacks,
max_tokens: int = 16_000,
**_kwargs,
) -> TableContainer:
"""Prep communities for report generation."""
# Prepare Community Reports
nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table)
edges = cast(pd.DataFrame, get_required_input_table(input, "edges").table)
claims = get_named_input_table(input, "claims")
if claims:
claims = cast(pd.DataFrame, claims.table)
output = prepare_community_reports_df(nodes, edges, claims, callbacks, max_tokens)
return TableContainer(table=output)
def prepare_community_reports_df(
nodes,
edges,
claims,

View File

@ -4,37 +4,15 @@
"""A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition."""
import logging
from typing import cast
import pandas as pd
from datashaper import TableContainer, VerbInput, verb
import graphrag.index.graph.extractors.community_reports.schemas as schemas
log = logging.getLogger(__name__)
@verb(name="restore_community_hierarchy")
def restore_community_hierarchy(
input: VerbInput,
name_column: str = schemas.NODE_NAME,
community_column: str = schemas.NODE_COMMUNITY,
level_column: str = schemas.NODE_LEVEL,
**_kwargs,
) -> TableContainer:
"""Restore the community hierarchy from the node data."""
node_df: pd.DataFrame = cast(pd.DataFrame, input.get_input())
output = restore_community_hierarchy_df(
node_df,
name_column=name_column,
community_column=community_column,
level_column=level_column,
)
return TableContainer(table=output)
def restore_community_hierarchy_df(
input: pd.DataFrame,
name_column: str = schemas.NODE_NAME,
community_column: str = schemas.NODE_COMMUNITY,

View File

@ -9,41 +9,37 @@ import traceback
from datashaper import VerbCallbacks
from graphrag.config.enums import LLMType
from graphrag.index.cache import PipelineCache
from graphrag.index.graph.extractors.community_reports import (
CommunityReportsExtractor,
)
from graphrag.index.llm import load_llm
from graphrag.index.utils.rate_limiter import RateLimiter
from graphrag.index.verbs.graph.report.strategies.typing import (
from graphrag.llm import CompletionLLM
from .typing import (
CommunityReport,
StrategyConfig,
)
from graphrag.llm import CompletionLLM
from .defaults import MOCK_RESPONSES
DEFAULT_CHUNK_SIZE = 3000
log = logging.getLogger(__name__)
async def run(
async def run_graph_intelligence(
community: str | int,
input: str,
level: int,
reporter: VerbCallbacks,
pipeline_cache: PipelineCache,
callbacks: VerbCallbacks,
cache: PipelineCache,
args: StrategyConfig,
) -> CommunityReport | None:
"""Run the graph intelligence entity extraction strategy."""
llm_config = args.get(
"llm", {"type": LLMType.StaticResponse, "responses": MOCK_RESPONSES}
)
llm_type = llm_config.get("type", LLMType.StaticResponse)
llm = load_llm(
"community_reporting", llm_type, reporter, pipeline_cache, llm_config
)
return await _run_extractor(llm, community, input, level, args, reporter)
llm_config = args.get("llm", {})
llm_type = llm_config.get("type")
llm = load_llm("community_reporting", llm_type, callbacks, cache, llm_config)
return await _run_extractor(llm, community, input, level, args, callbacks)
async def _run_extractor(
@ -52,7 +48,7 @@ async def _run_extractor(
input: str,
level: int,
args: StrategyConfig,
reporter: VerbCallbacks,
callbacks: VerbCallbacks,
) -> CommunityReport | None:
# RateLimiter
rate_limiter = RateLimiter(rate=1, per=60)
@ -60,7 +56,7 @@ async def _run_extractor(
llm,
extraction_prompt=args.get("extraction_prompt", None),
max_report_length=args.get("max_report_length", None),
on_error=lambda e, stack, _data: reporter.error(
on_error=lambda e, stack, _data: callbacks.error(
"Community Report Extraction Error", e, stack
),
)
@ -86,7 +82,7 @@ async def _run_extractor(
)
except Exception as e:
log.exception("Error processing community: %s", community)
reporter.error("Community Report Extraction Error", e, traceback.format_exc())
callbacks.error("Community Report Extraction Error", e, traceback.format_exc())
return None

View File

@ -4,19 +4,14 @@
"""A module containing create_community_reports and load_strategy methods definition."""
import logging
from enum import Enum
from typing import cast
import pandas as pd
from datashaper import (
AsyncType,
NoopVerbCallbacks,
TableContainer,
VerbCallbacks,
VerbInput,
derive_from_rows,
progress_ticker,
verb,
)
import graphrag.config.defaults as defaults
@ -26,54 +21,17 @@ from graphrag.index.graph.extractors.community_reports import (
get_levels,
prep_community_report_context,
)
from graphrag.index.utils.ds_util import get_required_input_table
from .strategies.typing import CommunityReport, CommunityReportsStrategy
from .typing import (
CommunityReport,
CommunityReportsStrategy,
CreateCommunityReportsStrategyType,
)
log = logging.getLogger(__name__)
class CreateCommunityReportsStrategyType(str, Enum):
"""CreateCommunityReportsStrategyType class definition."""
graph_intelligence = "graph_intelligence"
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'
@verb(name="create_community_reports")
async def create_community_reports(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
strategy: dict,
async_mode: AsyncType = AsyncType.AsyncIO,
num_threads: int = 4,
**_kwargs,
) -> TableContainer:
"""Generate community summaries."""
log.debug("create_community_reports strategy=%s", strategy)
local_contexts = cast(pd.DataFrame, input.get_input())
nodes = get_required_input_table(input, "nodes").table
community_hierarchy = get_required_input_table(input, "community_hierarchy").table
output = await create_community_reports_df(
local_contexts,
nodes,
community_hierarchy,
callbacks,
cache,
strategy,
async_mode=async_mode,
num_threads=num_threads,
)
return TableContainer(table=output)
async def create_community_reports_df(
async def summarize_communities(
local_contexts,
nodes,
community_hierarchy,
@ -106,8 +64,8 @@ async def create_community_reports_df(
community_id=record[schemas.NODE_COMMUNITY],
community_level=record[schemas.COMMUNITY_LEVEL],
community_context=record[schemas.CONTEXT_STRING],
cache=cache,
callbacks=callbacks,
cache=cache,
strategy=strategy,
)
tick()
@ -127,8 +85,8 @@ async def create_community_reports_df(
async def _generate_report(
runner: CommunityReportsStrategy,
cache: PipelineCache,
callbacks: VerbCallbacks,
cache: PipelineCache,
strategy: dict,
community_id: int | str,
community_level: int,
@ -146,9 +104,9 @@ def load_strategy(
"""Load strategy method definition."""
match strategy:
case CreateCommunityReportsStrategyType.graph_intelligence:
from .strategies.graph_intelligence import run
from .strategies import run_graph_intelligence
return run
return run_graph_intelligence
case _:
msg = f"Unknown strategy: {strategy}"
raise ValueError(msg)

View File

@ -4,6 +4,7 @@
"""A module containing 'Finding' and 'CommunityReport' models."""
from collections.abc import Awaitable, Callable
from enum import Enum
from typing import Any
from datashaper import VerbCallbacks
@ -50,3 +51,13 @@ CommunityReportsStrategy = Callable[
],
Awaitable[CommunityReport | None],
]
class CreateCommunityReportsStrategyType(str, Enum):
"""CreateCommunityReportsStrategyType class definition."""
graph_intelligence = "graph_intelligence"
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'

View File

@ -0,0 +1,13 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Root package for description summarization."""
from .summarize_descriptions import summarize_descriptions
from .typing import SummarizationStrategy, SummarizeStrategyType
__all__ = [
"SummarizationStrategy",
"SummarizeStrategyType",
"summarize_descriptions",
]

View File

@ -1,38 +1,34 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing run_gi, run_resolve_entities and _create_text_list_splitter methods to run graph intelligence."""
"""A module containing run_graph_intelligence, run_resolve_entities and _create_text_list_splitter methods to run graph intelligence."""
from datashaper import VerbCallbacks
from graphrag.config.enums import LLMType
from graphrag.index.cache import PipelineCache
from graphrag.index.graph.extractors.summarize import SummarizeExtractor
from graphrag.index.llm import load_llm
from graphrag.index.verbs.entities.summarize.strategies.typing import (
from graphrag.llm import CompletionLLM
from .typing import (
StrategyConfig,
SummarizedDescriptionResult,
)
from graphrag.llm import CompletionLLM
from .defaults import DEFAULT_LLM_CONFIG
async def run(
async def run_graph_intelligence(
described_items: str | tuple[str, str],
descriptions: list[str],
reporter: VerbCallbacks,
pipeline_cache: PipelineCache,
callbacks: VerbCallbacks,
cache: PipelineCache,
args: StrategyConfig,
) -> SummarizedDescriptionResult:
"""Run the graph intelligence entity extraction strategy."""
llm_config = args.get("llm", DEFAULT_LLM_CONFIG)
llm_type = llm_config.get("type", LLMType.StaticResponse)
llm = load_llm(
"summarize_descriptions", llm_type, reporter, pipeline_cache, llm_config
)
llm_config = args.get("llm", {})
llm_type = llm_config.get("type")
llm = load_llm("summarize_descriptions", llm_type, callbacks, cache, llm_config)
return await run_summarize_descriptions(
llm, described_items, descriptions, reporter, args
llm, described_items, descriptions, callbacks, args
)
@ -40,7 +36,7 @@ async def run_summarize_descriptions(
llm: CompletionLLM,
items: str | tuple[str, str],
descriptions: list[str],
reporter: VerbCallbacks,
callbacks: VerbCallbacks,
args: StrategyConfig,
) -> SummarizedDescriptionResult:
"""Run the entity extraction chain."""
@ -56,8 +52,8 @@ async def run_summarize_descriptions(
entity_name_key=entity_name_key,
input_descriptions_key=input_descriptions_key,
on_error=lambda e, stack, details: (
reporter.error("Entity Extraction Error", e, stack, details)
if reporter
callbacks.error("Entity Extraction Error", e, stack, details)
if callbacks
else None
),
max_summary_length=args.get("max_summary_length", None),

View File

@ -5,72 +5,32 @@
import asyncio
import logging
from enum import Enum
from typing import Any, NamedTuple, cast
from typing import Any, cast
import networkx as nx
import pandas as pd
from datashaper import (
ProgressTicker,
TableContainer,
VerbCallbacks,
VerbInput,
progress_ticker,
verb,
)
from graphrag.index.cache import PipelineCache
from graphrag.index.utils import load_graph
from .strategies.typing import SummarizationStrategy
from .typing import (
DescriptionSummarizeRow,
SummarizationStrategy,
SummarizeStrategyType,
)
log = logging.getLogger(__name__)
class DescriptionSummarizeRow(NamedTuple):
"""DescriptionSummarizeRow class definition."""
graph: Any
class SummarizeStrategyType(str, Enum):
"""SummarizeStrategyType class definition."""
graph_intelligence = "graph_intelligence"
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'
@verb(name="summarize_descriptions")
async def summarize_descriptions(
input: VerbInput,
cache: PipelineCache,
callbacks: VerbCallbacks,
column: str,
to: str,
strategy: dict[str, Any] | None = None,
**kwargs,
) -> TableContainer:
"""Summarize entity and relationship descriptions from an entity graph."""
source = cast(pd.DataFrame, input.get_input())
output = await summarize_descriptions_df(
source,
cache,
callbacks,
column=column,
to=to,
strategy=strategy,
**kwargs,
)
return TableContainer(table=output)
async def summarize_descriptions_df(
input: pd.DataFrame,
cache: PipelineCache,
callbacks: VerbCallbacks,
cache: PipelineCache,
column: str,
to: str,
strategy: dict[str, Any] | None = None,
@ -99,7 +59,6 @@ async def summarize_descriptions_df(
### yaml
```yaml
verb: entity_extract
args:
column: the_document_text_column_to_extract_descriptions_from
to: the_column_to_output_the_summarized_descriptions_to
@ -221,9 +180,9 @@ def load_strategy(strategy_type: SummarizeStrategyType) -> SummarizationStrategy
"""Load strategy method definition."""
match strategy_type:
case SummarizeStrategyType.graph_intelligence:
from .strategies.graph_intelligence import run as run_gi
from .strategies import run_graph_intelligence
return run_gi
return run_graph_intelligence
case _:
msg = f"Unknown strategy: {strategy_type}"
raise ValueError(msg)

View File

@ -5,7 +5,8 @@
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any
from enum import Enum
from typing import Any, NamedTuple
from datashaper import VerbCallbacks
@ -32,3 +33,19 @@ SummarizationStrategy = Callable[
],
Awaitable[SummarizedDescriptionResult],
]
class DescriptionSummarizeRow(NamedTuple):
"""DescriptionSummarizeRow class definition."""
graph: Any
class SummarizeStrategyType(str, Enum):
"""SummarizeStrategyType class definition."""
graph_intelligence = "graph_intelligence"
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'

View File

@ -7,57 +7,20 @@ from typing import Any, cast
import networkx as nx
import pandas as pd
from datashaper import TableContainer, VerbCallbacks, VerbInput, progress_iterable, verb
from datashaper import VerbCallbacks, progress_iterable
from graphrag.index.utils import load_graph
default_copy = ["level"]
@verb(name="unpack_graph")
def unpack_graph(
input: VerbInput,
callbacks: VerbCallbacks,
column: str,
type: str, # noqa A002
copy: list[str] | None = None,
embeddings_column: str = "embeddings",
**kwargs,
) -> TableContainer:
"""
Unpack nodes or edges from a graphml graph, into a list of nodes or edges.
This verb will create columns for each attribute in a node or edge.
## Usage
```yaml
verb: unpack_graph
args:
type: node # The type of data to unpack, one of: node, edge. node will create a node list, edge will create an edge list
column: <column name> # The name of the column containing the graph, should be a graphml graph
```
"""
input_df = input.get_input()
output_df = unpack_graph_df(
cast(pd.DataFrame, input_df),
callbacks,
column,
type,
copy,
embeddings_column,
kwargs=kwargs,
)
return TableContainer(table=output_df)
def unpack_graph_df(
input_df: pd.DataFrame,
callbacks: VerbCallbacks,
column: str,
type: str, # noqa A002
copy: list[str] | None = None,
embeddings_column: str = "embeddings",
**kwargs,
) -> pd.DataFrame:
"""Unpack nodes or edges from a graphml graph, into a list of nodes or edges."""
if copy is None:
@ -83,7 +46,6 @@ def unpack_graph_df(
cast(str | nx.Graph, row[column]),
type,
embeddings,
kwargs,
)
])
@ -94,19 +56,18 @@ def _run_unpack(
graphml_or_graph: str | nx.Graph,
unpack_type: str,
embeddings: dict[str, list[float]],
args: dict[str, Any],
) -> list[dict[str, Any]]:
graph = load_graph(graphml_or_graph)
if unpack_type == "nodes":
return _unpack_nodes(graph, embeddings, args)
return _unpack_nodes(graph, embeddings)
if unpack_type == "edges":
return _unpack_edges(graph, args)
return _unpack_edges(graph)
msg = f"Unknown type {unpack_type}"
raise ValueError(msg)
def _unpack_nodes(
graph: nx.Graph, embeddings: dict[str, list[float]], _args: dict[str, Any]
graph: nx.Graph, embeddings: dict[str, list[float]]
) -> list[dict[str, Any]]:
return [
{
@ -118,7 +79,7 @@ def _unpack_nodes(
]
def _unpack_edges(graph: nx.Graph, _args: dict[str, Any]) -> list[dict[str, Any]]:
def _unpack_edges(graph: nx.Graph) -> list[dict[str, Any]]:
return [
{
"source": source_id,

View File

@ -47,7 +47,6 @@ from graphrag.index.typing import PipelineRunResult
# Register all verbs
from graphrag.index.update.dataframes import get_delta_docs, update_dataframe_outputs
from graphrag.index.verbs import * # noqa
from graphrag.index.workflows import (
VerbDefinitions,
WorkflowDefinitions,

View File

@ -1,46 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing get_default_verbs method definition."""
from .covariates import extract_covariates
from .entities import entity_extract, summarize_descriptions
from .genid import genid
from .graph import (
cluster_graph,
create_community_reports,
create_graph,
layout_graph,
merge_graphs,
unpack_graph,
)
from .overrides import aggregate, concat
from .snapshot import snapshot
from .snapshot_rows import snapshot_rows
from .spread_json import spread_json
from .text import chunk, text_split, text_translate
from .unzip import unzip
from .zip import zip_verb
__all__ = [
"aggregate",
"chunk",
"cluster_graph",
"concat",
"create_community_reports",
"create_graph",
"entity_extract",
"extract_covariates",
"genid",
"layout_graph",
"merge_graphs",
"snapshot",
"snapshot_rows",
"spread_json",
"summarize_descriptions",
"text_split",
"text_translate",
"unpack_graph",
"unzip",
"zip_verb",
]

View File

@ -1,8 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The Indexing Engine covariates package root."""
from .extract_covariates import extract_covariates
__all__ = ["extract_covariates"]

View File

@ -1,4 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The Indexing Engine text extract claims strategies package root."""

View File

@ -1,8 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The Indexing Engine text extract claims strategies graph intelligence package root."""
from .run_gi_extract_claims import run
__all__ = ["run"]

View File

@ -1,10 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A file containing MOCK_LLM_RESPONSES definition."""
MOCK_LLM_RESPONSES = [
"""
(COMPANY A<|>GOVERNMENT AGENCY B<|>ANTI-COMPETITIVE PRACTICES<|>TRUE<|>2022-01-10T00:00:00<|>2022-01-10T00:00:00<|>Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10<|>According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
""".strip()
]

View File

@ -1,9 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The Indexing Engine entities package root."""
from .extraction import entity_extract
from .summarize import summarize_descriptions
__all__ = ["entity_extract", "summarize_descriptions"]

View File

@ -1,8 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The Indexing Engine graph intelligence package root."""
from .run_graph_intelligence import run_gi
__all__ = ["run_gi"]

View File

@ -1,25 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A file containing some default responses."""
from graphrag.config.enums import LLMType
MOCK_LLM_RESPONSES = [
"""
("entity"<|>COMPANY_A<|>COMPANY<|>Company_A is a test company)
##
("entity"<|>COMPANY_B<|>COMPANY<|>Company_B owns Company_A and also shares an address with Company_A)
##
("entity"<|>PERSON_C<|>PERSON<|>Person_C is director of Company_A)
##
("relationship"<|>COMPANY_A<|>COMPANY_B<|>Company_A and Company_B are related because Company_A is 100% owned by Company_B and the two companies also share the same address)<|>2)
##
("relationship"<|>COMPANY_A<|>PERSON_C<|>Company_A and Person_C are related because Person_C is director of Company_A<|>1))
""".strip()
]
DEFAULT_LLM_CONFIG = {
"type": LLMType.StaticResponse,
"responses": MOCK_LLM_RESPONSES,
}

View File

@ -1,8 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Root package for entity summarization."""
from .description_summarize import SummarizeStrategyType, summarize_descriptions
__all__ = ["SummarizeStrategyType", "summarize_descriptions"]

View File

@ -1,8 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Indexing Engine - Summarization Strategies Package."""
from .typing import SummarizationStrategy
__all__ = ["SummarizationStrategy"]

View File

@ -1,8 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The Entity summarization graph intelligence package root."""
from .run_graph_intelligence import run
__all__ = ["run"]

View File

@ -1,17 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A file containing some default responses."""
from graphrag.config.enums import LLMType
MOCK_LLM_RESPONSES = [
"""
This is a MOCK response for the LLM. It is summarized!
""".strip()
]
DEFAULT_LLM_CONFIG = {
"type": LLMType.StaticResponse,
"responses": MOCK_LLM_RESPONSES,
}

View File

@ -1,80 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing genid method definition."""
from typing import cast
import pandas as pd
from datashaper import TableContainer, VerbInput, verb
from graphrag.index.utils import gen_md5_hash
@verb(name="genid")
def genid(
input: VerbInput,
to: str,
method: str = "md5_hash",
hash: list[str] | None = None, # noqa A002
**_kwargs: dict,
) -> TableContainer:
"""
Generate a unique id for each row in the tabular data.
## Usage
### json
```json
{
"verb": "genid",
"args": {
"to": "id_output_column_name", /* The name of the column to output the id to */
"method": "md5_hash", /* The method to use to generate the id */
"hash": ["list", "of", "column", "names"] /* only if using md5_hash */,
"seed": 034324 /* The random seed to use with UUID */
}
}
```
### yaml
```yaml
verb: genid
args:
to: id_output_column_name
method: md5_hash
hash:
- list
- of
- column
- names
seed: 034324
```
"""
data = cast(pd.DataFrame, input.source.table)
output = genid_df(data, to, method, hash)
return TableContainer(table=output)
def genid_df(
input: pd.DataFrame,
to: str,
method: str = "md5_hash",
hash: list[str] | None = None, # noqa A002
):
"""Generate a unique id for each row in the tabular data."""
data = input
match method:
case "md5_hash":
if not hash:
msg = 'Must specify the "hash" columns to use md5_hash method'
raise ValueError(msg)
data[to] = data.apply(lambda row: gen_md5_hash(row, hash), axis=1)
case "increment":
data[to] = data.index + 1
case _:
msg = f"Unknown method {method}"
raise ValueError(msg)
return data

View File

@ -1,34 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The Indexing Engine graph package root."""
from .clustering import cluster_graph
from .compute_edge_combined_degree import compute_edge_combined_degree
from .create import DEFAULT_EDGE_ATTRIBUTES, DEFAULT_NODE_ATTRIBUTES, create_graph
from .layout import layout_graph
from .merge import merge_graphs
from .report import (
create_community_reports,
prepare_community_reports,
prepare_community_reports_claims,
prepare_community_reports_edges,
restore_community_hierarchy,
)
from .unpack import unpack_graph
__all__ = [
"DEFAULT_EDGE_ATTRIBUTES",
"DEFAULT_NODE_ATTRIBUTES",
"cluster_graph",
"compute_edge_combined_degree",
"create_community_reports",
"create_graph",
"layout_graph",
"merge_graphs",
"prepare_community_reports",
"prepare_community_reports_claims",
"prepare_community_reports_edges",
"restore_community_hierarchy",
"unpack_graph",
]

View File

@ -1,8 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The Indexing Engine graph clustering package root."""
from .cluster_graph import GraphCommunityStrategyType, cluster_graph
__all__ = ["GraphCommunityStrategyType", "cluster_graph"]

View File

@ -1,4 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Graph Clustering Strategies."""

View File

@ -1,69 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing run and _compute_leiden_communities methods definitions."""
import logging
from typing import Any
import networkx as nx
from graspologic.partition import hierarchical_leiden
from graphrag.index.graph.utils import stable_largest_connected_component
log = logging.getLogger(__name__)
def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, list[str]]]:
"""Run method definition."""
max_cluster_size = args.get("max_cluster_size", 10)
use_lcc = args.get("use_lcc", True)
if args.get("verbose", False):
log.info(
"Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc
)
node_id_to_community_map = _compute_leiden_communities(
graph=graph,
max_cluster_size=max_cluster_size,
use_lcc=use_lcc,
seed=args.get("seed", 0xDEADBEEF),
)
levels = args.get("levels")
# If they don't pass in levels, use them all
if levels is None:
levels = sorted(node_id_to_community_map.keys())
results_by_level: dict[int, dict[str, list[str]]] = {}
for level in levels:
result = {}
results_by_level[level] = result
for node_id, raw_community_id in node_id_to_community_map[level].items():
community_id = str(raw_community_id)
if community_id not in result:
result[community_id] = []
result[community_id].append(node_id)
return results_by_level
# Taken from graph_intelligence & adapted
def _compute_leiden_communities(
graph: nx.Graph | nx.DiGraph,
max_cluster_size: int,
use_lcc: bool,
seed=0xDEADBEEF,
) -> dict[int, dict[str, int]]:
"""Return Leiden root communities."""
if use_lcc:
graph = stable_largest_connected_component(graph)
community_mapping = hierarchical_leiden(
graph, max_cluster_size=max_cluster_size, random_seed=seed
)
results: dict[int, dict[str, int]] = {}
for partition in community_mapping:
results[partition.level] = results.get(partition.level, {})
results[partition.level][partition.node] = partition.cluster
return results

View File

@ -1,6 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing Communities list definition."""
Communities = list[tuple[int, str, list[str]]]

View File

@ -1,93 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition."""
from typing import cast
import pandas as pd
from datashaper import TableContainer, VerbInput, verb
from graphrag.index.utils.ds_util import get_required_input_table
@verb(name="compute_edge_combined_degree")
def compute_edge_combined_degree(
input: VerbInput,
to: str = "rank",
node_name_column: str = "title",
node_degree_column: str = "degree",
edge_source_column: str = "source",
edge_target_column: str = "target",
**_kwargs,
) -> TableContainer:
"""
Compute the combined degree for each edge in a graph.
Inputs Tables:
- input: The edge table
- nodes: The nodes table.
Args:
- to: The name of the column to output the combined degree to. Default="rank"
"""
edge_df: pd.DataFrame = cast(pd.DataFrame, input.get_input())
node_degree_df = _get_node_degree_table(input, node_name_column, node_degree_column)
output_df = compute_edge_combined_degree_df(
edge_df,
node_degree_df,
to,
node_name_column,
node_degree_column,
edge_source_column,
edge_target_column,
)
return TableContainer(table=output_df)
def compute_edge_combined_degree_df(
edge_df: pd.DataFrame,
node_degree_df: pd.DataFrame,
to: str,
node_name_column: str,
node_degree_column: str,
edge_source_column: str,
edge_target_column: str,
) -> pd.DataFrame:
"""Compute the combined degree for each edge in a graph."""
if to in edge_df.columns:
return edge_df
def join_to_degree(df: pd.DataFrame, column: str) -> pd.DataFrame:
degree_column = _degree_colname(column)
result = df.merge(
node_degree_df.rename(
columns={node_name_column: column, node_degree_column: degree_column}
),
on=column,
how="left",
)
result[degree_column] = result[degree_column].fillna(0)
return result
output_df = join_to_degree(edge_df, edge_source_column)
output_df = join_to_degree(output_df, edge_target_column)
output_df[to] = (
output_df[_degree_colname(edge_source_column)]
+ output_df[_degree_colname(edge_target_column)]
)
return output_df
def _degree_colname(column: str) -> str:
return f"{column}_degree"
def _get_node_degree_table(
input: VerbInput, node_name_column: str, node_degree_column: str
) -> pd.DataFrame:
nodes_container = get_required_input_table(input, "nodes")
nodes = cast(pd.DataFrame, nodes_container.table)
return cast(pd.DataFrame, nodes[[node_name_column, node_degree_column]])

View File

@ -1,135 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition."""
from typing import Any
import networkx as nx
import pandas as pd
from datashaper import TableContainer, VerbCallbacks, VerbInput, progress_iterable, verb
from graphrag.index.utils import clean_str
DEFAULT_NODE_ATTRIBUTES = ["label", "type", "id", "name", "description", "community"]
DEFAULT_EDGE_ATTRIBUTES = ["label", "type", "name", "source", "target"]
@verb(name="create_graph")
def create_graph(
input: VerbInput,
callbacks: VerbCallbacks,
to: str,
type: str, # noqa A002
graph_type: str = "undirected",
**kwargs,
) -> TableContainer:
"""
Create a graph from a dataframe. The verb outputs a new column containing the graph.
> Note: This will roll up all rows into a single graph.
## Usage
```yaml
verb: create_graph
args:
type: node # The type of graph to create, one of: node, edge
to: <column name> # The name of the column to output the graph to, this will be a graphml graph
attributes: # The attributes for the nodes / edges
# If using the node type, the following attributes are required:
id: <id_column_name>
# If using the edge type, the following attributes are required:
source: <source_column_name>
target: <target_column_name>
# Other attributes can be added as follows:
<attribute_name>: <column_name>
... for each attribute
```
"""
if type != "node" and type != "edge":
msg = f"Unknown type {type}"
raise ValueError(msg)
input_df = input.get_input()
num_total = len(input_df)
out_graph: nx.Graph = _create_nx_graph(graph_type)
in_attributes = (
_get_node_attributes(kwargs) if type == "node" else _get_edge_attributes(kwargs)
)
# At this point, _get_node_attributes and _get_edge_attributes have already validated
id_col = in_attributes.get(
"id", in_attributes.get("label", in_attributes.get("name", None))
)
source_col = in_attributes.get("source", None)
target_col = in_attributes.get("target", None)
for _, row in progress_iterable(input_df.iterrows(), callbacks.progress, num_total):
item_attributes = {
clean_str(key): _clean_value(row[value])
for key, value in in_attributes.items()
if value in row
}
if type == "node":
id = clean_str(row[id_col])
out_graph.add_node(id, **item_attributes)
elif type == "edge":
source = clean_str(row[source_col])
target = clean_str(row[target_col])
out_graph.add_edge(source, target, **item_attributes)
graphml_string = "".join(nx.generate_graphml(out_graph))
output_df = pd.DataFrame([{to: graphml_string}])
return TableContainer(table=output_df)
def _clean_value(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return clean_str(value)
msg = f"Value must be a string or None, got {type(value)}"
raise TypeError(msg)
def _get_node_attributes(args: dict[str, Any]) -> dict[str, Any]:
mapping = _get_attribute_column_mapping(
args.get("attributes", DEFAULT_NODE_ATTRIBUTES)
)
if "id" not in mapping and "label" not in mapping and "name" not in mapping:
msg = "You must specify an id, label, or name column in the node attributes"
raise ValueError(msg)
return mapping
def _get_edge_attributes(args: dict[str, Any]) -> dict[str, Any]:
mapping = _get_attribute_column_mapping(
args.get("attributes", DEFAULT_EDGE_ATTRIBUTES)
)
if "source" not in mapping or "target" not in mapping:
msg = "You must specify a source and target column in the edge attributes"
raise ValueError(msg)
return mapping
def _get_attribute_column_mapping(
in_attributes: dict[str, Any] | list[str],
) -> dict[str, str]:
# Its already a attribute: column dict
if isinstance(in_attributes, dict):
return {
**in_attributes,
}
return {attrib: attrib for attrib in in_attributes}
def _create_nx_graph(graph_type: str) -> nx.Graph:
if graph_type == "directed":
return nx.DiGraph()
return nx.Graph()

View File

@ -1,21 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A file containing DEFAULT_NODE_OPERATIONS, DEFAULT_EDGE_OPERATIONS and DEFAULT_CONCAT_SEPARATOR values definition."""
from .typing import BasicMergeOperation
DEFAULT_NODE_OPERATIONS = {
"*": {
"operation": BasicMergeOperation.Replace,
}
}
DEFAULT_EDGE_OPERATIONS = {
"*": {
"operation": BasicMergeOperation.Replace,
},
"weight": "sum",
}
DEFAULT_CONCAT_SEPARATOR = ","

View File

@ -1,25 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The Indexing Engine graph report package root."""
from .create_community_reports import (
CreateCommunityReportsStrategyType,
create_community_reports,
)
from .prepare_community_reports import prepare_community_reports
from .prepare_community_reports_claims import prepare_community_reports_claims
from .prepare_community_reports_edges import prepare_community_reports_edges
from .prepare_community_reports_nodes import prepare_community_reports_nodes
from .restore_community_hierarchy import restore_community_hierarchy
__all__ = [
"CreateCommunityReportsStrategyType",
"create_community_reports",
"create_community_reports",
"prepare_community_reports",
"prepare_community_reports_claims",
"prepare_community_reports_edges",
"prepare_community_reports_nodes",
"restore_community_hierarchy",
]

View File

@ -1,50 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition."""
from typing import cast
import pandas as pd
from datashaper import TableContainer, VerbInput, verb
from graphrag.index.graph.extractors.community_reports.schemas import (
CLAIM_DESCRIPTION,
CLAIM_DETAILS,
CLAIM_ID,
CLAIM_STATUS,
CLAIM_SUBJECT,
CLAIM_TYPE,
)
_MISSING_DESCRIPTION = "No Description"
@verb(name="prepare_community_reports_claims")
def prepare_community_reports_claims(
input: VerbInput,
to: str = CLAIM_DETAILS,
id_column: str = CLAIM_ID,
description_column: str = CLAIM_DESCRIPTION,
subject_column: str = CLAIM_SUBJECT,
type_column: str = CLAIM_TYPE,
status_column: str = CLAIM_STATUS,
**_kwargs,
) -> TableContainer:
"""Merge claim details into an object."""
claim_df: pd.DataFrame = cast(pd.DataFrame, input.get_input())
claim_df = claim_df.fillna(value={description_column: _MISSING_DESCRIPTION})
# merge values of five columns into a map column
claim_df[to] = claim_df.apply(
lambda x: {
id_column: x[id_column],
subject_column: x[subject_column],
type_column: x[type_column],
status_column: x[status_column],
description_column: x[description_column],
},
axis=1,
)
return TableContainer(table=claim_df)

View File

@ -1,48 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition."""
from typing import cast
import pandas as pd
from datashaper import TableContainer, VerbInput, verb
from graphrag.index.graph.extractors.community_reports.schemas import (
EDGE_DEGREE,
EDGE_DESCRIPTION,
EDGE_DETAILS,
EDGE_ID,
EDGE_SOURCE,
EDGE_TARGET,
)
_MISSING_DESCRIPTION = "No Description"
@verb(name="prepare_community_reports_edges")
def prepare_community_reports_edges(
input: VerbInput,
to: str = EDGE_DETAILS,
id_column: str = EDGE_ID,
source_column: str = EDGE_SOURCE,
target_column: str = EDGE_TARGET,
description_column: str = EDGE_DESCRIPTION,
degree_column: str = EDGE_DEGREE,
**_kwargs,
) -> TableContainer:
"""Merge edge details into an object."""
edge_df: pd.DataFrame = cast(pd.DataFrame, input.get_input()).fillna(
value={description_column: _MISSING_DESCRIPTION}
)
edge_df[to] = edge_df.apply(
lambda x: {
id_column: x[id_column],
source_column: x[source_column],
target_column: x[target_column],
description_column: x[description_column],
degree_column: x[degree_column],
},
axis=1,
)
return TableContainer(table=edge_df)

View File

@ -1,46 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition."""
from typing import cast
import pandas as pd
from datashaper import TableContainer, VerbInput, verb
from graphrag.index.graph.extractors.community_reports.schemas import (
NODE_DEGREE,
NODE_DESCRIPTION,
NODE_DETAILS,
NODE_ID,
NODE_NAME,
)
_MISSING_DESCRIPTION = "No Description"
@verb(name="prepare_community_reports_nodes")
def prepare_community_reports_nodes(
input: VerbInput,
to: str = NODE_DETAILS,
id_column: str = NODE_ID,
name_column: str = NODE_NAME,
description_column: str = NODE_DESCRIPTION,
degree_column: str = NODE_DEGREE,
**_kwargs,
) -> TableContainer:
"""Merge edge details into an object."""
node_df = cast(pd.DataFrame, input.get_input())
node_df = node_df.fillna(value={description_column: _MISSING_DESCRIPTION})
# merge values of four columns into a map column
node_df[to] = node_df.apply(
lambda x: {
id_column: x[id_column],
name_column: x[name_column],
description_column: x[description_column],
degree_column: x[degree_column],
},
axis=1,
)
return TableContainer(table=node_df)

View File

@ -1,4 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The Indexing Engine graph report strategies package root."""

View File

@ -1,8 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The Indexing Engine graph report strategies graph intelligence package root."""
from .run_graph_intelligence import run
__all__ = ["run"]

View File

@ -1,27 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A file containing DEFAULT_CHUNK_SIZE and MOCK_RESPONSES definitions."""
import json
DEFAULT_CHUNK_SIZE = 3000
MOCK_RESPONSES = [
json.dumps({
"title": "<report_title>",
"summary": "<executive_summary>",
"rating": 2,
"rating_explanation": "<rating_explanation>",
"findings": [
{
"summary": "<insight_1_summary>",
"explanation": "<insight_1_explanation",
},
{
"summary": "<insight_2_summary>",
"explanation": "<insight_2_explanation",
},
],
})
]

View File

@ -1,9 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The Indexing Engine overrides package root."""
from .aggregate import aggregate
from .concat import concat
__all__ = ["aggregate", "concat"]

View File

@ -1,101 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'Aggregation' model."""
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from dataclasses import dataclass
from typing import Any, cast
import pandas as pd
from datashaper import (
FieldAggregateOperation,
Progress,
Table,
TableContainer,
VerbCallbacks,
VerbInput,
aggregate_operation_mapping,
verb,
)
ARRAY_AGGREGATIONS = [
FieldAggregateOperation.ArrayAgg,
FieldAggregateOperation.ArrayAggDistinct,
]
# TODO: This thing is kinda gross
# Also, it diverges from the original aggregate verb, since it doesn't support the same syntax
@verb(name="aggregate_override")
def aggregate(
input: VerbInput,
callbacks: VerbCallbacks,
aggregations: list[dict[str, Any]],
groupby: list[str] | None = None,
**_kwargs: dict,
) -> TableContainer:
"""Aggregate method definition."""
input_table = input.get_input()
callbacks.progress(Progress(percent=0))
output = aggregate_df(input_table, aggregations, groupby)
callbacks.progress(Progress(percent=1))
return TableContainer(table=output)
def aggregate_df(
input_table: Table,
aggregations: list[dict[str, Any]],
groupby: list[str] | None = None,
) -> pd.DataFrame:
"""Aggregate method definition."""
aggregations_to_apply = _load_aggregations(aggregations)
df_aggregations = {
agg.column: _get_pandas_agg_operation(agg)
for agg in aggregations_to_apply.values()
}
if groupby is None:
output_grouped = input_table.groupby(lambda _x: True)
else:
output_grouped = input_table.groupby(groupby, sort=False)
output = cast(pd.DataFrame, output_grouped.agg(df_aggregations))
output.rename(
columns={agg.column: agg.to for agg in aggregations_to_apply.values()},
inplace=True,
)
output.columns = [agg.to for agg in aggregations_to_apply.values()]
return output.reset_index()
@dataclass
class Aggregation:
"""Aggregation class method definition."""
column: str | None
operation: str
to: str
# Only useful for the concat operation
separator: str | None = None
def _get_pandas_agg_operation(agg: Aggregation) -> Any:
# TODO: Merge into datashaper
if agg.operation == "string_concat":
return (agg.separator or ",").join
return aggregate_operation_mapping[FieldAggregateOperation(agg.operation)]
def _load_aggregations(
aggregations: list[dict[str, Any]],
) -> dict[str, Aggregation]:
return {
aggregation["column"]: Aggregation(
aggregation["column"], aggregation["operation"], aggregation["to"]
)
for aggregation in aggregations
}

View File

@ -1,27 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing concat method definition."""
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from typing import cast
import pandas as pd
from datashaper import TableContainer, VerbInput, verb
@verb(name="concat_override")
def concat(
input: VerbInput,
columnwise: bool = False,
**_kwargs: dict,
) -> TableContainer:
"""Concat method definition."""
input_table = cast(pd.DataFrame, input.get_input())
others = cast(list[pd.DataFrame], input.get_others())
if columnwise:
output = pd.concat([input_table, *others], axis=1)
else:
output = pd.concat([input_table, *others], ignore_index=True)
return TableContainer(table=output)

Some files were not shown because too many files have changed in this diff Show More