Update input factory to match other factories

This commit is contained in:
Nathan Evans 2026-01-05 16:35:14 -08:00
parent fde14b63e5
commit 99aea5226e
18 changed files with 176 additions and 147 deletions

View File

@ -26,7 +26,7 @@ def register_cache(
cache_initializer: Callable[..., Cache],
scope: ServiceScope = "transient",
) -> None:
"""Register a custom storage implementation.
"""Register a custom cache implementation.
Args
----

View File

@ -5,7 +5,7 @@
from __future__ import annotations
from enum import Enum
from enum import Enum, StrEnum
class CacheType(str, Enum):
@ -27,7 +27,7 @@ class CacheType(str, Enum):
return f'"{self.value}"'
class InputFileType(str, Enum):
class InputFileType(StrEnum):
"""The input file type for the pipeline."""
csv = "csv"

View File

@ -113,7 +113,7 @@ class GraphRagConfig(BaseModel):
if self.input.file_type == defs.InputFileType.text:
self.input.file_pattern = ".*\\.txt$"
else:
self.input.file_pattern = f".*\\.{self.input.file_type.value}$"
self.input.file_pattern = f".*\\.{self.input.file_type}$"
def _validate_input_base_dir(self) -> None:
"""Validate the input base directory."""

View File

@ -4,23 +4,25 @@
"""Parameterization settings for the default configuration."""
from graphrag_storage import StorageConfig
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
import graphrag.config.defaults as defs
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import InputFileType
class InputConfig(BaseModel):
"""The default configuration section for Input."""
model_config = ConfigDict(extra="allow")
"""Allow extra fields to support custom cache implementations."""
storage: StorageConfig = Field(
description="The storage configuration to use for reading input documents.",
default=StorageConfig(
base_dir=graphrag_config_defaults.input.storage.base_dir,
),
)
file_type: InputFileType = Field(
file_type: str = Field(
description="The input file type to use.",
default=graphrag_config_defaults.input.file_type,
)

View File

@ -28,8 +28,8 @@ class CSVFileReader(InputReader):
- output - DataFrame with a row for each document in the file.
"""
buffer = BytesIO(await self._storage.get(path, as_bytes=True))
data = pd.read_csv(buffer, encoding=self._config.encoding)
data = process_data_columns(data, self._config, path)
data = pd.read_csv(buffer, encoding=self._encoding)
data = process_data_columns(data, path, self._text_column, self._title_column)
creation_date = await self._storage.get_creation_date(path)
data["creation_date"] = data.apply(lambda _: creation_date, axis=1)
return data

View File

@ -1,26 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'InputReaderFactory' model."""
import logging
from graphrag_common.factory import Factory
from graphrag.config.enums import InputFileType
from graphrag.index.input.csv import CSVFileReader
from graphrag.index.input.input_reader import InputReader
from graphrag.index.input.json import JSONFileReader
from graphrag.index.input.text import TextFileReader
logger = logging.getLogger(__name__)
class InputReaderFactory(Factory[InputReader]):
"""Factory for creating Input Reader instances."""
input_reader_factory = InputReaderFactory()
input_reader_factory.register(InputFileType.text, TextFileReader)
input_reader_factory.register(InputFileType.csv, CSVFileReader)
input_reader_factory.register(InputFileType.json, JSONFileReader)

View File

@ -15,25 +15,39 @@ import pandas as pd
if TYPE_CHECKING:
from graphrag_storage import Storage
from graphrag.config.models.input_config import InputConfig
logger = logging.getLogger(__name__)
class InputReader(metaclass=ABCMeta):
"""Provide a cache interface for the pipeline."""
def __init__(self, storage: Storage, config: InputConfig, **kwargs):
def __init__(
self,
storage: Storage,
file_type: str,
file_pattern: str,
encoding: str = "utf-8",
text_column: str | None = None,
title_column: str | None = None,
metadata: list[str] | None = None,
**kwargs,
):
self._storage = storage
self._config = config
self._file_type = file_type
self._file_pattern = file_pattern
self._encoding = encoding
self._text_column = text_column
self._title_column = title_column
self._metadata = metadata
async def read_files(self) -> pd.DataFrame:
"""Load files from storage and apply a loader function based on file type. Process metadata on the results if needed."""
files = list(self._storage.find(re.compile(self._config.file_pattern)))
files = list(self._storage.find(re.compile(self._file_pattern)))
if len(files) == 0:
msg = f"No {self._config.file_type} files found in {self._config.storage.base_dir}"
raise ValueError(msg)
msg = f"No {self._file_type} files found in storage" # TODO: use a storage __str__ to define it per impl
logger.warning(msg)
files = []
files_loaded = []
@ -47,19 +61,19 @@ class InputReader(metaclass=ABCMeta):
logger.info(
"Found %d %s files, loading %d",
len(files),
self._config.file_type,
self._file_type,
len(files_loaded),
)
result = pd.concat(files_loaded)
total_files_log = (
f"Total number of unfiltered {self._config.file_type} rows: {len(result)}"
f"Total number of unfiltered {self._file_type} rows: {len(result)}"
)
logger.info(total_files_log)
# Convert metadata columns to strings and collapse them into a JSON object
if self._config.metadata:
if all(col in result.columns for col in self._config.metadata):
if self._metadata:
if all(col in result.columns for col in self._metadata):
# Collapse the metadata columns into a single JSON object column
result["metadata"] = result[self._config.metadata].apply(
result["metadata"] = result[self._metadata].apply(
lambda row: row.to_dict(), axis=1
)
else:
@ -68,7 +82,7 @@ class InputReader(metaclass=ABCMeta):
)
raise ValueError(value_error_msg)
result[self._config.metadata] = result[self._config.metadata].astype(str)
result[self._metadata] = result[self._metadata].astype(str)
return result

View File

@ -0,0 +1,82 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'InputReaderFactory' model."""
import logging
from collections.abc import Callable
from graphrag_common.factory import Factory
from graphrag_common.factory.factory import ServiceScope
from graphrag_storage.storage import Storage
from graphrag.config.enums import InputFileType
from graphrag.config.models.input_config import InputConfig
from graphrag.index.input.input_reader import InputReader
logger = logging.getLogger(__name__)
class InputReaderFactory(Factory[InputReader]):
"""Factory for creating Input Reader instances."""
input_reader_factory = InputReaderFactory()
def register_input_reader(
input_reader_type: str,
input_reader_initializer: Callable[..., InputReader],
scope: ServiceScope = "transient",
) -> None:
"""Register a custom input reader implementation.
Args
----
- input_reader_type: str
The input reader id to register.
- input_reader_initializer: Callable[..., InputReader]
The input reader initializer to register.
"""
input_reader_factory.register(input_reader_type, input_reader_initializer, scope)
def create_input_reader(config: InputConfig, storage: Storage) -> InputReader:
"""Create an input reader implementation based on the given configuration.
Args
----
- config: InputConfig
The input reader configuration to use.
- storage: Storage | None
The storage implementation to use for reading the files.
Returns
-------
InputReader
The created input reader implementation.
"""
config_model = config.model_dump()
input_strategy = config.file_type
if input_strategy not in input_reader_factory:
match input_strategy:
case InputFileType.csv:
from graphrag.index.input.csv import CSVFileReader
register_input_reader(InputFileType.csv, CSVFileReader)
case InputFileType.text:
from graphrag.index.input.text import TextFileReader
register_input_reader(InputFileType.text, TextFileReader)
case InputFileType.json:
from graphrag.index.input.json import JSONFileReader
register_input_reader(InputFileType.json, JSONFileReader)
case _:
msg = f"InputConfig.file_type '{input_strategy}' is not registered in the InputReaderFactory. Registered types: {', '.join(input_reader_factory.keys())}."
raise ValueError(msg)
config_model["storage"] = storage
return input_reader_factory.create(input_strategy, init_args=config_model)

View File

@ -27,12 +27,12 @@ class JSONFileReader(InputReader):
-------
- output - DataFrame with a row for each document in the file.
"""
text = await self._storage.get(path, encoding=self._config.encoding)
text = await self._storage.get(path, encoding=self._encoding)
as_json = json.loads(text)
# json file could just be a single object, or an array of objects
rows = as_json if isinstance(as_json, list) else [as_json]
data = pd.DataFrame(rows)
data = process_data_columns(data, self._config, path)
data = process_data_columns(data, path, self._text_column, self._title_column)
creation_date = await self._storage.get_creation_date(path)
data["creation_date"] = data.apply(lambda _: creation_date, axis=1)

View File

@ -27,7 +27,7 @@ class TextFileReader(InputReader):
-------
- output - DataFrame with a row for each document in the file.
"""
text = await self._storage.get(path, encoding=self._config.encoding)
text = await self._storage.get(path, encoding=self._encoding)
new_item = {"text": text}
new_item["id"] = gen_sha512_hash(new_item, new_item.keys())
new_item["title"] = str(Path(path).name)

View File

@ -7,40 +7,40 @@ import logging
import pandas as pd
from graphrag.config.models.input_config import InputConfig
from graphrag.index.utils.hashing import gen_sha512_hash
logger = logging.getLogger(__name__)
def process_data_columns(
documents: pd.DataFrame, config: InputConfig, path: str
documents: pd.DataFrame,
path: str,
text_column: str | None,
title_column: str | None,
) -> pd.DataFrame:
"""Process configured data columns of a DataFrame."""
if "id" not in documents.columns:
documents["id"] = documents.apply(
lambda x: gen_sha512_hash(x, x.keys()), axis=1
)
if config.text_column is not None and "text" not in documents.columns:
if config.text_column not in documents.columns:
if text_column is not None and "text" not in documents.columns:
if text_column not in documents.columns:
logger.warning(
"text_column %s not found in csv file %s",
config.text_column,
text_column,
path,
)
else:
documents["text"] = documents.apply(lambda x: x[config.text_column], axis=1)
if config.title_column is not None:
if config.title_column not in documents.columns:
documents["text"] = documents.apply(lambda x: x[text_column], axis=1)
if title_column is not None:
if title_column not in documents.columns:
logger.warning(
"title_column %s not found in csv file %s",
config.title_column,
title_column,
path,
)
else:
documents["title"] = documents.apply(
lambda x: x[config.title_column], axis=1
)
documents["title"] = documents.apply(lambda x: x[title_column], axis=1)
else:
documents["title"] = documents.apply(lambda _: path, axis=1)
return documents

View File

@ -8,8 +8,10 @@ import logging
import pandas as pd
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.input.factory import InputReaderFactory
from graphrag.index.input.input_reader import InputReader
from graphrag.index.input.input_reader_factory import (
create_input_reader,
)
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import write_table_to_storage
@ -22,10 +24,7 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""Load and parse input documents into a standard format."""
input_reader = InputReaderFactory().create(
config.input.file_type,
{"storage": context.input_storage, "config": config.input},
)
input_reader = create_input_reader(config.input, context.input_storage)
output = await load_input_documents(input_reader)

View File

@ -9,8 +9,8 @@ import pandas as pd
from graphrag_storage import Storage
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.input.factory import InputReaderFactory
from graphrag.index.input.input_reader import InputReader
from graphrag.index.input.input_reader_factory import create_input_reader
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.index.update.incremental_index import get_delta_docs
@ -24,10 +24,7 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""Load and parse update-only input documents into a standard format."""
input_reader = InputReaderFactory().create(
config.input.file_type,
{"storage": context.input_storage, "config": config.input},
)
input_reader = create_input_reader(config.input, context.input_storage)
output = await load_update_documents(
input_reader,
context.previous_storage,

View File

@ -13,7 +13,7 @@ from graphrag_storage import create_storage
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.input.factory import InputReaderFactory
from graphrag.index.input.input_reader_factory import create_input_reader
from graphrag.index.operations.embed_text.run_embed_text import (
run_embed_text,
)
@ -64,10 +64,7 @@ async def load_docs_in_chunks(
)
tokenizer = get_tokenizer(embeddings_llm_settings)
input_storage = create_storage(config.input.storage)
input_reader = InputReaderFactory().create(
config.input.file_type,
{"storage": input_storage, "config": config.input},
)
input_reader = create_input_reader(config.input, input_storage)
dataset = await input_reader.read_files()
chunk_config = config.chunks
chunks_df = create_base_text_units(

View File

@ -3,7 +3,9 @@
from graphrag.config.enums import InputFileType
from graphrag.config.models.input_config import InputConfig
from graphrag.index.input.factory import InputReaderFactory
from graphrag.index.input.input_reader_factory import (
create_input_reader,
)
from graphrag_storage import StorageConfig, create_storage
@ -16,11 +18,8 @@ async def test_csv_loader_one_file():
file_pattern=".*\\.csv$",
)
storage = create_storage(config.storage)
documents = (
await InputReaderFactory()
.create(config.file_type, {"storage": storage, "config": config})
.read_files()
)
reader = create_input_reader(config, storage)
documents = await reader.read_files()
assert documents.shape == (2, 4)
assert documents["title"].iloc[0] == "input.csv"
@ -35,11 +34,8 @@ async def test_csv_loader_one_file_with_title():
title_column="title",
)
storage = create_storage(config.storage)
documents = (
await InputReaderFactory()
.create(config.file_type, {"storage": storage, "config": config})
.read_files()
)
reader = create_input_reader(config, storage)
documents = await reader.read_files()
assert documents.shape == (2, 4)
assert documents["title"].iloc[0] == "Hello"
@ -55,12 +51,8 @@ async def test_csv_loader_one_file_with_metadata():
metadata=["title"],
)
storage = create_storage(config.storage)
documents = (
await InputReaderFactory()
.create(config.file_type, {"storage": storage, "config": config})
.read_files()
)
print(documents)
reader = create_input_reader(config, storage)
documents = await reader.read_files()
assert documents.shape == (2, 5)
assert documents["metadata"][0] == {"title": "Hello"}
@ -74,9 +66,6 @@ async def test_csv_loader_multiple_files():
file_pattern=".*\\.csv$",
)
storage = create_storage(config.storage)
documents = (
await InputReaderFactory()
.create(config.file_type, {"storage": storage, "config": config})
.read_files()
)
reader = create_input_reader(config, storage)
documents = await reader.read_files()
assert documents.shape == (4, 4)

View File

@ -3,7 +3,7 @@
from graphrag.config.enums import InputFileType
from graphrag.config.models.input_config import InputConfig
from graphrag.index.input.factory import InputReaderFactory
from graphrag.index.input.input_reader_factory import create_input_reader
from graphrag_storage import StorageConfig, create_storage
@ -16,11 +16,8 @@ async def test_json_loader_one_file_one_object():
file_pattern=".*\\.json$",
)
storage = create_storage(config.storage)
documents = (
await InputReaderFactory()
.create(config.file_type, {"storage": storage, "config": config})
.read_files()
)
reader = create_input_reader(config, storage)
documents = await reader.read_files()
assert documents.shape == (1, 4)
assert documents["title"].iloc[0] == "input.json"
@ -34,11 +31,8 @@ async def test_json_loader_one_file_multiple_objects():
file_pattern=".*\\.json$",
)
storage = create_storage(config.storage)
documents = (
await InputReaderFactory()
.create(config.file_type, {"storage": storage, "config": config})
.read_files()
)
reader = create_input_reader(config, storage)
documents = await reader.read_files()
print(documents)
assert documents.shape == (3, 4)
assert documents["title"].iloc[0] == "input.json"
@ -54,11 +48,8 @@ async def test_json_loader_one_file_with_title():
title_column="title",
)
storage = create_storage(config.storage)
documents = (
await InputReaderFactory()
.create(config.file_type, {"storage": storage, "config": config})
.read_files()
)
reader = create_input_reader(config, storage)
documents = await reader.read_files()
assert documents.shape == (1, 4)
assert documents["title"].iloc[0] == "Hello"
@ -74,11 +65,8 @@ async def test_json_loader_one_file_with_metadata():
metadata=["title"],
)
storage = create_storage(config.storage)
documents = (
await InputReaderFactory()
.create(config.file_type, {"storage": storage, "config": config})
.read_files()
)
reader = create_input_reader(config, storage)
documents = await reader.read_files()
assert documents.shape == (1, 5)
assert documents["metadata"][0] == {"title": "Hello"}
@ -92,9 +80,6 @@ async def test_json_loader_multiple_files():
file_pattern=".*\\.json$",
)
storage = create_storage(config.storage)
documents = (
await InputReaderFactory()
.create(config.file_type, {"storage": storage, "config": config})
.read_files()
)
reader = create_input_reader(config, storage)
documents = await reader.read_files()
assert documents.shape == (4, 4)

View File

@ -3,7 +3,7 @@
from graphrag.config.enums import InputFileType
from graphrag.config.models.input_config import InputConfig
from graphrag.index.input.factory import InputReaderFactory
from graphrag.index.input.input_reader_factory import create_input_reader
from graphrag_storage import StorageConfig, create_storage
@ -16,11 +16,8 @@ async def test_txt_loader_one_file():
file_pattern=".*\\.txt$",
)
storage = create_storage(config.storage)
documents = (
await InputReaderFactory()
.create(config.file_type, {"storage": storage, "config": config})
.read_files()
)
reader = create_input_reader(config, storage)
documents = await reader.read_files()
assert documents.shape == (1, 4)
assert documents["title"].iloc[0] == "input.txt"
@ -35,11 +32,8 @@ async def test_txt_loader_one_file_with_metadata():
metadata=["title"],
)
storage = create_storage(config.storage)
documents = (
await InputReaderFactory()
.create(config.file_type, {"storage": storage, "config": config})
.read_files()
)
reader = create_input_reader(config, storage)
documents = await reader.read_files()
assert documents.shape == (1, 5)
# unlike csv, we cannot set the title to anything other than the filename
assert documents["metadata"][0] == {"title": "input.txt"}
@ -54,9 +48,6 @@ async def test_txt_loader_multiple_files():
file_pattern=".*\\.txt$",
)
storage = create_storage(config.storage)
documents = (
await InputReaderFactory()
.create(config.file_type, {"storage": storage, "config": config})
.read_files()
)
reader = create_input_reader(config, storage)
documents = await reader.read_files()
assert documents.shape == (2, 4)

View File

@ -4,15 +4,14 @@
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.schemas import COMMUNITY_REPORTS_FINAL_COLUMNS
from graphrag.index.workflows.create_community_reports import (
run_workflow,
)
from graphrag.utils.storage import load_table_from_storage
from graphrag.index.operations.summarize_communities.community_reports_extractor import (
CommunityReportResponse,
FindingModel,
)
from graphrag.index.workflows.create_community_reports import (
run_workflow,
)
from graphrag.utils.storage import load_table_from_storage
from .util import (
DEFAULT_MODEL_CONFIG,