mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Input factory (#2168)
Some checks are pending
Python Build and Type Check / python-ci (ubuntu-latest, 3.11) (push) Waiting to run
Python Build and Type Check / python-ci (ubuntu-latest, 3.13) (push) Waiting to run
Python Build and Type Check / python-ci (windows-latest, 3.11) (push) Waiting to run
Python Build and Type Check / python-ci (windows-latest, 3.13) (push) Waiting to run
Python Integration Tests / python-ci (ubuntu-latest, 3.13) (push) Waiting to run
Python Integration Tests / python-ci (windows-latest, 3.13) (push) Waiting to run
Python Notebook Tests / python-ci (ubuntu-latest, 3.13) (push) Waiting to run
Python Notebook Tests / python-ci (windows-latest, 3.13) (push) Waiting to run
Python Smoke Tests / python-ci (ubuntu-latest, 3.13) (push) Waiting to run
Python Smoke Tests / python-ci (windows-latest, 3.13) (push) Waiting to run
Python Unit Tests / python-ci (ubuntu-latest, 3.13) (push) Waiting to run
Python Unit Tests / python-ci (windows-latest, 3.13) (push) Waiting to run
Some checks are pending
Python Build and Type Check / python-ci (ubuntu-latest, 3.11) (push) Waiting to run
Python Build and Type Check / python-ci (ubuntu-latest, 3.13) (push) Waiting to run
Python Build and Type Check / python-ci (windows-latest, 3.11) (push) Waiting to run
Python Build and Type Check / python-ci (windows-latest, 3.13) (push) Waiting to run
Python Integration Tests / python-ci (ubuntu-latest, 3.13) (push) Waiting to run
Python Integration Tests / python-ci (windows-latest, 3.13) (push) Waiting to run
Python Notebook Tests / python-ci (ubuntu-latest, 3.13) (push) Waiting to run
Python Notebook Tests / python-ci (windows-latest, 3.13) (push) Waiting to run
Python Smoke Tests / python-ci (ubuntu-latest, 3.13) (push) Waiting to run
Python Smoke Tests / python-ci (windows-latest, 3.13) (push) Waiting to run
Python Unit Tests / python-ci (ubuntu-latest, 3.13) (push) Waiting to run
Python Unit Tests / python-ci (windows-latest, 3.13) (push) Waiting to run
* Update input factory to match other factories * Move input config alongside input readers * Move file pattern logic into InputReader * Set encoding default * Clean up optional column configs * Combine structured data extraction * Remove pandas from input loading * Throw if empty documents * Add json lines (jsonl) input support * Store raw data * Fix merge imports * Move metadata handling entirely to chunking * Nicer automatic title * Typo * Add get_property utility for nested dictionary access with dot notation * Update structured_file_reader to use get_property utility * Extract input module into new graphrag-input monorepo package - Create new graphrag-input package with input loading utilities - Move InputConfig, InputFileType, InputReader, TextDocument, and file readers (CSV, JSON, JSONL, Text) - Add get_property utility for nested dictionary access with dot notation - Include hashing utility for document ID generation - Update all imports throughout codebase to use graphrag_input - Add package to workspace configuration and release tasks - Remove old graphrag.index.input module * Rename ChunkResult to TextChunk and add transformer support - Rename chunk_result.py to text_chunk.py with ChunkResult -> TextChunk - Add 'original' field to TextChunk to track pre-transform text - Add optional transform callback to chunker.chunk() method - Add add_metadata transformer for prepending metadata to chunks - Update create_chunk_results to apply transforms and populate original - Update sentence_chunker and token_chunker with transform support - Refactor create_base_text_units to use new transformer pattern - Rename pluck_metadata to get/collect methods on TextDocument * Back-compat comment * Align input config type name with other factory configs * Add MarkItDown support * Remove pattern default from MarkItDown reader * Remove plugins flag (implicit disabled) * Format * Update verb tests * Separate storage from input config * Add empty objects for NaN raw_data * Fix smoke tests * Fix BOM in csv smoke * Format
This commit is contained in:
parent
8fd7730067
commit
710fdad6f0
@ -87,9 +87,9 @@ Our pipeline can ingest .csv, .txt, or .json data from an input location. See th
|
||||
- `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name.
|
||||
- `storage_account_blob_url` **str** - (blob only) The storage account blob URL to use.
|
||||
- `cosmosdb_account_blob_url` **str** - (cosmosdb only) The CosmosDB account blob URL to use.
|
||||
- `file_type` **text|csv|json** - The type of input data to load. Default is `text`
|
||||
- `type` **text|csv|json** - The type of input data to load. Default is `text`
|
||||
- `encoding` **str** - The encoding of the input file. Default is `utf-8`
|
||||
- `file_pattern` **str** - A regex to match input files. Default is `.*\.csv$`, `.*\.txt$`, or `.*\.json$` depending on the specified `file_type`, but you can customize it if needed.
|
||||
- `file_pattern` **str** - A regex to match input files. Default is `.*\.csv$`, `.*\.txt$`, or `.*\.json$` depending on the specified `type`, but you can customize it if needed.
|
||||
- `text_column` **str** - (CSV/JSON only) The text column name. If unset we expect a column named `text`.
|
||||
- `title_column` **str** - (CSV/JSON only) The title column name, filename will be used if unset.
|
||||
- `metadata` **list[str]** - (CSV/JSON only) The additional document attributes fields to keep.
|
||||
|
||||
@ -28,11 +28,10 @@
|
||||
"from pathlib import Path\n",
|
||||
"from pprint import pprint\n",
|
||||
"\n",
|
||||
"import graphrag.api as api\n",
|
||||
"import pandas as pd\n",
|
||||
"from graphrag.config.load_config import load_config\n",
|
||||
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult\n",
|
||||
"\n",
|
||||
"import graphrag.api as api"
|
||||
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -30,11 +30,10 @@
|
||||
"from pathlib import Path\n",
|
||||
"from pprint import pprint\n",
|
||||
"\n",
|
||||
"import graphrag.api as api\n",
|
||||
"import pandas as pd\n",
|
||||
"from graphrag.config.load_config import load_config\n",
|
||||
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult\n",
|
||||
"\n",
|
||||
"import graphrag.api as api"
|
||||
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -116,7 +116,7 @@ settings.yaml
|
||||
|
||||
```yaml
|
||||
input:
|
||||
file_type: text
|
||||
type: text
|
||||
metadata: [title]
|
||||
|
||||
chunks:
|
||||
@ -194,7 +194,7 @@ settings.yaml
|
||||
|
||||
```yaml
|
||||
input:
|
||||
file_type: json
|
||||
type: json
|
||||
title_column: headline
|
||||
text_column: content
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ def register_cache(
|
||||
cache_initializer: Callable[..., Cache],
|
||||
scope: ServiceScope = "transient",
|
||||
) -> None:
|
||||
"""Register a custom storage implementation.
|
||||
"""Register a custom cache implementation.
|
||||
|
||||
Args
|
||||
----
|
||||
|
||||
@ -23,12 +23,12 @@ authors = [
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
license-files = ["LICENSE"]
|
||||
requires-python = ">=3.10,<3.13"
|
||||
requires-python = ">=3.11,<3.14"
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
]
|
||||
dependencies = [
|
||||
"graphrag-common==2.7.0",
|
||||
|
||||
@ -1,19 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'prepend_metadata' function."""
|
||||
|
||||
|
||||
def add_metadata(
|
||||
text: str,
|
||||
metadata: dict,
|
||||
delimiter: str = ": ",
|
||||
line_delimiter: str = "\n",
|
||||
append: bool = False,
|
||||
) -> str:
|
||||
"""Add metadata to the given text, prepending by default. This utility writes the dict as rows of key/value pairs."""
|
||||
metadata_str = (
|
||||
line_delimiter.join(f"{k}{delimiter}{v}" for k, v in metadata.items())
|
||||
+ line_delimiter
|
||||
)
|
||||
return text + metadata_str if append else metadata_str + text
|
||||
@ -1,17 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""The ChunkResult dataclass."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkResult:
|
||||
"""Result of chunking a document."""
|
||||
|
||||
text: str
|
||||
index: int
|
||||
start_char: int
|
||||
end_char: int
|
||||
token_count: int | None = None
|
||||
@ -4,9 +4,10 @@
|
||||
"""A module containing the 'Chunker' class."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from graphrag_chunking.chunk_result import ChunkResult
|
||||
from graphrag_chunking.text_chunk import TextChunk
|
||||
|
||||
|
||||
class Chunker(ABC):
|
||||
@ -17,5 +18,7 @@ class Chunker(ABC):
|
||||
"""Create a chunker instance."""
|
||||
|
||||
@abstractmethod
|
||||
def chunk(self, text: str) -> list[ChunkResult]:
|
||||
def chunk(
|
||||
self, text: str, transform: Callable[[str], str] | None = None
|
||||
) -> list[TextChunk]:
|
||||
"""Chunk method definition."""
|
||||
|
||||
@ -30,7 +30,7 @@ class ChunkingConfig(BaseModel):
|
||||
description="The chunk overlap to use.",
|
||||
default=100,
|
||||
)
|
||||
prepend_metadata: bool = Field(
|
||||
description="Prepend metadata into each chunk.",
|
||||
default=False,
|
||||
prepend_metadata: list[str] | None = Field(
|
||||
description="Metadata fields from the source document to prepend on each chunk.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@ -5,26 +5,28 @@
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from graphrag_chunking.chunk_result import ChunkResult
|
||||
from graphrag_chunking.text_chunk import TextChunk
|
||||
|
||||
|
||||
def create_chunk_results(
|
||||
chunks: list[str],
|
||||
transform: Callable[[str], str] | None = None,
|
||||
encode: Callable[[str], list[int]] | None = None,
|
||||
) -> list[ChunkResult]:
|
||||
"""Create chunk results from a list of text chunks. The index assignments are 0-based and assume chunks we not stripped relative to the source text."""
|
||||
) -> list[TextChunk]:
|
||||
"""Create chunk results from a list of text chunks. The index assignments are 0-based and assume chunks were not stripped relative to the source text."""
|
||||
results = []
|
||||
start_char = 0
|
||||
for index, chunk in enumerate(chunks):
|
||||
end_char = start_char + len(chunk) - 1 # 0-based indices
|
||||
chunk = ChunkResult(
|
||||
text=chunk,
|
||||
result = TextChunk(
|
||||
original=chunk,
|
||||
text=transform(chunk) if transform else chunk,
|
||||
index=index,
|
||||
start_char=start_char,
|
||||
end_char=end_char,
|
||||
)
|
||||
if encode:
|
||||
chunk.token_count = len(encode(chunk.text))
|
||||
results.append(chunk)
|
||||
result.token_count = len(encode(result.text))
|
||||
results.append(result)
|
||||
start_char = end_char + 1
|
||||
return results
|
||||
|
||||
@ -9,9 +9,9 @@ from typing import Any
|
||||
import nltk
|
||||
|
||||
from graphrag_chunking.bootstrap_nltk import bootstrap
|
||||
from graphrag_chunking.chunk_result import ChunkResult
|
||||
from graphrag_chunking.chunker import Chunker
|
||||
from graphrag_chunking.create_chunk_results import create_chunk_results
|
||||
from graphrag_chunking.text_chunk import TextChunk
|
||||
|
||||
|
||||
class SentenceChunker(Chunker):
|
||||
@ -24,10 +24,14 @@ class SentenceChunker(Chunker):
|
||||
self._encode = encode
|
||||
bootstrap()
|
||||
|
||||
def chunk(self, text) -> list[ChunkResult]:
|
||||
def chunk(
|
||||
self, text: str, transform: Callable[[str], str] | None = None
|
||||
) -> list[TextChunk]:
|
||||
"""Chunk the text into sentence-based chunks."""
|
||||
sentences = nltk.sent_tokenize(text.strip())
|
||||
results = create_chunk_results(sentences, encode=self._encode)
|
||||
results = create_chunk_results(
|
||||
sentences, transform=transform, encode=self._encode
|
||||
)
|
||||
# nltk sentence tokenizer may trim whitespace, so we need to adjust start/end chars
|
||||
for index, result in enumerate(results):
|
||||
txt = result.text
|
||||
|
||||
29
packages/graphrag-chunking/graphrag_chunking/text_chunk.py
Normal file
29
packages/graphrag-chunking/graphrag_chunking/text_chunk.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""The TextChunk dataclass."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextChunk:
|
||||
"""Result of chunking a document."""
|
||||
|
||||
original: str
|
||||
"""Raw original text chunk before any transformation."""
|
||||
|
||||
text: str
|
||||
"""The final text content of this chunk."""
|
||||
|
||||
index: int
|
||||
"""Zero-based index of this chunk within the source document."""
|
||||
|
||||
start_char: int
|
||||
"""Character index where the raw chunk text begins in the source document."""
|
||||
|
||||
end_char: int
|
||||
"""Character index where the raw chunk text ends in the source document."""
|
||||
|
||||
token_count: int | None = None
|
||||
"""Number of tokens in the final chunk text, if computed."""
|
||||
@ -6,9 +6,9 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from graphrag_chunking.chunk_result import ChunkResult
|
||||
from graphrag_chunking.chunker import Chunker
|
||||
from graphrag_chunking.create_chunk_results import create_chunk_results
|
||||
from graphrag_chunking.text_chunk import TextChunk
|
||||
|
||||
|
||||
class TokenChunker(Chunker):
|
||||
@ -28,7 +28,9 @@ class TokenChunker(Chunker):
|
||||
self._encode = encode
|
||||
self._decode = decode
|
||||
|
||||
def chunk(self, text: str) -> list[ChunkResult]:
|
||||
def chunk(
|
||||
self, text: str, transform: Callable[[str], str] | None = None
|
||||
) -> list[TextChunk]:
|
||||
"""Chunk the text into token-based chunks."""
|
||||
chunks = split_text_on_tokens(
|
||||
text,
|
||||
@ -37,7 +39,7 @@ class TokenChunker(Chunker):
|
||||
encode=self._encode,
|
||||
decode=self._decode,
|
||||
)
|
||||
return create_chunk_results(chunks, encode=self._encode)
|
||||
return create_chunk_results(chunks, transform=transform, encode=self._encode)
|
||||
|
||||
|
||||
def split_text_on_tokens(
|
||||
|
||||
25
packages/graphrag-chunking/graphrag_chunking/transformers.py
Normal file
25
packages/graphrag-chunking/graphrag_chunking/transformers.py
Normal file
@ -0,0 +1,25 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A collection of useful built-in transformers you can use for chunking."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
|
||||
def add_metadata(
|
||||
metadata: dict[str, Any],
|
||||
delimiter: str = ": ",
|
||||
line_delimiter: str = "\n",
|
||||
append: bool = False,
|
||||
) -> Callable[[str], str]:
|
||||
"""Add metadata to the given text, prepending by default. This utility writes the dict as rows of key/value pairs."""
|
||||
|
||||
def transformer(text: str) -> str:
|
||||
metadata_str = (
|
||||
line_delimiter.join(f"{k}{delimiter}{v}" for k, v in metadata.items())
|
||||
+ line_delimiter
|
||||
)
|
||||
return text + metadata_str if append else metadata_str + text
|
||||
|
||||
return transformer
|
||||
72
packages/graphrag-input/README.md
Normal file
72
packages/graphrag-input/README.md
Normal file
@ -0,0 +1,72 @@
|
||||
# GraphRAG Inputs
|
||||
|
||||
This package provides input document loading utilities for GraphRAG, supporting multiple file formats including CSV, JSON, JSON Lines, and plain text.
|
||||
|
||||
## Supported File Types
|
||||
|
||||
The following four standard file formats are supported out of the box:
|
||||
|
||||
- **CSV** - Tabular data with configurable column mappings
|
||||
- **JSON** - JSON files with configurable property paths
|
||||
- **JSON Lines** - Line-delimited JSON records
|
||||
- **Text** - Plain text files
|
||||
|
||||
### Markitdown Support
|
||||
|
||||
Additionally, we support the `InputType.MarkItDown` format, which uses the [MarkItDown](https://github.com/microsoft/markitdown) library to import any supported file type. The MarkItDown converter can handle a wide variety of file formats including Office documents, PDFs, HTML, and more.
|
||||
|
||||
**Note:** Additional optional dependencies may need to be installed depending on the file type you're processing. The choice of converter is determined by MarkItDowns's processing logic, which primarily uses the file extension to select the appropriate converter. Please refer to the [MarkItDown repository](https://github.com/microsoft/markitdown) for installation instructions and detailed information about supported formats.
|
||||
|
||||
## Examples
|
||||
|
||||
Basic usage with the factory:
|
||||
```python
|
||||
from graphrag_input import create_input_reader, InputConfig, InputType
|
||||
from graphrag_storage import StorageConfig, create_storage
|
||||
|
||||
config = InputConfig(
|
||||
type=InputType.Csv,
|
||||
text_column="content",
|
||||
title_column="title",
|
||||
)
|
||||
storage = create_storage(StorageConfig(base_dir="./input"))
|
||||
reader = create_input_reader(config, storage)
|
||||
documents = await reader.read_files()
|
||||
```
|
||||
|
||||
Import a pdf with MarkItDown:
|
||||
|
||||
```bash
|
||||
pip install 'markitdown[pdf]' # required dependency for pdf processing
|
||||
```
|
||||
|
||||
```python
|
||||
from graphrag_input import create_input_reader, InputConfig, InputType
|
||||
from graphrag_storage import StorageConfig, create_storage
|
||||
|
||||
config = InputConfig(
|
||||
type=InputType.MarkitDown,
|
||||
file_pattern=".*\\.pdf$"
|
||||
)
|
||||
storage = create_storage(StorageConfig(base_dir="./input"))
|
||||
reader = create_input_reader(config, storage)
|
||||
documents = await reader.read_files()
|
||||
```
|
||||
|
||||
YAML config example for above:
|
||||
```yaml
|
||||
input:
|
||||
type: markitdown
|
||||
file_pattern: ".*\\.pdf$$"
|
||||
input_storage:
|
||||
type: file
|
||||
base_dir: "input"
|
||||
```
|
||||
|
||||
Note that when specifying column names for data extraction, we can handle nested objects (e.g., in JSON) with dot notation:
|
||||
```python
|
||||
from graphrag_input import get_property
|
||||
|
||||
data = {"user": {"profile": {"name": "Alice"}}}
|
||||
name = get_property(data, "user.profile.name") # Returns "Alice"
|
||||
```
|
||||
20
packages/graphrag-input/graphrag_input/__init__.py
Normal file
20
packages/graphrag-input/graphrag_input/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""GraphRAG input document loading package."""
|
||||
|
||||
from graphrag_input.get_property import get_property
|
||||
from graphrag_input.input_config import InputConfig
|
||||
from graphrag_input.input_reader import InputReader
|
||||
from graphrag_input.input_reader_factory import create_input_reader
|
||||
from graphrag_input.input_type import InputType
|
||||
from graphrag_input.text_document import TextDocument
|
||||
|
||||
__all__ = [
|
||||
"InputConfig",
|
||||
"InputReader",
|
||||
"InputType",
|
||||
"TextDocument",
|
||||
"create_input_reader",
|
||||
"get_property",
|
||||
]
|
||||
38
packages/graphrag-input/graphrag_input/csv.py
Normal file
38
packages/graphrag-input/graphrag_input/csv.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'CSVFileReader' model."""
|
||||
|
||||
import csv
|
||||
import logging
|
||||
|
||||
from graphrag_input.structured_file_reader import StructuredFileReader
|
||||
from graphrag_input.text_document import TextDocument
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CSVFileReader(StructuredFileReader):
|
||||
"""Reader implementation for csv files."""
|
||||
|
||||
def __init__(self, file_pattern: str | None = None, **kwargs):
|
||||
super().__init__(
|
||||
file_pattern=file_pattern if file_pattern is not None else ".*\\.csv$",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def read_file(self, path: str) -> list[TextDocument]:
|
||||
"""Read a csv file into a list of documents.
|
||||
|
||||
Args:
|
||||
- path - The path to read the file from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - list with a TextDocument for each row in the file.
|
||||
"""
|
||||
file = await self._storage.get(path, encoding=self._encoding)
|
||||
|
||||
reader = csv.DictReader(file.splitlines())
|
||||
rows = list(reader)
|
||||
return await self.process_data_columns(rows, path)
|
||||
36
packages/graphrag-input/graphrag_input/get_property.py
Normal file
36
packages/graphrag-input/graphrag_input/get_property.py
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Utility for retrieving properties from nested dictionaries."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def get_property(data: dict[str, Any], path: str) -> Any:
|
||||
"""Retrieve a property from a dictionary using dot notation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : dict[str, Any]
|
||||
The dictionary to retrieve the property from.
|
||||
path : str
|
||||
A dot-separated string representing the path to the property (e.g., "foo.bar.baz").
|
||||
|
||||
Returns
|
||||
-------
|
||||
Any
|
||||
The value at the specified path.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If the path does not exist in the dictionary.
|
||||
"""
|
||||
keys = path.split(".")
|
||||
current = data
|
||||
for key in keys:
|
||||
if not isinstance(current, dict) or key not in current:
|
||||
msg = f"Property '{path}' not found"
|
||||
raise KeyError(msg)
|
||||
current = current[key]
|
||||
return current
|
||||
27
packages/graphrag-input/graphrag_input/hashing.py
Normal file
27
packages/graphrag-input/graphrag_input/hashing.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Hashing utilities."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from hashlib import sha512
|
||||
from typing import Any
|
||||
|
||||
|
||||
def gen_sha512_hash(item: dict[str, Any], hashcode: Iterable[str]) -> str:
|
||||
"""Generate a SHA512 hash.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
item : dict[str, Any]
|
||||
The dictionary containing values to hash.
|
||||
hashcode : Iterable[str]
|
||||
The keys to include in the hash.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The SHA512 hash as a hexadecimal string.
|
||||
"""
|
||||
hashed = "".join([str(item[column]) for column in hashcode])
|
||||
return f"{sha512(hashed.encode('utf-8'), usedforsecurity=False).hexdigest()}"
|
||||
40
packages/graphrag-input/graphrag_input/input_config.py
Normal file
40
packages/graphrag-input/graphrag_input/input_config.py
Normal file
@ -0,0 +1,40 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Parameterization settings for the default configuration."""
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from graphrag_input.input_type import InputType
|
||||
|
||||
|
||||
class InputConfig(BaseModel):
|
||||
"""The default configuration section for Input."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
"""Allow extra fields to support custom reader implementations."""
|
||||
|
||||
type: str = Field(
|
||||
description="The input file type to use.",
|
||||
default=InputType.Text,
|
||||
)
|
||||
encoding: str | None = Field(
|
||||
description="The input file encoding to use.",
|
||||
default=None,
|
||||
)
|
||||
file_pattern: str | None = Field(
|
||||
description="The input file pattern to use.",
|
||||
default=None,
|
||||
)
|
||||
id_column: str | None = Field(
|
||||
description="The input ID column to use.",
|
||||
default=None,
|
||||
)
|
||||
title_column: str | None = Field(
|
||||
description="The input title column to use.",
|
||||
default=None,
|
||||
)
|
||||
text_column: str | None = Field(
|
||||
description="The input text column to use.",
|
||||
default=None,
|
||||
)
|
||||
75
packages/graphrag-input/graphrag_input/input_reader.py
Normal file
75
packages/graphrag-input/graphrag_input/input_reader.py
Normal file
@ -0,0 +1,75 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'InputReader' model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag_storage import Storage
|
||||
|
||||
from graphrag_input.text_document import TextDocument
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InputReader(metaclass=ABCMeta):
|
||||
"""Provide a cache interface for the pipeline."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: Storage,
|
||||
file_pattern: str,
|
||||
encoding: str = "utf-8",
|
||||
**kwargs,
|
||||
):
|
||||
self._storage = storage
|
||||
self._encoding = encoding
|
||||
self._file_pattern = file_pattern
|
||||
|
||||
async def read_files(self) -> list[TextDocument]:
|
||||
"""Load files from storage and apply a loader function based on file type. Process metadata on the results if needed."""
|
||||
files = list(self._storage.find(re.compile(self._file_pattern)))
|
||||
if len(files) == 0:
|
||||
msg = f"No {self._file_pattern} matches found in storage"
|
||||
logger.warning(msg)
|
||||
files = []
|
||||
|
||||
documents: list[TextDocument] = []
|
||||
|
||||
for file in files:
|
||||
try:
|
||||
documents.extend(await self.read_file(file))
|
||||
except Exception as e: # noqa: BLE001 (catching Exception is fine here)
|
||||
logger.warning("Warning! Error loading file %s. Skipping...", file)
|
||||
logger.warning("Error: %s", e)
|
||||
|
||||
logger.info(
|
||||
"Found %d %s files, loading %d",
|
||||
len(files),
|
||||
self._file_pattern,
|
||||
len(documents),
|
||||
)
|
||||
total_files_log = (
|
||||
f"Total number of unfiltered {self._file_pattern} rows: {len(documents)}"
|
||||
)
|
||||
logger.info(total_files_log)
|
||||
|
||||
return documents
|
||||
|
||||
@abstractmethod
|
||||
async def read_file(self, path: str) -> list[TextDocument]:
|
||||
"""Read a file into a list of documents.
|
||||
|
||||
Args:
|
||||
- path - The path to read the file from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - List with an entry for each document in the file.
|
||||
"""
|
||||
@ -0,0 +1,90 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'InputReaderFactory' model."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
from graphrag_common.factory import Factory
|
||||
from graphrag_common.factory.factory import ServiceScope
|
||||
from graphrag_storage.storage import Storage
|
||||
|
||||
from graphrag_input.input_config import InputConfig
|
||||
from graphrag_input.input_reader import InputReader
|
||||
from graphrag_input.input_type import InputType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InputReaderFactory(Factory[InputReader]):
|
||||
"""Factory for creating Input Reader instances."""
|
||||
|
||||
|
||||
input_reader_factory = InputReaderFactory()
|
||||
|
||||
|
||||
def register_input_reader(
|
||||
input_reader_type: str,
|
||||
input_reader_initializer: Callable[..., InputReader],
|
||||
scope: ServiceScope = "transient",
|
||||
) -> None:
|
||||
"""Register a custom input reader implementation.
|
||||
|
||||
Args
|
||||
----
|
||||
- input_reader_type: str
|
||||
The input reader id to register.
|
||||
- input_reader_initializer: Callable[..., InputReader]
|
||||
The input reader initializer to register.
|
||||
"""
|
||||
input_reader_factory.register(input_reader_type, input_reader_initializer, scope)
|
||||
|
||||
|
||||
def create_input_reader(config: InputConfig, storage: Storage) -> InputReader:
|
||||
"""Create an input reader implementation based on the given configuration.
|
||||
|
||||
Args
|
||||
----
|
||||
- config: InputConfig
|
||||
The input reader configuration to use.
|
||||
- storage: Storage | None
|
||||
The storage implementation to use for reading the files.
|
||||
|
||||
Returns
|
||||
-------
|
||||
InputReader
|
||||
The created input reader implementation.
|
||||
"""
|
||||
config_model = config.model_dump()
|
||||
input_strategy = config.type
|
||||
|
||||
if input_strategy not in input_reader_factory:
|
||||
match input_strategy:
|
||||
case InputType.Csv:
|
||||
from graphrag_input.csv import CSVFileReader
|
||||
|
||||
register_input_reader(InputType.Csv, CSVFileReader)
|
||||
case InputType.Text:
|
||||
from graphrag_input.text import TextFileReader
|
||||
|
||||
register_input_reader(InputType.Text, TextFileReader)
|
||||
case InputType.Json:
|
||||
from graphrag_input.json import JSONFileReader
|
||||
|
||||
register_input_reader(InputType.Json, JSONFileReader)
|
||||
case InputType.JsonLines:
|
||||
from graphrag_input.jsonl import JSONLinesFileReader
|
||||
|
||||
register_input_reader(InputType.JsonLines, JSONLinesFileReader)
|
||||
case InputType.MarkItDown:
|
||||
from graphrag_input.markitdown import MarkItDownFileReader
|
||||
|
||||
register_input_reader(InputType.MarkItDown, MarkItDownFileReader)
|
||||
case _:
|
||||
msg = f"InputConfig.type '{input_strategy}' is not registered in the InputReaderFactory. Registered types: {', '.join(input_reader_factory.keys())}."
|
||||
raise ValueError(msg)
|
||||
|
||||
config_model["storage"] = storage
|
||||
|
||||
return input_reader_factory.create(input_strategy, init_args=config_model)
|
||||
25
packages/graphrag-input/graphrag_input/input_type.py
Normal file
25
packages/graphrag-input/graphrag_input/input_type.py
Normal file
@ -0,0 +1,25 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing input file type enum."""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class InputType(StrEnum):
|
||||
"""The input file type for the pipeline."""
|
||||
|
||||
Csv = "csv"
|
||||
"""The CSV input type."""
|
||||
Text = "text"
|
||||
"""The text input type."""
|
||||
Json = "json"
|
||||
"""The JSON input type."""
|
||||
JsonLines = "jsonl"
|
||||
"""The JSON Lines input type."""
|
||||
MarkItDown = "markitdown"
|
||||
"""The MarkItDown input type."""
|
||||
|
||||
def __repr__(self):
|
||||
"""Get a string representation."""
|
||||
return f'"{self.value}"'
|
||||
38
packages/graphrag-input/graphrag_input/json.py
Normal file
38
packages/graphrag-input/graphrag_input/json.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'JSONFileReader' model."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from graphrag_input.structured_file_reader import StructuredFileReader
|
||||
from graphrag_input.text_document import TextDocument
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JSONFileReader(StructuredFileReader):
|
||||
"""Reader implementation for json files."""
|
||||
|
||||
def __init__(self, file_pattern: str | None = None, **kwargs):
|
||||
super().__init__(
|
||||
file_pattern=file_pattern if file_pattern is not None else ".*\\.json$",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def read_file(self, path: str) -> list[TextDocument]:
|
||||
"""Read a JSON file into a list of documents.
|
||||
|
||||
Args:
|
||||
- path - The path to read the file from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - list with a TextDocument for each row in the file.
|
||||
"""
|
||||
text = await self._storage.get(path, encoding=self._encoding)
|
||||
as_json = json.loads(text)
|
||||
# json file could just be a single object, or an array of objects
|
||||
rows = as_json if isinstance(as_json, list) else [as_json]
|
||||
return await self.process_data_columns(rows, path)
|
||||
38
packages/graphrag-input/graphrag_input/jsonl.py
Normal file
38
packages/graphrag-input/graphrag_input/jsonl.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'JSONLinesFileReader' model."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from graphrag_input.structured_file_reader import StructuredFileReader
|
||||
from graphrag_input.text_document import TextDocument
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JSONLinesFileReader(StructuredFileReader):
|
||||
"""Reader implementation for json lines files."""
|
||||
|
||||
def __init__(self, file_pattern: str | None = None, **kwargs):
|
||||
super().__init__(
|
||||
file_pattern=file_pattern if file_pattern is not None else ".*\\.jsonl$",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def read_file(self, path: str) -> list[TextDocument]:
|
||||
"""Read a JSON lines file into a list of documents.
|
||||
|
||||
This differs from standard JSON files in that each line is a separate JSON object.
|
||||
|
||||
Args:
|
||||
- path - The path to read the file from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - list with a TextDocument for each row in the file.
|
||||
"""
|
||||
text = await self._storage.get(path, encoding=self._encoding)
|
||||
rows = [json.loads(line) for line in text.splitlines()]
|
||||
return await self.process_data_columns(rows, path)
|
||||
49
packages/graphrag-input/graphrag_input/markitdown.py
Normal file
49
packages/graphrag-input/graphrag_input/markitdown.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'TextFileReader' model."""
|
||||
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
from markitdown import MarkItDown, StreamInfo
|
||||
|
||||
from graphrag_input.hashing import gen_sha512_hash
|
||||
from graphrag_input.input_reader import InputReader
|
||||
from graphrag_input.text_document import TextDocument
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarkItDownFileReader(InputReader):
|
||||
"""Reader implementation for any file type supported by markitdown.
|
||||
|
||||
https://github.com/microsoft/markitdown
|
||||
"""
|
||||
|
||||
async def read_file(self, path: str) -> list[TextDocument]:
|
||||
"""Read a text file into a DataFrame of documents.
|
||||
|
||||
Args:
|
||||
- path - The path to read the file from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - list with a TextDocument for each row in the file.
|
||||
"""
|
||||
bytes = await self._storage.get(path, encoding=self._encoding, as_bytes=True)
|
||||
md = MarkItDown()
|
||||
result = md.convert_stream(
|
||||
BytesIO(bytes), stream_info=StreamInfo(extension=Path(path).suffix)
|
||||
)
|
||||
text = result.markdown
|
||||
|
||||
document = TextDocument(
|
||||
id=gen_sha512_hash({"text": text}, ["text"]),
|
||||
title=result.title if result.title else str(Path(path).name),
|
||||
text=text,
|
||||
creation_date=await self._storage.get_creation_date(path),
|
||||
raw_data=None,
|
||||
)
|
||||
return [document]
|
||||
@ -0,0 +1,65 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'StructuredFileReader' model."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from graphrag_input.get_property import get_property
|
||||
from graphrag_input.hashing import gen_sha512_hash
|
||||
from graphrag_input.input_reader import InputReader
|
||||
from graphrag_input.text_document import TextDocument
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StructuredFileReader(InputReader):
|
||||
"""Base reader implementation for structured files such as csv and json."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id_column: str | None = None,
|
||||
title_column: str | None = None,
|
||||
text_column: str = "text",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._id_column = id_column
|
||||
self._title_column = title_column
|
||||
self._text_column = text_column
|
||||
|
||||
async def process_data_columns(
|
||||
self,
|
||||
rows: list[dict[str, Any]],
|
||||
path: str,
|
||||
) -> list[TextDocument]:
|
||||
"""Process configured data columns from a list of loaded dicts."""
|
||||
documents = []
|
||||
for index, row in enumerate(rows):
|
||||
# text column is required - harvest from dict
|
||||
text = get_property(row, self._text_column)
|
||||
# id is optional - harvest from dict or hash from text
|
||||
id = (
|
||||
get_property(row, self._id_column)
|
||||
if self._id_column
|
||||
else gen_sha512_hash({"text": text}, ["text"])
|
||||
)
|
||||
# title is optional - harvest from dict or use filename
|
||||
num = f" ({index})" if len(rows) > 1 else ""
|
||||
title = (
|
||||
get_property(row, self._title_column)
|
||||
if self._title_column
|
||||
else f"{path}{num}"
|
||||
)
|
||||
creation_date = await self._storage.get_creation_date(path)
|
||||
documents.append(
|
||||
TextDocument(
|
||||
id=id,
|
||||
title=title,
|
||||
text=text,
|
||||
creation_date=creation_date,
|
||||
raw_data=row,
|
||||
)
|
||||
)
|
||||
return documents
|
||||
43
packages/graphrag-input/graphrag_input/text.py
Normal file
43
packages/graphrag-input/graphrag_input/text.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'TextFileReader' model."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag_input.hashing import gen_sha512_hash
|
||||
from graphrag_input.input_reader import InputReader
|
||||
from graphrag_input.text_document import TextDocument
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextFileReader(InputReader):
|
||||
"""Reader implementation for text files."""
|
||||
|
||||
def __init__(self, file_pattern: str | None = None, **kwargs):
|
||||
super().__init__(
|
||||
file_pattern=file_pattern if file_pattern is not None else ".*\\.txt$",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def read_file(self, path: str) -> list[TextDocument]:
|
||||
"""Read a text file into a list of documents.
|
||||
|
||||
Args:
|
||||
- path - The path to read the file from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - list with a TextDocument for each row in the file.
|
||||
"""
|
||||
text = await self._storage.get(path, encoding=self._encoding)
|
||||
document = TextDocument(
|
||||
id=gen_sha512_hash({"text": text}, ["text"]),
|
||||
title=str(Path(path).name),
|
||||
text=text,
|
||||
creation_date=await self._storage.get_creation_date(path),
|
||||
raw_data=None,
|
||||
)
|
||||
return [document]
|
||||
59
packages/graphrag-input/graphrag_input/text_document.py
Normal file
59
packages/graphrag-input/graphrag_input/text_document.py
Normal file
@ -0,0 +1,59 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""TextDocument dataclass."""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from graphrag_input.get_property import get_property
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextDocument:
|
||||
"""The TextDocument holds relevant content for GraphRAG indexing."""
|
||||
|
||||
id: str
|
||||
"""Unique identifier for the document."""
|
||||
text: str
|
||||
"""The main text content of the document."""
|
||||
title: str
|
||||
"""The title of the document."""
|
||||
creation_date: str
|
||||
"""The creation date of the document, ISO-8601 format."""
|
||||
raw_data: dict[str, Any] | None = None
|
||||
"""Raw data from source document."""
|
||||
|
||||
def get(self, field: str, default_value: Any = None) -> Any:
|
||||
"""
|
||||
Get a single field from the TextDocument.
|
||||
|
||||
Functions like the get method on a dictionary, returning default_value if the field is not found.
|
||||
|
||||
Supports nested fields using dot notation.
|
||||
|
||||
This takes a two step approach for flexibility:
|
||||
1. If the field is one of the standard text document fields (id, title, text, creation_date), just grab it directly. This accommodates unstructured text for example, which just has the standard fields.
|
||||
2. Otherwise. try to extract it from the raw_data dict. This allows users to specify any column from the original input file.
|
||||
|
||||
"""
|
||||
if field in ["id", "title", "text", "creation_date"]:
|
||||
return getattr(self, field)
|
||||
|
||||
raw = self.raw_data or {}
|
||||
try:
|
||||
return get_property(raw, field)
|
||||
except KeyError:
|
||||
return default_value
|
||||
|
||||
def collect(self, fields: list[str]) -> dict[str, Any]:
|
||||
"""Extract data fields from a TextDocument into a dict."""
|
||||
data = {}
|
||||
for field in fields:
|
||||
value = self.get(field)
|
||||
if value is not None:
|
||||
data[field] = value
|
||||
return data
|
||||
44
packages/graphrag-input/pyproject.toml
Normal file
44
packages/graphrag-input/pyproject.toml
Normal file
@ -0,0 +1,44 @@
|
||||
[project]
|
||||
name = "graphrag-input"
|
||||
version = "2.7.0"
|
||||
description = "Input document loading utilities for GraphRAG"
|
||||
authors = [
|
||||
{name = "Alonso Guevara Fernández", email = "alonsog@microsoft.com"},
|
||||
{name = "Andrés Morales Esquivel", email = "andresmor@microsoft.com"},
|
||||
{name = "Chris Trevino", email = "chtrevin@microsoft.com"},
|
||||
{name = "David Tittsworth", email = "datittsw@microsoft.com"},
|
||||
{name = "Dayenne de Souza", email = "ddesouza@microsoft.com"},
|
||||
{name = "Derek Worthen", email = "deworthe@microsoft.com"},
|
||||
{name = "Gaudy Blanco Meneses", email = "gaudyb@microsoft.com"},
|
||||
{name = "Ha Trinh", email = "trinhha@microsoft.com"},
|
||||
{name = "Jonathan Larson", email = "jolarso@microsoft.com"},
|
||||
{name = "Josh Bradley", email = "joshbradley@microsoft.com"},
|
||||
{name = "Kate Lytvynets", email = "kalytv@microsoft.com"},
|
||||
{name = "Kenny Zhang", email = "zhangken@microsoft.com"},
|
||||
{name = "Mónica Carvajal"},
|
||||
{name = "Nathan Evans", email = "naevans@microsoft.com"},
|
||||
{name = "Rodrigo Racanicci", email = "rracanicci@microsoft.com"},
|
||||
{name = "Sarah Smith", email = "smithsarah@microsoft.com"},
|
||||
]
|
||||
license = {text = "MIT"}
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11,<3.14"
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
]
|
||||
dependencies = [
|
||||
"graphrag-common==2.7.0",
|
||||
"graphrag-storage==2.7.0",
|
||||
"pydantic~=2.10",
|
||||
"markitdown~=0.1.0"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Source = "https://github.com/microsoft/graphrag"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling>=1.27.0,<2.0.0"]
|
||||
build-backend = "hatchling.build"
|
||||
@ -9,9 +9,11 @@ import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag_cache.cache_type import CacheType
|
||||
|
||||
import graphrag.api as api
|
||||
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
|
||||
from graphrag.config.enums import CacheType, IndexingMethod
|
||||
from graphrag.config.enums import IndexingMethod
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.index.validate_config import validate_config_names
|
||||
from graphrag.utils.cli import redact
|
||||
@ -102,7 +104,7 @@ def _run_index(
|
||||
)
|
||||
|
||||
if not cache:
|
||||
config.cache.type = CacheType.none
|
||||
config.cache.type = CacheType.Noop
|
||||
|
||||
if not skip_validation:
|
||||
validate_config_names(config)
|
||||
|
||||
@ -63,7 +63,7 @@ def initialize_project_at(
|
||||
raise ValueError(msg)
|
||||
|
||||
input_path = (
|
||||
root / (graphrag_config_defaults.input.storage.base_dir or "input")
|
||||
root / (graphrag_config_defaults.input_storage.base_dir or "input")
|
||||
).resolve()
|
||||
input_path.mkdir(parents=True, exist_ok=True)
|
||||
# using replace with custom tokens instead of format here because we have a placeholder for GRAPHRAG_API_KEY that is used later for .env overlay
|
||||
|
||||
@ -9,13 +9,13 @@ from typing import ClassVar
|
||||
|
||||
from graphrag_cache import CacheType
|
||||
from graphrag_chunking.chunk_strategy_type import ChunkerType
|
||||
from graphrag_input import InputType
|
||||
from graphrag_storage import StorageType
|
||||
|
||||
from graphrag.config.embeddings import default_embeddings
|
||||
from graphrag.config.enums import (
|
||||
AsyncType,
|
||||
AuthType,
|
||||
InputFileType,
|
||||
ModelType,
|
||||
NounPhraseExtractorType,
|
||||
ReportingType,
|
||||
@ -64,7 +64,7 @@ class ChunkingDefaults:
|
||||
size: int = 1200
|
||||
overlap: int = 100
|
||||
encoding_model: str = ENCODING_MODEL
|
||||
prepend_metadata: bool = False
|
||||
prepend_metadata: None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -225,6 +225,18 @@ class StorageDefaults:
|
||||
azure_cosmosdb_account_url: None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputDefaults:
|
||||
"""Default values for input."""
|
||||
|
||||
type: ClassVar[InputType] = InputType.Text
|
||||
encoding: str | None = None
|
||||
file_pattern: None = None
|
||||
id_column: None = None
|
||||
title_column: None = None
|
||||
text_column: None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputStorageDefaults(StorageDefaults):
|
||||
"""Default values for input storage."""
|
||||
@ -232,19 +244,6 @@ class InputStorageDefaults(StorageDefaults):
|
||||
base_dir: str | None = DEFAULT_INPUT_BASE_DIR
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputDefaults:
|
||||
"""Default values for input."""
|
||||
|
||||
storage: InputStorageDefaults = field(default_factory=InputStorageDefaults)
|
||||
file_type: ClassVar[InputFileType] = InputFileType.text
|
||||
encoding: str = "utf-8"
|
||||
file_pattern: str = ""
|
||||
text_column: str = "text"
|
||||
title_column: None = None
|
||||
metadata: None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStorageDefaults(StorageDefaults):
|
||||
"""Default values for cache storage."""
|
||||
@ -311,7 +310,7 @@ class LocalSearchDefaults:
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputDefaults(StorageDefaults):
|
||||
class OutputStorageDefaults(StorageDefaults):
|
||||
"""Default values for output."""
|
||||
|
||||
base_dir: str | None = DEFAULT_OUTPUT_BASE_DIR
|
||||
@ -362,7 +361,7 @@ class SummarizeDescriptionsDefaults:
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateIndexOutputDefaults(StorageDefaults):
|
||||
class UpdateOutputStorageDefaults(StorageDefaults):
|
||||
"""Default values for update index output."""
|
||||
|
||||
base_dir: str | None = DEFAULT_UPDATE_OUTPUT_BASE_DIR
|
||||
@ -389,13 +388,14 @@ class GraphRagConfigDefaults:
|
||||
|
||||
models: dict = field(default_factory=dict)
|
||||
reporting: ReportingDefaults = field(default_factory=ReportingDefaults)
|
||||
storage: StorageDefaults = field(default_factory=StorageDefaults)
|
||||
output: OutputDefaults = field(default_factory=OutputDefaults)
|
||||
update_index_output: UpdateIndexOutputDefaults = field(
|
||||
default_factory=UpdateIndexOutputDefaults
|
||||
input_storage: InputStorageDefaults = field(default_factory=InputStorageDefaults)
|
||||
output_storage: OutputStorageDefaults = field(default_factory=OutputStorageDefaults)
|
||||
update_output_storage: UpdateOutputStorageDefaults = field(
|
||||
default_factory=UpdateOutputStorageDefaults
|
||||
)
|
||||
cache: CacheDefaults = field(default_factory=CacheDefaults)
|
||||
input: InputDefaults = field(default_factory=InputDefaults)
|
||||
|
||||
embed_text: EmbedTextDefaults = field(default_factory=EmbedTextDefaults)
|
||||
chunking: ChunkingDefaults = field(default_factory=ChunkingDefaults)
|
||||
snapshots: SnapshotsDefaults = field(default_factory=SnapshotsDefaults)
|
||||
|
||||
@ -8,57 +8,6 @@ from __future__ import annotations
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class CacheType(str, Enum):
|
||||
"""The cache configuration type for the pipeline."""
|
||||
|
||||
file = "file"
|
||||
"""The file cache configuration type."""
|
||||
memory = "memory"
|
||||
"""The memory cache configuration type."""
|
||||
none = "none"
|
||||
"""The none cache configuration type."""
|
||||
blob = "blob"
|
||||
"""The blob cache configuration type."""
|
||||
cosmosdb = "cosmosdb"
|
||||
"""The cosmosdb cache configuration type"""
|
||||
|
||||
def __repr__(self):
|
||||
"""Get a string representation."""
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
class InputFileType(str, Enum):
|
||||
"""The input file type for the pipeline."""
|
||||
|
||||
csv = "csv"
|
||||
"""The CSV input type."""
|
||||
text = "text"
|
||||
"""The text input type."""
|
||||
json = "json"
|
||||
"""The JSON input type."""
|
||||
|
||||
def __repr__(self):
|
||||
"""Get a string representation."""
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
class StorageType(str, Enum):
|
||||
"""The output type for the pipeline."""
|
||||
|
||||
file = "file"
|
||||
"""The file output type."""
|
||||
memory = "memory"
|
||||
"""The memory output type."""
|
||||
blob = "blob"
|
||||
"""The blob output type."""
|
||||
cosmosdb = "cosmosdb"
|
||||
"""The cosmosdb output type"""
|
||||
|
||||
def __repr__(self):
|
||||
"""Get a string representation."""
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
class VectorStoreType(str, Enum):
|
||||
"""The supported vector store types."""
|
||||
|
||||
|
||||
@ -46,13 +46,10 @@ models:
|
||||
tokens_per_minute: null
|
||||
requests_per_minute: null
|
||||
|
||||
### Input settings ###
|
||||
### Document processing settings ###
|
||||
|
||||
input:
|
||||
storage:
|
||||
type: {graphrag_config_defaults.input.storage.type} # [file, blob, cosmosdb]
|
||||
base_dir: "{graphrag_config_defaults.input.storage.base_dir}"
|
||||
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
|
||||
type: {graphrag_config_defaults.input.type.value} # [csv, text, json, jsonl]
|
||||
|
||||
chunking:
|
||||
type: {graphrag_config_defaults.chunking.type}
|
||||
@ -60,24 +57,28 @@ chunking:
|
||||
overlap: {graphrag_config_defaults.chunking.overlap}
|
||||
encoding_model: {graphrag_config_defaults.chunking.encoding_model}
|
||||
|
||||
### Output/storage settings ###
|
||||
### Storage settings ###
|
||||
## If blob storage is specified in the following four sections,
|
||||
## connection_string and container_name must be provided
|
||||
|
||||
output:
|
||||
type: {graphrag_config_defaults.output.type} # [file, blob, cosmosdb]
|
||||
base_dir: "{graphrag_config_defaults.output.base_dir}"
|
||||
input_storage:
|
||||
type: {graphrag_config_defaults.input_storage.type} # [file, blob, cosmosdb]
|
||||
base_dir: "{graphrag_config_defaults.input_storage.base_dir}"
|
||||
|
||||
output_storage:
|
||||
type: {graphrag_config_defaults.output_storage.type} # [file, blob, cosmosdb]
|
||||
base_dir: "{graphrag_config_defaults.output_storage.base_dir}"
|
||||
|
||||
reporting:
|
||||
type: {graphrag_config_defaults.reporting.type.value} # [file, blob]
|
||||
base_dir: "{graphrag_config_defaults.reporting.base_dir}"
|
||||
|
||||
cache:
|
||||
type: {graphrag_config_defaults.cache.type} # [json, memory, none]
|
||||
storage:
|
||||
type: {graphrag_config_defaults.cache.storage.type} # [file, blob, cosmosdb]
|
||||
base_dir: "{graphrag_config_defaults.cache.storage.base_dir}"
|
||||
|
||||
reporting:
|
||||
type: {graphrag_config_defaults.reporting.type.value} # [file, blob]
|
||||
base_dir: "{graphrag_config_defaults.reporting.base_dir}"
|
||||
|
||||
|
||||
vector_store:
|
||||
type: {vector_store_defaults.type}
|
||||
db_uri: {vector_store_defaults.db_uri}
|
||||
|
||||
@ -9,12 +9,12 @@ from pathlib import Path
|
||||
from devtools import pformat
|
||||
from graphrag_cache import CacheConfig
|
||||
from graphrag_chunking.chunking_config import ChunkingConfig
|
||||
from graphrag_input import InputConfig
|
||||
from graphrag_storage import StorageConfig, StorageType
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import VectorStoreType
|
||||
from graphrag.config.enums import ReportingType, VectorStoreType
|
||||
from graphrag.config.models.basic_search_config import BasicSearchConfig
|
||||
from graphrag.config.models.cluster_graph_config import ClusterGraphConfig
|
||||
from graphrag.config.models.community_reports_config import CommunityReportsConfig
|
||||
@ -24,7 +24,6 @@ from graphrag.config.models.extract_claims_config import ExtractClaimsConfig
|
||||
from graphrag.config.models.extract_graph_config import ExtractGraphConfig
|
||||
from graphrag.config.models.extract_graph_nlp_config import ExtractGraphNLPConfig
|
||||
from graphrag.config.models.global_search_config import GlobalSearchConfig
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.config.models.local_search_config import LocalSearchConfig
|
||||
from graphrag.config.models.prune_graph_config import PruneGraphConfig
|
||||
@ -107,22 +106,22 @@ class GraphRagConfig(BaseModel):
|
||||
)
|
||||
"""The input configuration."""
|
||||
|
||||
def _validate_input_pattern(self) -> None:
|
||||
"""Validate the input file pattern based on the specified type."""
|
||||
if len(self.input.file_pattern) == 0:
|
||||
if self.input.file_type == defs.InputFileType.text:
|
||||
self.input.file_pattern = ".*\\.txt$"
|
||||
else:
|
||||
self.input.file_pattern = f".*\\.{self.input.file_type.value}$"
|
||||
input_storage: StorageConfig = Field(
|
||||
description="The input storage configuration.",
|
||||
default=StorageConfig(
|
||||
base_dir=graphrag_config_defaults.input_storage.base_dir,
|
||||
),
|
||||
)
|
||||
"""The input storage configuration."""
|
||||
|
||||
def _validate_input_base_dir(self) -> None:
|
||||
"""Validate the input base directory."""
|
||||
if self.input.storage.type == StorageType.File:
|
||||
if not self.input.storage.base_dir:
|
||||
if self.input_storage.type == StorageType.File:
|
||||
if not self.input_storage.base_dir:
|
||||
msg = "input storage base directory is required for file input storage. Please rerun `graphrag init` and set the input storage configuration."
|
||||
raise ValueError(msg)
|
||||
self.input.storage.base_dir = str(
|
||||
Path(self.input.storage.base_dir).resolve()
|
||||
self.input_storage.base_dir = str(
|
||||
Path(self.input_storage.base_dir).resolve()
|
||||
)
|
||||
|
||||
chunking: ChunkingConfig = Field(
|
||||
@ -140,7 +139,7 @@ class GraphRagConfig(BaseModel):
|
||||
output: StorageConfig = Field(
|
||||
description="The output configuration.",
|
||||
default=StorageConfig(
|
||||
base_dir=graphrag_config_defaults.output.base_dir,
|
||||
base_dir=graphrag_config_defaults.output_storage.base_dir,
|
||||
),
|
||||
)
|
||||
"""The output configuration."""
|
||||
@ -153,22 +152,22 @@ class GraphRagConfig(BaseModel):
|
||||
raise ValueError(msg)
|
||||
self.output.base_dir = str(Path(self.output.base_dir).resolve())
|
||||
|
||||
update_index_output: StorageConfig = Field(
|
||||
update_output_storage: StorageConfig = Field(
|
||||
description="The output configuration for the updated index.",
|
||||
default=StorageConfig(
|
||||
base_dir=graphrag_config_defaults.update_index_output.base_dir,
|
||||
base_dir=graphrag_config_defaults.update_output_storage.base_dir,
|
||||
),
|
||||
)
|
||||
"""The output configuration for the updated index."""
|
||||
|
||||
def _validate_update_index_output_base_dir(self) -> None:
|
||||
"""Validate the update index output base directory."""
|
||||
if self.update_index_output.type == StorageType.File:
|
||||
if not self.update_index_output.base_dir:
|
||||
msg = "update_index_output base directory is required for file output. Please rerun `graphrag init` and set the update_index_output configuration."
|
||||
def _validate_update_output_storage_base_dir(self) -> None:
|
||||
"""Validate the update output base directory."""
|
||||
if self.update_output_storage.type == StorageType.File:
|
||||
if not self.update_output_storage.base_dir:
|
||||
msg = "update_output_storage base directory is required for file output. Please rerun `graphrag init` and set the update_output_storage configuration."
|
||||
raise ValueError(msg)
|
||||
self.update_index_output.base_dir = str(
|
||||
Path(self.update_index_output.base_dir).resolve()
|
||||
self.update_output_storage.base_dir = str(
|
||||
Path(self.update_output_storage.base_dir).resolve()
|
||||
)
|
||||
|
||||
cache: CacheConfig = Field(
|
||||
@ -184,7 +183,7 @@ class GraphRagConfig(BaseModel):
|
||||
|
||||
def _validate_reporting_base_dir(self) -> None:
|
||||
"""Validate the reporting base directory."""
|
||||
if self.reporting.type == defs.ReportingType.file:
|
||||
if self.reporting.type == ReportingType.file:
|
||||
if self.reporting.base_dir.strip() == "":
|
||||
msg = "Reporting base directory is required for file reporting. Please rerun `graphrag init` and set the reporting configuration."
|
||||
raise ValueError(msg)
|
||||
@ -318,11 +317,10 @@ class GraphRagConfig(BaseModel):
|
||||
@model_validator(mode="after")
|
||||
def _validate_model(self):
|
||||
"""Validate the model configuration."""
|
||||
self._validate_input_pattern()
|
||||
self._validate_input_base_dir()
|
||||
self._validate_reporting_base_dir()
|
||||
self._validate_output_base_dir()
|
||||
self._validate_update_index_output_base_dir()
|
||||
self._validate_update_output_storage_base_dir()
|
||||
self._validate_vector_store_db_uri()
|
||||
self._validate_factories()
|
||||
return self
|
||||
|
||||
@ -1,46 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Parameterization settings for the default configuration."""
|
||||
|
||||
from graphrag_storage import StorageConfig
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import InputFileType
|
||||
|
||||
|
||||
class InputConfig(BaseModel):
|
||||
"""The default configuration section for Input."""
|
||||
|
||||
storage: StorageConfig = Field(
|
||||
description="The storage configuration to use for reading input documents.",
|
||||
default=StorageConfig(
|
||||
base_dir=graphrag_config_defaults.input.storage.base_dir,
|
||||
),
|
||||
)
|
||||
file_type: InputFileType = Field(
|
||||
description="The input file type to use.",
|
||||
default=graphrag_config_defaults.input.file_type,
|
||||
)
|
||||
encoding: str = Field(
|
||||
description="The input file encoding to use.",
|
||||
default=defs.graphrag_config_defaults.input.encoding,
|
||||
)
|
||||
file_pattern: str = Field(
|
||||
description="The input file pattern to use.",
|
||||
default=graphrag_config_defaults.input.file_pattern,
|
||||
)
|
||||
text_column: str = Field(
|
||||
description="The input text column to use.",
|
||||
default=graphrag_config_defaults.input.text_column,
|
||||
)
|
||||
title_column: str | None = Field(
|
||||
description="The input title column to use.",
|
||||
default=graphrag_config_defaults.input.title_column,
|
||||
)
|
||||
metadata: list[str] | None = Field(
|
||||
description="The document attribute columns to use.",
|
||||
default=graphrag_config_defaults.input.metadata,
|
||||
)
|
||||
@ -64,7 +64,7 @@ TEXT = "text"
|
||||
N_TOKENS = "n_tokens"
|
||||
|
||||
CREATION_DATE = "creation_date"
|
||||
METADATA = "metadata"
|
||||
RAW_DATA = "raw_data"
|
||||
|
||||
# the following lists define the final content and ordering of columns in the data model parquet outputs
|
||||
ENTITIES_FINAL_COLUMNS = [
|
||||
@ -155,5 +155,5 @@ DOCUMENTS_FINAL_COLUMNS = [
|
||||
TEXT,
|
||||
TEXT_UNIT_IDS,
|
||||
CREATION_DATE,
|
||||
METADATA,
|
||||
RAW_DATA,
|
||||
]
|
||||
|
||||
@ -1,4 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""The Indexing Engine input package root."""
|
||||
@ -1,35 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'CSVFileReader' model."""
|
||||
|
||||
import logging
|
||||
from io import BytesIO
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.index.input.input_reader import InputReader
|
||||
from graphrag.index.input.util import process_data_columns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CSVFileReader(InputReader):
|
||||
"""Reader implementation for csv files."""
|
||||
|
||||
async def read_file(self, path: str) -> pd.DataFrame:
|
||||
"""Read a csv file into a DataFrame of documents.
|
||||
|
||||
Args:
|
||||
- path - The path to read the file from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - DataFrame with a row for each document in the file.
|
||||
"""
|
||||
buffer = BytesIO(await self._storage.get(path, as_bytes=True))
|
||||
data = pd.read_csv(buffer, encoding=self._config.encoding)
|
||||
data = process_data_columns(data, self._config, path)
|
||||
creation_date = await self._storage.get_creation_date(path)
|
||||
data["creation_date"] = data.apply(lambda _: creation_date, axis=1)
|
||||
return data
|
||||
@ -1,26 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'InputReaderFactory' model."""
|
||||
|
||||
import logging
|
||||
|
||||
from graphrag_common.factory import Factory
|
||||
|
||||
from graphrag.config.enums import InputFileType
|
||||
from graphrag.index.input.csv import CSVFileReader
|
||||
from graphrag.index.input.input_reader import InputReader
|
||||
from graphrag.index.input.json import JSONFileReader
|
||||
from graphrag.index.input.text import TextFileReader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InputReaderFactory(Factory[InputReader]):
|
||||
"""Factory for creating Input Reader instances."""
|
||||
|
||||
|
||||
input_reader_factory = InputReaderFactory()
|
||||
input_reader_factory.register(InputFileType.text, TextFileReader)
|
||||
input_reader_factory.register(InputFileType.csv, CSVFileReader)
|
||||
input_reader_factory.register(InputFileType.json, JSONFileReader)
|
||||
@ -1,85 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'InputReader' model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pandas as pd
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag_storage import Storage
|
||||
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InputReader(metaclass=ABCMeta):
|
||||
"""Provide a cache interface for the pipeline."""
|
||||
|
||||
def __init__(self, storage: Storage, config: InputConfig, **kwargs):
|
||||
self._storage = storage
|
||||
self._config = config
|
||||
|
||||
async def read_files(self) -> pd.DataFrame:
|
||||
"""Load files from storage and apply a loader function based on file type. Process metadata on the results if needed."""
|
||||
files = list(self._storage.find(re.compile(self._config.file_pattern)))
|
||||
|
||||
if len(files) == 0:
|
||||
msg = f"No {self._config.file_type} files found in {self._config.storage.base_dir}"
|
||||
raise ValueError(msg)
|
||||
|
||||
files_loaded = []
|
||||
|
||||
for file in files:
|
||||
try:
|
||||
files_loaded.append(await self.read_file(file))
|
||||
except Exception as e: # noqa: BLE001 (catching Exception is fine here)
|
||||
logger.warning("Warning! Error loading file %s. Skipping...", file)
|
||||
logger.warning("Error: %s", e)
|
||||
|
||||
logger.info(
|
||||
"Found %d %s files, loading %d",
|
||||
len(files),
|
||||
self._config.file_type,
|
||||
len(files_loaded),
|
||||
)
|
||||
result = pd.concat(files_loaded)
|
||||
total_files_log = (
|
||||
f"Total number of unfiltered {self._config.file_type} rows: {len(result)}"
|
||||
)
|
||||
logger.info(total_files_log)
|
||||
# Convert metadata columns to strings and collapse them into a JSON object
|
||||
if self._config.metadata:
|
||||
if all(col in result.columns for col in self._config.metadata):
|
||||
# Collapse the metadata columns into a single JSON object column
|
||||
result["metadata"] = result[self._config.metadata].apply(
|
||||
lambda row: row.to_dict(), axis=1
|
||||
)
|
||||
else:
|
||||
value_error_msg = (
|
||||
"One or more metadata columns not found in the DataFrame."
|
||||
)
|
||||
raise ValueError(value_error_msg)
|
||||
|
||||
result[self._config.metadata] = result[self._config.metadata].astype(str)
|
||||
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def read_file(self, path: str) -> pd.DataFrame:
|
||||
"""Read a file into a DataFrame of documents.
|
||||
|
||||
Args:
|
||||
- path - The path to read the file from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - DataFrame with a row for each document in the file.
|
||||
"""
|
||||
@ -1,39 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'JSONFileReader' model."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.index.input.input_reader import InputReader
|
||||
from graphrag.index.input.util import process_data_columns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JSONFileReader(InputReader):
|
||||
"""Reader implementation for json files."""
|
||||
|
||||
async def read_file(self, path: str) -> pd.DataFrame:
|
||||
"""Read a JSON file into a DataFrame of documents.
|
||||
|
||||
Args:
|
||||
- path - The path to read the file from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - DataFrame with a row for each document in the file.
|
||||
"""
|
||||
text = await self._storage.get(path, encoding=self._config.encoding)
|
||||
as_json = json.loads(text)
|
||||
# json file could just be a single object, or an array of objects
|
||||
rows = as_json if isinstance(as_json, list) else [as_json]
|
||||
data = pd.DataFrame(rows)
|
||||
data = process_data_columns(data, self._config, path)
|
||||
creation_date = await self._storage.get_creation_date(path)
|
||||
data["creation_date"] = data.apply(lambda _: creation_date, axis=1)
|
||||
|
||||
return data
|
||||
@ -1,35 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'TextFileReader' model."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.index.input.input_reader import InputReader
|
||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextFileReader(InputReader):
|
||||
"""Reader implementation for text files."""
|
||||
|
||||
async def read_file(self, path: str) -> pd.DataFrame:
|
||||
"""Read a text file into a DataFrame of documents.
|
||||
|
||||
Args:
|
||||
- path - The path to read the file from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - DataFrame with a row for each document in the file.
|
||||
"""
|
||||
text = await self._storage.get(path, encoding=self._config.encoding)
|
||||
new_item = {"text": text}
|
||||
new_item["id"] = gen_sha512_hash(new_item, new_item.keys())
|
||||
new_item["title"] = str(Path(path).name)
|
||||
new_item["creation_date"] = await self._storage.get_creation_date(path)
|
||||
return pd.DataFrame([new_item])
|
||||
@ -1,46 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Shared column processing for structured input files."""
|
||||
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def process_data_columns(
|
||||
documents: pd.DataFrame, config: InputConfig, path: str
|
||||
) -> pd.DataFrame:
|
||||
"""Process configured data columns of a DataFrame."""
|
||||
if "id" not in documents.columns:
|
||||
documents["id"] = documents.apply(
|
||||
lambda x: gen_sha512_hash(x, x.keys()), axis=1
|
||||
)
|
||||
if config.text_column is not None and "text" not in documents.columns:
|
||||
if config.text_column not in documents.columns:
|
||||
logger.warning(
|
||||
"text_column %s not found in csv file %s",
|
||||
config.text_column,
|
||||
path,
|
||||
)
|
||||
else:
|
||||
documents["text"] = documents.apply(lambda x: x[config.text_column], axis=1)
|
||||
if config.title_column is not None:
|
||||
if config.title_column not in documents.columns:
|
||||
logger.warning(
|
||||
"title_column %s not found in csv file %s",
|
||||
config.title_column,
|
||||
path,
|
||||
)
|
||||
else:
|
||||
documents["title"] = documents.apply(
|
||||
lambda x: x[config.title_column], axis=1
|
||||
)
|
||||
else:
|
||||
documents["title"] = documents.apply(lambda _: path, axis=1)
|
||||
return documents
|
||||
@ -50,7 +50,8 @@ def prune_graph(
|
||||
graph.remove_nodes_from([
|
||||
node
|
||||
for node, data in graph.nodes(data=True)
|
||||
if data[schemas.NODE_FREQUENCY] < min_node_freq
|
||||
if schemas.NODE_FREQUENCY not in data
|
||||
or data[schemas.NODE_FREQUENCY] < min_node_freq
|
||||
])
|
||||
if max_node_freq_std is not None:
|
||||
upper_threshold = _get_upper_threshold_by_std(
|
||||
|
||||
@ -35,7 +35,7 @@ async def run_pipeline(
|
||||
input_documents: pd.DataFrame | None = None,
|
||||
) -> AsyncIterable[PipelineRunResult]:
|
||||
"""Run all workflows using a simplified pipeline."""
|
||||
input_storage = create_storage(config.input.storage)
|
||||
input_storage = create_storage(config.input_storage)
|
||||
output_storage = create_storage(config.output)
|
||||
cache = create_cache(config.cache)
|
||||
|
||||
@ -49,7 +49,7 @@ async def run_pipeline(
|
||||
if is_update_run:
|
||||
logger.info("Running incremental indexing.")
|
||||
|
||||
update_storage = create_storage(config.update_index_output)
|
||||
update_storage = create_storage(config.update_output_storage)
|
||||
# we use this to store the new subset index, and will merge its content with the previous index
|
||||
update_timestamp = time.strftime("%Y%m%d-%H%M%S")
|
||||
timestamped_storage = update_storage.child(update_timestamp)
|
||||
|
||||
@ -53,7 +53,7 @@ def get_update_storages(
|
||||
) -> tuple[Storage, Storage, Storage]:
|
||||
"""Get storage objects for the update index run."""
|
||||
output_storage = create_storage(config.output)
|
||||
update_storage = create_storage(config.update_index_output)
|
||||
update_storage = create_storage(config.update_output_storage)
|
||||
timestamped_storage = update_storage.child(timestamp)
|
||||
delta_storage = timestamped_storage.child("delta")
|
||||
previous_storage = timestamped_storage.child("previous")
|
||||
|
||||
@ -3,14 +3,14 @@
|
||||
|
||||
"""A module containing run_workflow method definition."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
from graphrag_chunking.add_metadata import add_metadata
|
||||
from graphrag_chunking.chunker import Chunker
|
||||
from graphrag_chunking.chunker_factory import create_chunker
|
||||
from graphrag_chunking.transformers import add_metadata
|
||||
from graphrag_input import TextDocument
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
@ -54,7 +54,7 @@ def create_base_text_units(
|
||||
callbacks: WorkflowCallbacks,
|
||||
tokenizer: Tokenizer,
|
||||
chunker: Chunker,
|
||||
prepend_metadata: bool | None = False,
|
||||
prepend_metadata: list[str] | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to transform base text_units."""
|
||||
documents.sort_values(by=["id"], ascending=[True], inplace=True)
|
||||
@ -66,15 +66,27 @@ def create_base_text_units(
|
||||
logger.info("Starting chunking process for %d documents", total_rows)
|
||||
|
||||
def chunker_with_logging(row: pd.Series, row_index: int) -> Any:
|
||||
row["chunks"] = [chunk.text for chunk in chunker.chunk(row["text"])]
|
||||
if prepend_metadata:
|
||||
# create a standard text document for metadata plucking
|
||||
# ignore any additional fields in case the input dataframe has extra columns
|
||||
document = TextDocument(
|
||||
id=row["id"],
|
||||
title=row["title"],
|
||||
text=row["text"],
|
||||
creation_date=row["creation_date"],
|
||||
raw_data=row["raw_data"],
|
||||
)
|
||||
metadata = document.collect(prepend_metadata)
|
||||
transformer = add_metadata(
|
||||
metadata=metadata, line_delimiter=".\n"
|
||||
) # delim with . for back-compat older indexes
|
||||
else:
|
||||
transformer = None
|
||||
|
||||
row["chunks"] = [
|
||||
chunk.text for chunk in chunker.chunk(row["text"], transform=transformer)
|
||||
]
|
||||
|
||||
metadata = row.get("metadata", None)
|
||||
if prepend_metadata and metadata is not None:
|
||||
metadata = json.loads(metadata) if isinstance(metadata, str) else metadata
|
||||
row["chunks"] = [
|
||||
add_metadata(chunk, metadata, line_delimiter=".\n")
|
||||
for chunk in row["chunks"]
|
||||
]
|
||||
tick()
|
||||
logger.info("chunker progress: %d/%d", row_index + 1, total_rows)
|
||||
return row
|
||||
|
||||
@ -67,7 +67,7 @@ def create_final_documents(
|
||||
rejoined["id"] = rejoined["id"].astype(str)
|
||||
rejoined["human_readable_id"] = rejoined.index
|
||||
|
||||
if "metadata" not in rejoined.columns:
|
||||
rejoined["metadata"] = pd.Series(dtype="object")
|
||||
if "raw_data" not in rejoined.columns:
|
||||
rejoined["raw_data"] = pd.Series(dtype="object")
|
||||
|
||||
return rejoined.loc[:, DOCUMENTS_FINAL_COLUMNS]
|
||||
|
||||
@ -6,10 +6,9 @@
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
from graphrag_input import InputReader, create_input_reader
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.input.factory import InputReaderFactory
|
||||
from graphrag.index.input.input_reader import InputReader
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import write_table_to_storage
|
||||
@ -22,13 +21,15 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""Load and parse input documents into a standard format."""
|
||||
input_reader = InputReaderFactory().create(
|
||||
config.input.file_type,
|
||||
{"storage": context.input_storage, "config": config.input},
|
||||
)
|
||||
input_reader = create_input_reader(config.input, context.input_storage)
|
||||
|
||||
output = await load_input_documents(input_reader)
|
||||
|
||||
if len(output) == 0:
|
||||
msg = "Error reading documents, please see logs."
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
logger.info("Final # of rows loaded: %s", len(output))
|
||||
context.stats.num_documents = len(output)
|
||||
|
||||
@ -39,4 +40,4 @@ async def run_workflow(
|
||||
|
||||
async def load_input_documents(input_reader: InputReader) -> pd.DataFrame:
|
||||
"""Load and parse input documents into a standard format."""
|
||||
return await input_reader.read_files()
|
||||
return pd.DataFrame(await input_reader.read_files())
|
||||
|
||||
@ -6,11 +6,11 @@
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
from graphrag_input.input_reader import InputReader
|
||||
from graphrag_input.input_reader_factory import create_input_reader
|
||||
from graphrag_storage import Storage
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.input.factory import InputReaderFactory
|
||||
from graphrag.index.input.input_reader import InputReader
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.index.update.incremental_index import get_delta_docs
|
||||
@ -24,10 +24,7 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""Load and parse update-only input documents into a standard format."""
|
||||
input_reader = InputReaderFactory().create(
|
||||
config.input.file_type,
|
||||
{"storage": context.input_storage, "config": config.input},
|
||||
)
|
||||
input_reader = create_input_reader(config.input, context.input_storage)
|
||||
output = await load_update_documents(
|
||||
input_reader,
|
||||
context.previous_storage,
|
||||
@ -50,7 +47,7 @@ async def load_update_documents(
|
||||
previous_storage: Storage,
|
||||
) -> pd.DataFrame:
|
||||
"""Load and parse update-only input documents into a standard format."""
|
||||
input_documents = await input_reader.read_files()
|
||||
input_documents = pd.DataFrame(await input_reader.read_files())
|
||||
# previous storage is the output of the previous run
|
||||
# we'll use this to diff the input from the prior
|
||||
delta_documents = await get_delta_docs(input_documents, previous_storage)
|
||||
|
||||
@ -58,6 +58,7 @@ def prune_graph(
|
||||
"""Prune a full graph based on graph statistics."""
|
||||
# create a temporary graph to prune, then turn it back into dataframes
|
||||
graph = create_graph(relationships, edge_attr=["weight"], nodes=entities)
|
||||
|
||||
pruned = prune_graph_operation(
|
||||
graph,
|
||||
min_node_freq=pruning_config.min_node_freq,
|
||||
|
||||
@ -10,11 +10,11 @@ import numpy as np
|
||||
import pandas as pd
|
||||
from graphrag_cache.noop_cache import NoopCache
|
||||
from graphrag_chunking.chunker_factory import create_chunker
|
||||
from graphrag_input import create_input_reader
|
||||
from graphrag_storage import create_storage
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.input.factory import InputReaderFactory
|
||||
from graphrag.index.operations.embed_text.run_embed_text import (
|
||||
run_embed_text,
|
||||
)
|
||||
@ -63,14 +63,11 @@ async def load_docs_in_chunks(
|
||||
)
|
||||
tokenizer = get_tokenizer(embeddings_llm_settings)
|
||||
chunker = create_chunker(config.chunking, tokenizer.encode, tokenizer.decode)
|
||||
input_storage = create_storage(config.input.storage)
|
||||
input_reader = InputReaderFactory().create(
|
||||
config.input.file_type,
|
||||
{"storage": input_storage, "config": config.input},
|
||||
)
|
||||
input_storage = create_storage(config.input_storage)
|
||||
input_reader = create_input_reader(config.input, input_storage)
|
||||
dataset = await input_reader.read_files()
|
||||
chunks_df = create_base_text_units(
|
||||
documents=dataset,
|
||||
documents=pd.DataFrame(dataset),
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
tokenizer=tokenizer,
|
||||
chunker=chunker,
|
||||
|
||||
@ -41,6 +41,7 @@ dependencies = [
|
||||
"environs~=11.0",
|
||||
"graphrag-cache==2.7.0",
|
||||
"graphrag-common==2.7.0",
|
||||
"graphrag-input==2.7.0",
|
||||
"graphrag-storage==2.7.0",
|
||||
"graspologic-native~=1.2",
|
||||
"json-repair~=0.30",
|
||||
|
||||
@ -55,6 +55,7 @@ members = ["packages/*"]
|
||||
[tool.uv.sources]
|
||||
graphrag-chunking = { workspace = true }
|
||||
graphrag-common = { workspace = true }
|
||||
graphrag-input = { workspace = true }
|
||||
graphrag-storage = { workspace = true }
|
||||
graphrag-cache = { workspace = true }
|
||||
|
||||
@ -75,6 +76,7 @@ _semversioner_update_graphrag_chunking_toml_version = "update-toml update --file
|
||||
_semversioner_update_graphrag_common_toml_version = "update-toml update --file packages/graphrag-common/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
|
||||
_semversioner_update_graphrag_storage_toml_version = "update-toml update --file packages/graphrag-storage/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
|
||||
_semversioner_update_graphrag_cache_toml_version = "update-toml update --file packages/graphrag-cache/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
|
||||
_semversioner_update_graphrag_input_toml_version = "update-toml update --file packages/graphrag-input/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
|
||||
_semversioner_update_workspace_dependency_versions = "python -m scripts.update_workspace_dependency_versions"
|
||||
semversioner_add = "semversioner add-change"
|
||||
coverage_report = 'coverage report --omit "**/tests/**" --show-missing'
|
||||
@ -110,6 +112,7 @@ sequence = [
|
||||
'_semversioner_update_graphrag_toml_version',
|
||||
'_semversioner_update_graphrag_common_toml_version',
|
||||
'_semversioner_update_graphrag_chunking_toml_version',
|
||||
'_semversioner_update_graphrag_input_toml_version',
|
||||
'_semversioner_update_graphrag_storage_toml_version',
|
||||
'_semversioner_update_graphrag_cache_toml_version',
|
||||
'_semversioner_update_workspace_dependency_versions',
|
||||
|
||||
2
tests/fixtures/azure/config.json
vendored
2
tests/fixtures/azure/config.json
vendored
@ -1,6 +1,6 @@
|
||||
{
|
||||
"input_path": "./tests/fixtures/azure",
|
||||
"input_file_type": "text",
|
||||
"input_type": "text",
|
||||
"index_method": "standard",
|
||||
"workflow_config": {
|
||||
"skip_assert": true,
|
||||
|
||||
2
tests/fixtures/azure/settings.yml
vendored
2
tests/fixtures/azure/settings.yml
vendored
@ -14,7 +14,7 @@ input:
|
||||
connection_string: ${LOCAL_BLOB_STORAGE_CONNECTION_STRING}
|
||||
container_name: azurefixture
|
||||
base_dir: input
|
||||
file_type: text
|
||||
type: text
|
||||
|
||||
|
||||
cache:
|
||||
|
||||
4
tests/fixtures/min-csv/config.json
vendored
4
tests/fixtures/min-csv/config.json
vendored
@ -1,6 +1,6 @@
|
||||
{
|
||||
"input_path": "./tests/fixtures/min-csv",
|
||||
"input_file_type": "text",
|
||||
"input_type": "text",
|
||||
"index_method": "standard",
|
||||
"workflow_config": {
|
||||
"load_input_documents": {
|
||||
@ -73,7 +73,7 @@
|
||||
15
|
||||
],
|
||||
"nan_allowed_columns": [
|
||||
"metadata"
|
||||
"raw_data"
|
||||
],
|
||||
"max_runtime": 30,
|
||||
"expected_artifacts": ["documents.parquet"]
|
||||
|
||||
3
tests/fixtures/min-csv/settings.yml
vendored
3
tests/fixtures/min-csv/settings.yml
vendored
@ -35,7 +35,8 @@ vector_store:
|
||||
overwrite: True
|
||||
|
||||
input:
|
||||
file_type: csv
|
||||
type: csv
|
||||
encoding: utf-8-sig
|
||||
|
||||
snapshots:
|
||||
embeddings: true
|
||||
4
tests/fixtures/text/config.json
vendored
4
tests/fixtures/text/config.json
vendored
@ -1,6 +1,6 @@
|
||||
{
|
||||
"input_path": "./tests/fixtures/text",
|
||||
"input_file_type": "text",
|
||||
"input_type": "text",
|
||||
"index_method": "fast",
|
||||
"workflow_config": {
|
||||
"load_input_documents": {
|
||||
@ -72,7 +72,7 @@
|
||||
1
|
||||
],
|
||||
"nan_allowed_columns": [
|
||||
"metadata"
|
||||
"raw_data"
|
||||
],
|
||||
"max_runtime": 30,
|
||||
"expected_artifacts": ["documents.parquet"]
|
||||
|
||||
@ -126,7 +126,7 @@ class TestIndexer:
|
||||
def __run_indexer(
|
||||
self,
|
||||
root: Path,
|
||||
input_file_type: str,
|
||||
input_type: str,
|
||||
index_method: str,
|
||||
):
|
||||
command = [
|
||||
@ -232,7 +232,7 @@ class TestIndexer:
|
||||
def test_fixture(
|
||||
self,
|
||||
input_path: str,
|
||||
input_file_type: str,
|
||||
input_type: str,
|
||||
index_method: str,
|
||||
workflow_config: dict[str, dict[str, Any]],
|
||||
query_config: list[dict[str, str]],
|
||||
@ -248,7 +248,7 @@ class TestIndexer:
|
||||
dispose = asyncio.run(prepare_azurite_data(input_path, azure))
|
||||
|
||||
print("running indexer")
|
||||
self.__run_indexer(root, input_file_type, index_method)
|
||||
self.__run_indexer(root, input_type, index_method)
|
||||
print("indexer complete")
|
||||
|
||||
if dispose is not None:
|
||||
|
||||
@ -27,7 +27,7 @@ class TestRunSentences:
|
||||
bootstrap()
|
||||
|
||||
def test_basic_functionality(self):
|
||||
"""Test basic sentence splitting without metadata"""
|
||||
"""Test basic sentence splitting"""
|
||||
input = "This is a test. Another sentence. And a third one!"
|
||||
chunker = create_chunker(ChunkingConfig(type=ChunkerType.Sentence))
|
||||
chunks = chunker.chunk(input)
|
||||
|
||||
@ -1,14 +1,15 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag_chunking.add_metadata import add_metadata
|
||||
from graphrag_chunking.transformers import add_metadata
|
||||
|
||||
|
||||
def test_add_metadata_one_row():
|
||||
"""Test prepending metadata to chunks"""
|
||||
chunks = ["This is a test.", "Another sentence."]
|
||||
metadata = {"message": "hello"}
|
||||
results = [add_metadata(chunk, metadata) for chunk in chunks]
|
||||
transformer = add_metadata(metadata)
|
||||
results = [transformer(chunk) for chunk in chunks]
|
||||
assert results[0] == "message: hello\nThis is a test."
|
||||
assert results[1] == "message: hello\nAnother sentence."
|
||||
|
||||
@ -17,7 +18,8 @@ def test_add_metadata_one_row_append():
|
||||
"""Test prepending metadata to chunks"""
|
||||
chunks = ["This is a test.", "Another sentence."]
|
||||
metadata = {"message": "hello"}
|
||||
results = [add_metadata(chunk, metadata, append=True) for chunk in chunks]
|
||||
transformer = add_metadata(metadata, append=True)
|
||||
results = [transformer(chunk) for chunk in chunks]
|
||||
assert results[0] == "This is a test.message: hello\n"
|
||||
assert results[1] == "Another sentence.message: hello\n"
|
||||
|
||||
@ -26,7 +28,8 @@ def test_add_metadata_multiple_rows():
|
||||
"""Test prepending metadata to chunks"""
|
||||
chunks = ["This is a test.", "Another sentence."]
|
||||
metadata = {"message": "hello", "tag": "first"}
|
||||
results = [add_metadata(chunk, metadata) for chunk in chunks]
|
||||
transformer = add_metadata(metadata)
|
||||
results = [transformer(chunk) for chunk in chunks]
|
||||
assert results[0] == "message: hello\ntag: first\nThis is a test."
|
||||
assert results[1] == "message: hello\ntag: first\nAnother sentence."
|
||||
|
||||
@ -35,9 +38,7 @@ def test_add_metadata_custom_delimiters():
|
||||
"""Test prepending metadata to chunks"""
|
||||
chunks = ["This is a test.", "Another sentence."]
|
||||
metadata = {"message": "hello", "tag": "first"}
|
||||
results = [
|
||||
add_metadata(chunk, metadata, delimiter="-", line_delimiter="_")
|
||||
for chunk in chunks
|
||||
]
|
||||
transformer = add_metadata(metadata, delimiter="-", line_delimiter="_")
|
||||
results = [transformer(chunk) for chunk in chunks]
|
||||
assert results[0] == "message-hello_tag-first_This is a test."
|
||||
assert results[1] == "message-hello_tag-first_Another sentence."
|
||||
|
||||
@ -17,7 +17,6 @@ from graphrag.config.models.extract_graph_nlp_config import (
|
||||
)
|
||||
from graphrag.config.models.global_search_config import GlobalSearchConfig
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.config.models.local_search_config import LocalSearchConfig
|
||||
from graphrag.config.models.prune_graph_config import PruneGraphConfig
|
||||
@ -29,6 +28,7 @@ from graphrag.config.models.summarize_descriptions_config import (
|
||||
from graphrag.config.models.vector_store_config import VectorStoreConfig
|
||||
from graphrag_cache import CacheConfig
|
||||
from graphrag_chunking.chunking_config import ChunkingConfig
|
||||
from graphrag_input import InputConfig
|
||||
from graphrag_storage import StorageConfig
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -143,13 +143,11 @@ def assert_cache_configs(actual: CacheConfig, expected: CacheConfig) -> None:
|
||||
|
||||
|
||||
def assert_input_configs(actual: InputConfig, expected: InputConfig) -> None:
|
||||
assert_storage_config(actual.storage, expected.storage)
|
||||
assert actual.file_type == expected.file_type
|
||||
assert actual.type == expected.type
|
||||
assert actual.encoding == expected.encoding
|
||||
assert actual.file_pattern == expected.file_pattern
|
||||
assert actual.text_column == expected.text_column
|
||||
assert actual.title_column == expected.title_column
|
||||
assert actual.metadata == expected.metadata
|
||||
|
||||
|
||||
def assert_text_embedding_configs(
|
||||
@ -338,8 +336,8 @@ def assert_graphrag_configs(actual: GraphRagConfig, expected: GraphRagConfig) ->
|
||||
assert_vector_store_configs(actual.vector_store, expected.vector_store)
|
||||
assert_reporting_configs(actual.reporting, expected.reporting)
|
||||
assert_storage_config(actual.output, expected.output)
|
||||
|
||||
assert_storage_config(actual.update_index_output, expected.update_index_output)
|
||||
assert_storage_config(actual.input_storage, expected.input_storage)
|
||||
assert_storage_config(actual.update_output_storage, expected.update_output_storage)
|
||||
|
||||
assert_cache_configs(actual.cache, expected.cache)
|
||||
assert_input_configs(actual.input, expected.input)
|
||||
|
||||
8
tests/unit/indexing/input/data/one-html/input.html
Normal file
8
tests/unit/indexing/input/data/one-html/input.html
Normal file
@ -0,0 +1,8 @@
|
||||
<html>
|
||||
<head>
|
||||
<title>Test</title>
|
||||
<head>
|
||||
<body>
|
||||
Hi how are you today?
|
||||
</body>
|
||||
</html>
|
||||
3
tests/unit/indexing/input/data/one-jsonl/input.jsonl
Normal file
3
tests/unit/indexing/input/data/one-jsonl/input.jsonl
Normal file
@ -0,0 +1,3 @@
|
||||
{ "title": "Hello", "text": "Hi how are you today?"}
|
||||
{ "title": "Goodbye", "text": "I'm outta here"}
|
||||
{ "title": "Adios", "text": "See you later"}
|
||||
@ -1,82 +1,56 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.config.enums import InputFileType
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.factory import InputReaderFactory
|
||||
from graphrag_input import InputConfig, InputType, create_input_reader
|
||||
from graphrag_storage import StorageConfig, create_storage
|
||||
|
||||
|
||||
async def test_csv_loader_one_file():
|
||||
config = InputConfig(
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-csv",
|
||||
),
|
||||
file_type=InputFileType.csv,
|
||||
type=InputType.Csv,
|
||||
file_pattern=".*\\.csv$",
|
||||
)
|
||||
storage = create_storage(config.storage)
|
||||
documents = (
|
||||
await InputReaderFactory()
|
||||
.create(config.file_type, {"storage": storage, "config": config})
|
||||
.read_files()
|
||||
storage = create_storage(
|
||||
StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-csv",
|
||||
)
|
||||
)
|
||||
assert documents.shape == (2, 4)
|
||||
assert documents["title"].iloc[0] == "input.csv"
|
||||
reader = create_input_reader(config, storage)
|
||||
documents = await reader.read_files()
|
||||
assert len(documents) == 2
|
||||
assert documents[0].title == "input.csv (0)"
|
||||
assert documents[0].raw_data == {
|
||||
"title": "Hello",
|
||||
"text": "Hi how are you today?",
|
||||
}
|
||||
assert documents[1].title == "input.csv (1)"
|
||||
|
||||
|
||||
async def test_csv_loader_one_file_with_title():
|
||||
config = InputConfig(
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-csv",
|
||||
),
|
||||
file_type=InputFileType.csv,
|
||||
file_pattern=".*\\.csv$",
|
||||
type=InputType.Csv,
|
||||
title_column="title",
|
||||
)
|
||||
storage = create_storage(config.storage)
|
||||
documents = (
|
||||
await InputReaderFactory()
|
||||
.create(config.file_type, {"storage": storage, "config": config})
|
||||
.read_files()
|
||||
)
|
||||
assert documents.shape == (2, 4)
|
||||
assert documents["title"].iloc[0] == "Hello"
|
||||
|
||||
|
||||
async def test_csv_loader_one_file_with_metadata():
|
||||
config = InputConfig(
|
||||
storage=StorageConfig(
|
||||
storage = create_storage(
|
||||
StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-csv",
|
||||
),
|
||||
file_type=InputFileType.csv,
|
||||
file_pattern=".*\\.csv$",
|
||||
title_column="title",
|
||||
metadata=["title"],
|
||||
)
|
||||
)
|
||||
storage = create_storage(config.storage)
|
||||
documents = (
|
||||
await InputReaderFactory()
|
||||
.create(config.file_type, {"storage": storage, "config": config})
|
||||
.read_files()
|
||||
)
|
||||
print(documents)
|
||||
assert documents.shape == (2, 5)
|
||||
assert documents["metadata"][0] == {"title": "Hello"}
|
||||
reader = create_input_reader(config, storage)
|
||||
documents = await reader.read_files()
|
||||
assert len(documents) == 2
|
||||
assert documents[0].title == "Hello"
|
||||
|
||||
|
||||
async def test_csv_loader_multiple_files():
|
||||
config = InputConfig(
|
||||
storage=StorageConfig(
|
||||
type=InputType.Csv,
|
||||
)
|
||||
storage = create_storage(
|
||||
StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/multiple-csvs",
|
||||
),
|
||||
file_type=InputFileType.csv,
|
||||
file_pattern=".*\\.csv$",
|
||||
)
|
||||
)
|
||||
storage = create_storage(config.storage)
|
||||
documents = (
|
||||
await InputReaderFactory()
|
||||
.create(config.file_type, {"storage": storage, "config": config})
|
||||
.read_files()
|
||||
)
|
||||
assert documents.shape == (4, 4)
|
||||
reader = create_input_reader(config, storage)
|
||||
documents = await reader.read_files()
|
||||
assert len(documents) == 4
|
||||
|
||||
@ -1,100 +1,71 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.config.enums import InputFileType
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.factory import InputReaderFactory
|
||||
from graphrag_input import InputConfig, InputType, create_input_reader
|
||||
from graphrag_storage import StorageConfig, create_storage
|
||||
|
||||
|
||||
async def test_json_loader_one_file_one_object():
|
||||
config = InputConfig(
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-json-one-object",
|
||||
),
|
||||
file_type=InputFileType.json,
|
||||
type=InputType.Json,
|
||||
file_pattern=".*\\.json$",
|
||||
)
|
||||
storage = create_storage(config.storage)
|
||||
documents = (
|
||||
await InputReaderFactory()
|
||||
.create(config.file_type, {"storage": storage, "config": config})
|
||||
.read_files()
|
||||
storage = create_storage(
|
||||
StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-json-one-object",
|
||||
)
|
||||
)
|
||||
assert documents.shape == (1, 4)
|
||||
assert documents["title"].iloc[0] == "input.json"
|
||||
reader = create_input_reader(config, storage)
|
||||
documents = await reader.read_files()
|
||||
assert len(documents) == 1
|
||||
assert documents[0].title == "input.json"
|
||||
assert documents[0].raw_data == {
|
||||
"title": "Hello",
|
||||
"text": "Hi how are you today?",
|
||||
}
|
||||
|
||||
|
||||
async def test_json_loader_one_file_multiple_objects():
|
||||
config = InputConfig(
|
||||
storage=StorageConfig(
|
||||
type=InputType.Json,
|
||||
)
|
||||
storage = create_storage(
|
||||
StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-json-multiple-objects",
|
||||
),
|
||||
file_type=InputFileType.json,
|
||||
file_pattern=".*\\.json$",
|
||||
)
|
||||
)
|
||||
storage = create_storage(config.storage)
|
||||
documents = (
|
||||
await InputReaderFactory()
|
||||
.create(config.file_type, {"storage": storage, "config": config})
|
||||
.read_files()
|
||||
)
|
||||
print(documents)
|
||||
assert documents.shape == (3, 4)
|
||||
assert documents["title"].iloc[0] == "input.json"
|
||||
reader = create_input_reader(config, storage)
|
||||
documents = await reader.read_files()
|
||||
assert len(documents) == 3
|
||||
assert documents[0].title == "input.json (0)"
|
||||
assert documents[1].title == "input.json (1)"
|
||||
|
||||
|
||||
async def test_json_loader_one_file_with_title():
|
||||
config = InputConfig(
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-json-one-object",
|
||||
),
|
||||
file_type=InputFileType.json,
|
||||
file_pattern=".*\\.json$",
|
||||
type=InputType.Json,
|
||||
title_column="title",
|
||||
)
|
||||
storage = create_storage(config.storage)
|
||||
documents = (
|
||||
await InputReaderFactory()
|
||||
.create(config.file_type, {"storage": storage, "config": config})
|
||||
.read_files()
|
||||
)
|
||||
assert documents.shape == (1, 4)
|
||||
assert documents["title"].iloc[0] == "Hello"
|
||||
|
||||
|
||||
async def test_json_loader_one_file_with_metadata():
|
||||
config = InputConfig(
|
||||
storage=StorageConfig(
|
||||
storage = create_storage(
|
||||
StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-json-one-object",
|
||||
),
|
||||
file_type=InputFileType.json,
|
||||
file_pattern=".*\\.json$",
|
||||
title_column="title",
|
||||
metadata=["title"],
|
||||
)
|
||||
)
|
||||
storage = create_storage(config.storage)
|
||||
documents = (
|
||||
await InputReaderFactory()
|
||||
.create(config.file_type, {"storage": storage, "config": config})
|
||||
.read_files()
|
||||
)
|
||||
assert documents.shape == (1, 5)
|
||||
assert documents["metadata"][0] == {"title": "Hello"}
|
||||
reader = create_input_reader(config, storage)
|
||||
documents = await reader.read_files()
|
||||
assert len(documents) == 1
|
||||
assert documents[0].title == "Hello"
|
||||
|
||||
|
||||
async def test_json_loader_multiple_files():
|
||||
config = InputConfig(
|
||||
storage=StorageConfig(
|
||||
type=InputType.Json,
|
||||
)
|
||||
storage = create_storage(
|
||||
StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/multiple-jsons",
|
||||
),
|
||||
file_type=InputFileType.json,
|
||||
file_pattern=".*\\.json$",
|
||||
)
|
||||
)
|
||||
storage = create_storage(config.storage)
|
||||
documents = (
|
||||
await InputReaderFactory()
|
||||
.create(config.file_type, {"storage": storage, "config": config})
|
||||
.read_files()
|
||||
)
|
||||
assert documents.shape == (4, 4)
|
||||
reader = create_input_reader(config, storage)
|
||||
documents = await reader.read_files()
|
||||
assert len(documents) == 4
|
||||
|
||||
42
tests/unit/indexing/input/test_jsonl_loader.py
Normal file
42
tests/unit/indexing/input/test_jsonl_loader.py
Normal file
@ -0,0 +1,42 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag_input import InputConfig, InputType, create_input_reader
|
||||
from graphrag_storage import StorageConfig, create_storage
|
||||
|
||||
|
||||
async def test_jsonl_loader_one_file_multiple_objects():
|
||||
config = InputConfig(
|
||||
type=InputType.JsonLines,
|
||||
file_pattern=".*\\.jsonl$",
|
||||
)
|
||||
storage = create_storage(
|
||||
StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-jsonl",
|
||||
)
|
||||
)
|
||||
reader = create_input_reader(config, storage)
|
||||
documents = await reader.read_files()
|
||||
assert len(documents) == 3
|
||||
assert documents[0].title == "input.jsonl (0)"
|
||||
assert documents[0].raw_data == {
|
||||
"title": "Hello",
|
||||
"text": "Hi how are you today?",
|
||||
}
|
||||
assert documents[1].title == "input.jsonl (1)"
|
||||
|
||||
|
||||
async def test_jsonl_loader_one_file_with_title():
|
||||
config = InputConfig(
|
||||
type=InputType.JsonLines,
|
||||
title_column="title",
|
||||
)
|
||||
storage = create_storage(
|
||||
StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-jsonl",
|
||||
)
|
||||
)
|
||||
reader = create_input_reader(config, storage)
|
||||
documents = await reader.read_files()
|
||||
assert len(documents) == 3
|
||||
assert documents[0].title == "Hello"
|
||||
26
tests/unit/indexing/input/test_markitdown_loader.py
Normal file
26
tests/unit/indexing/input/test_markitdown_loader.py
Normal file
@ -0,0 +1,26 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag_input import InputConfig, InputType, create_input_reader
|
||||
from graphrag_storage import StorageConfig, create_storage
|
||||
|
||||
|
||||
# these tests just confirm we can load files with MarkItDown,
|
||||
# and use html specifically because it requires no additional dependency installation
|
||||
async def test_markitdown_loader_one_file():
|
||||
config = InputConfig(
|
||||
type=InputType.MarkItDown,
|
||||
file_pattern=".*\\.html$",
|
||||
)
|
||||
storage = create_storage(
|
||||
StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-html",
|
||||
)
|
||||
)
|
||||
reader = create_input_reader(config, storage)
|
||||
documents = await reader.read_files()
|
||||
assert len(documents) == 1
|
||||
# markitdown will extract the title and body from the HTML if present and clean them
|
||||
assert documents[0].title == "Test"
|
||||
assert documents[0].text == "Hi how are you today?"
|
||||
assert documents[0].raw_data is None
|
||||
76
tests/unit/indexing/input/test_text_document.py
Normal file
76
tests/unit/indexing/input/test_text_document.py
Normal file
@ -0,0 +1,76 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
import pytest
|
||||
from graphrag_input import get_property
|
||||
|
||||
|
||||
def test_get_property_single_level():
|
||||
data = {"foo": "bar"}
|
||||
assert get_property(data, "foo") == "bar"
|
||||
|
||||
|
||||
def test_get_property_two_levels():
|
||||
data = {"foo": {"bar": "baz"}}
|
||||
assert get_property(data, "foo.bar") == "baz"
|
||||
|
||||
|
||||
def test_get_property_three_levels():
|
||||
data = {"a": {"b": {"c": "value"}}}
|
||||
assert get_property(data, "a.b.c") == "value"
|
||||
|
||||
|
||||
def test_get_property_returns_dict():
|
||||
data = {"foo": {"bar": {"baz": "qux"}}}
|
||||
result = get_property(data, "foo.bar")
|
||||
assert result == {"baz": "qux"}
|
||||
|
||||
|
||||
def test_get_property_missing_key_raises():
|
||||
data = {"foo": "bar"}
|
||||
with pytest.raises(KeyError):
|
||||
get_property(data, "missing")
|
||||
|
||||
|
||||
def test_get_property_missing_nested_key_raises():
|
||||
data = {"foo": {"bar": "baz"}}
|
||||
with pytest.raises(KeyError):
|
||||
get_property(data, "foo.missing")
|
||||
|
||||
|
||||
def test_get_property_non_dict_intermediate_raises():
|
||||
data = {"foo": "bar"}
|
||||
with pytest.raises(KeyError):
|
||||
get_property(data, "foo.bar")
|
||||
|
||||
|
||||
def test_get_property_empty_dict_raises():
|
||||
data = {}
|
||||
with pytest.raises(KeyError):
|
||||
get_property(data, "foo")
|
||||
|
||||
|
||||
def test_get_property_with_none_value():
|
||||
data = {"foo": None}
|
||||
assert get_property(data, "foo") is None
|
||||
|
||||
|
||||
def test_get_property_with_list_value():
|
||||
data = {"foo": [1, 2, 3]}
|
||||
assert get_property(data, "foo") == [1, 2, 3]
|
||||
|
||||
|
||||
def test_get_property_list_intermediate_raises():
|
||||
data = {"foo": [{"bar": "baz"}]}
|
||||
with pytest.raises(KeyError):
|
||||
get_property(data, "foo.bar")
|
||||
|
||||
|
||||
def test_get_property_numeric_value():
|
||||
data = {"count": 42}
|
||||
assert get_property(data, "count") == 42
|
||||
|
||||
|
||||
def test_get_property_boolean_value():
|
||||
data = {"enabled": True}
|
||||
assert get_property(data, "enabled") is True
|
||||
36
tests/unit/indexing/input/test_text_loader.py
Normal file
36
tests/unit/indexing/input/test_text_loader.py
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag_input import InputConfig, InputType, create_input_reader
|
||||
from graphrag_storage import StorageConfig, create_storage
|
||||
|
||||
|
||||
async def test_text_loader_one_file():
|
||||
config = InputConfig(
|
||||
type=InputType.Text,
|
||||
file_pattern=".*\\.txt$",
|
||||
)
|
||||
storage = create_storage(
|
||||
StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-txt",
|
||||
)
|
||||
)
|
||||
reader = create_input_reader(config, storage)
|
||||
documents = await reader.read_files()
|
||||
assert len(documents) == 1
|
||||
assert documents[0].title == "input.txt"
|
||||
assert documents[0].raw_data is None
|
||||
|
||||
|
||||
async def test_text_loader_multiple_files():
|
||||
config = InputConfig(
|
||||
type=InputType.Text,
|
||||
)
|
||||
storage = create_storage(
|
||||
StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/multiple-txts",
|
||||
)
|
||||
)
|
||||
reader = create_input_reader(config, storage)
|
||||
documents = await reader.read_files()
|
||||
assert len(documents) == 2
|
||||
@ -1,62 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.config.enums import InputFileType
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.factory import InputReaderFactory
|
||||
from graphrag_storage import StorageConfig, create_storage
|
||||
|
||||
|
||||
async def test_txt_loader_one_file():
|
||||
config = InputConfig(
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-txt",
|
||||
),
|
||||
file_type=InputFileType.text,
|
||||
file_pattern=".*\\.txt$",
|
||||
)
|
||||
storage = create_storage(config.storage)
|
||||
documents = (
|
||||
await InputReaderFactory()
|
||||
.create(config.file_type, {"storage": storage, "config": config})
|
||||
.read_files()
|
||||
)
|
||||
assert documents.shape == (1, 4)
|
||||
assert documents["title"].iloc[0] == "input.txt"
|
||||
|
||||
|
||||
async def test_txt_loader_one_file_with_metadata():
|
||||
config = InputConfig(
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-txt",
|
||||
),
|
||||
file_type=InputFileType.text,
|
||||
file_pattern=".*\\.txt$",
|
||||
metadata=["title"],
|
||||
)
|
||||
storage = create_storage(config.storage)
|
||||
documents = (
|
||||
await InputReaderFactory()
|
||||
.create(config.file_type, {"storage": storage, "config": config})
|
||||
.read_files()
|
||||
)
|
||||
assert documents.shape == (1, 5)
|
||||
# unlike csv, we cannot set the title to anything other than the filename
|
||||
assert documents["metadata"][0] == {"title": "input.txt"}
|
||||
|
||||
|
||||
async def test_txt_loader_multiple_files():
|
||||
config = InputConfig(
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/multiple-txts",
|
||||
),
|
||||
file_type=InputFileType.text,
|
||||
file_pattern=".*\\.txt$",
|
||||
)
|
||||
storage = create_storage(config.storage)
|
||||
documents = (
|
||||
await InputReaderFactory()
|
||||
.create(config.file_type, {"storage": storage, "config": config})
|
||||
.read_files()
|
||||
)
|
||||
assert documents.shape == (2, 4)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -10,7 +10,6 @@ from .util import (
|
||||
compare_outputs,
|
||||
create_test_context,
|
||||
load_test_table,
|
||||
update_document_metadata,
|
||||
)
|
||||
|
||||
|
||||
@ -20,26 +19,18 @@ async def test_create_base_text_units():
|
||||
context = await create_test_context()
|
||||
|
||||
config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore
|
||||
config.chunking.prepend_metadata = ["title"]
|
||||
|
||||
await run_workflow(config, context)
|
||||
|
||||
actual = await load_table_from_storage("text_units", context.output_storage)
|
||||
|
||||
print("EXPECTED")
|
||||
print(expected.columns)
|
||||
print(expected)
|
||||
|
||||
print("ACTUAL")
|
||||
print(actual.columns)
|
||||
print(actual)
|
||||
|
||||
compare_outputs(actual, expected, columns=["text", "document_id", "n_tokens"])
|
||||
|
||||
|
||||
async def test_create_base_text_units_metadata():
|
||||
expected = load_test_table("text_units_metadata")
|
||||
|
||||
context = await create_test_context()
|
||||
|
||||
config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore
|
||||
config.input.metadata = ["title"]
|
||||
config.chunking.prepend_metadata = True
|
||||
|
||||
await update_document_metadata(config.input.metadata, context)
|
||||
|
||||
await run_workflow(config, context)
|
||||
|
||||
actual = await load_table_from_storage("text_units", context.output_storage)
|
||||
compare_outputs(actual, expected, ["text", "document_id", "n_tokens"])
|
||||
|
||||
@ -4,15 +4,14 @@
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.data_model.schemas import COMMUNITY_REPORTS_FINAL_COLUMNS
|
||||
from graphrag.index.workflows.create_community_reports import (
|
||||
run_workflow,
|
||||
)
|
||||
from graphrag.utils.storage import load_table_from_storage
|
||||
|
||||
from graphrag.index.operations.summarize_communities.community_reports_extractor import (
|
||||
CommunityReportResponse,
|
||||
FindingModel,
|
||||
)
|
||||
from graphrag.index.workflows.create_community_reports import (
|
||||
run_workflow,
|
||||
)
|
||||
from graphrag.utils.storage import load_table_from_storage
|
||||
|
||||
from .util import (
|
||||
DEFAULT_MODEL_CONFIG,
|
||||
|
||||
@ -13,7 +13,6 @@ from .util import (
|
||||
compare_outputs,
|
||||
create_test_context,
|
||||
load_test_table,
|
||||
update_document_metadata,
|
||||
)
|
||||
|
||||
|
||||
@ -34,26 +33,3 @@ async def test_create_final_documents():
|
||||
|
||||
for column in DOCUMENTS_FINAL_COLUMNS:
|
||||
assert column in actual.columns
|
||||
|
||||
|
||||
async def test_create_final_documents_with_metadata_column():
|
||||
context = await create_test_context(
|
||||
storage=["text_units"],
|
||||
)
|
||||
|
||||
config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore
|
||||
config.input.metadata = ["title"]
|
||||
|
||||
# simulate the metadata construction during initial input loading
|
||||
await update_document_metadata(config.input.metadata, context)
|
||||
|
||||
expected = await load_table_from_storage("documents", context.output_storage)
|
||||
|
||||
await run_workflow(config, context)
|
||||
|
||||
actual = await load_table_from_storage("documents", context.output_storage)
|
||||
|
||||
compare_outputs(actual, expected)
|
||||
|
||||
for column in DOCUMENTS_FINAL_COLUMNS:
|
||||
assert column in actual.columns
|
||||
|
||||
@ -28,4 +28,4 @@ async def test_prune_graph():
|
||||
|
||||
nodes_actual = await load_table_from_storage("entities", context.output_storage)
|
||||
|
||||
assert len(nodes_actual) == 21
|
||||
assert len(nodes_actual) == 29
|
||||
|
||||
@ -5,7 +5,7 @@ import graphrag.config.defaults as defs
|
||||
import pandas as pd
|
||||
from graphrag.index.run.utils import create_run_context
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
from graphrag.utils.storage import write_table_to_storage
|
||||
from pandas.testing import assert_series_equal
|
||||
|
||||
pd.set_option("display.max_columns", None)
|
||||
@ -87,12 +87,3 @@ def compare_outputs(
|
||||
print("Actual:")
|
||||
print(actual[column])
|
||||
raise
|
||||
|
||||
|
||||
async def update_document_metadata(metadata: list[str], context: PipelineRunContext):
|
||||
"""Takes the default documents and adds the configured metadata columns for later parsing by the text units and final documents workflows."""
|
||||
documents = await load_table_from_storage("documents", context.output_storage)
|
||||
documents["metadata"] = documents[metadata].apply(lambda row: row.to_dict(), axis=1)
|
||||
await write_table_to_storage(
|
||||
documents, "documents", context.output_storage
|
||||
) # write to the runtime context storage only
|
||||
|
||||
@ -7,6 +7,7 @@ import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import graphrag.api as api
|
||||
import streamlit as st
|
||||
from knowledge_loader.data_sources.loader import (
|
||||
create_datasource,
|
||||
@ -17,8 +18,6 @@ from rag.typing import SearchResult, SearchType
|
||||
from state.session_variables import SessionVariables
|
||||
from ui.search import display_search_result
|
||||
|
||||
import graphrag.api as api
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user