Cleanup factory methods (#1482)

* cleanup factory methods to have similar design pattern across codebase

* add semversioner file

* cleanup logging factory

* update developer guide

* add comment

* typo fix

* cleanup reporter terminology

* renmae reporter to logger

* fix comments

* update comment

* instantiate factory classes correctly and update index api callback parameter

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
Josh Bradley 2024-12-10 17:11:11 -05:00 committed by GitHub
parent 04405803db
commit 823342188d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
51 changed files with 1246 additions and 1149 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "cleanup and refactor factory classes."
}

View File

@ -10,24 +10,56 @@
# Getting Started
## Install Dependencies
```sh
# Install Python dependencies.
```shell
# install python dependencies
poetry install
```
## Executing the Indexing Engine
```sh
## Execute the indexing engine
```shell
poetry run poe index <...args>
```
## Executing Queries
## Execute prompt tuning
```shell
poetry run poe prompt_tune <...args>
```
```sh
## Execute Queries
```shell
poetry run poe query <...args>
```
## Repository Structure
An overview of the repository's top-level folder structure is provided below, detailing the overall design and purpose.
We leverage a factory design pattern where possible, enabling a variety of implementations for each core component of graphrag.
```shell
graphrag
├── api # library API definitions
├── cache # cache module supporting several options
│   └─ factory.py # └─ main entrypoint to create a cache
├── callbacks # a collection of commonly used callback functions
├── cli # library CLI
│   └─ main.py # └─ primary CLI entrypoint
├── config # configuration management
├── index # indexing engine
| └─ run/run.py # main entrypoint to build an index
├── llm # generic llm interfaces
├── logger # logger module supporting several options
│   └─ factory.py # └─ main entrypoint to create a logger
├── model # data model definitions associated with the knowledge graph
├── prompt_tune # prompt tuning module
├── prompts # a collection of all the system prompts used by graphrag
├── query # query engine
├── storage # storage module supporting several options
│   └─ factory.py # └─ main entrypoint to create/load a storage endpoint
├── utils # helper functions used throughout the library
└── vector_stores # vector store module containing a few options
└─ factory.py # └─ main entrypoint to create a vector store
```
Where appropriate, the factories expose a registration method for users to provide their own custom implementations if desired.
## Versioning
We use [semversioner](https://github.com/raulgomis/semversioner) to automate and enforce semantic versioning in the release process. Our CI/CD pipeline checks that all PR's include a json file generated by semversioner. When submitting a PR, please run:

View File

@ -156,7 +156,7 @@ This section controls the storage mechanism used by the pipeline used for export
| Parameter | Description | Type | Required or Optional | Default |
| ------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- | -------------------- | ------- |
| `GRAPHRAG_STORAGE_TYPE` | The type of reporter to use. Options are `file`, `memory`, or `blob` | `str` | optional | `file` |
| `GRAPHRAG_STORAGE_TYPE` | The type of storage to use. Options are `file`, `memory`, or `blob` | `str` | optional | `file` |
| `GRAPHRAG_STORAGE_STORAGE_ACCOUNT_BLOB_URL` | The Azure Storage blob endpoint to use when in `blob` mode and using managed identity. Will have the format `https://<storage_account_name>.blob.core.windows.net` | `str` | optional | None |
| `GRAPHRAG_STORAGE_CONNECTION_STRING` | The Azure Storage connection string to use when in `blob` mode. | `str` | optional | None |
| `GRAPHRAG_STORAGE_CONTAINER_NAME` | The Azure Storage container name to use when in `blob` mode. | `str` | optional | None |

View File

@ -17,7 +17,7 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.create_pipeline_config import create_pipeline_config
from graphrag.index.run import run_pipeline_with_config
from graphrag.index.typing import PipelineRunResult
from graphrag.logging.base import ProgressReporter
from graphrag.logger.base import ProgressLogger
async def build_index(
@ -26,7 +26,7 @@ async def build_index(
is_resume_run: bool = False,
memory_profile: bool = False,
callbacks: list[WorkflowCallbacks] | None = None,
progress_reporter: ProgressReporter | None = None,
progress_logger: ProgressLogger | None = None,
) -> list[PipelineRunResult]:
"""Run the pipeline with the given configuration.
@ -42,8 +42,8 @@ async def build_index(
Whether to enable memory profiling.
callbacks : list[WorkflowCallbacks] | None default=None
A list of callbacks to register.
progress_reporter : ProgressReporter | None default=None
The progress reporter.
progress_logger : ProgressLogger | None default=None
The progress logger.
Returns
-------
@ -60,10 +60,10 @@ async def build_index(
pipeline_cache = (
NoopPipelineCache() if config.cache.type == CacheType.none is None else None
)
# create a pipeline reporter and add to any additional callbacks
# TODO: remove the type ignore once the new config engine has been refactored
callbacks = (
[create_pipeline_reporter(config.reporting, None)] if config.reporting else None # type: ignore
) # type: ignore
callbacks = callbacks or []
callbacks.append(create_pipeline_reporter(config.reporting, None)) # type: ignore
outputs: list[PipelineRunResult] = []
async for output in run_pipeline_with_config(
pipeline_config,
@ -71,15 +71,15 @@ async def build_index(
memory_profile=memory_profile,
cache=pipeline_cache,
callbacks=callbacks,
progress_reporter=progress_reporter,
logger=progress_logger,
is_resume_run=is_resume_run,
is_update_run=is_update_run,
):
outputs.append(output)
if progress_reporter:
if progress_logger:
if output.errors and len(output.errors) > 0:
progress_reporter.error(output.workflow)
progress_logger.error(output.workflow)
else:
progress_reporter.success(output.workflow)
progress_reporter.info(str(output.result))
progress_logger.success(output.workflow)
progress_logger.info(str(output.result))
return outputs

View File

@ -16,7 +16,7 @@ from pydantic import PositiveInt, validate_call
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.llm.load_llm import load_llm
from graphrag.logging.print_progress import PrintProgressReporter
from graphrag.logger.print_progress import PrintProgressLogger
from graphrag.prompt_tune.defaults import MAX_TOKEN_COUNT
from graphrag.prompt_tune.generator.community_report_rating import (
generate_community_report_rating,
@ -80,7 +80,7 @@ async def generate_indexing_prompts(
-------
tuple[str, str, str]: entity extraction prompt, entity summarization prompt, community summarization prompt
"""
reporter = PrintProgressReporter("")
logger = PrintProgressLogger("")
# Retrieve documents
doc_list = await load_docs_in_chunks(
@ -88,7 +88,7 @@ async def generate_indexing_prompts(
config=config,
limit=limit,
select_method=selection_method,
reporter=reporter,
logger=logger,
chunk_size=chunk_size,
n_subset_max=n_subset_max,
k=k,
@ -103,25 +103,25 @@ async def generate_indexing_prompts(
)
if not domain:
reporter.info("Generating domain...")
logger.info("Generating domain...")
domain = await generate_domain(llm, doc_list)
reporter.info(f"Generated domain: {domain}")
logger.info(f"Generated domain: {domain}") # noqa
if not language:
reporter.info("Detecting language...")
logger.info("Detecting language...")
language = await detect_language(llm, doc_list)
reporter.info("Generating persona...")
logger.info("Generating persona...")
persona = await generate_persona(llm, domain)
reporter.info("Generating community report ranking description...")
logger.info("Generating community report ranking description...")
community_report_ranking = await generate_community_report_rating(
llm, domain=domain, persona=persona, docs=doc_list
)
entity_types = None
if discover_entity_types:
reporter.info("Generating entity types...")
logger.info("Generating entity types...")
entity_types = await generate_entity_types(
llm,
domain=domain,
@ -130,7 +130,7 @@ async def generate_indexing_prompts(
json_mode=config.llm.model_supports_json or False,
)
reporter.info("Generating entity relationship examples...")
logger.info("Generating entity relationship examples...")
examples = await generate_entity_relationship_examples(
llm,
persona=persona,
@ -140,7 +140,7 @@ async def generate_indexing_prompts(
json_mode=False, # config.llm.model_supports_json should be used, but these prompts are used in non-json mode by the index engine
)
reporter.info("Generating entity extraction prompt...")
logger.info("Generating entity extraction prompt...")
entity_extraction_prompt = create_entity_extraction_prompt(
entity_types=entity_types,
docs=doc_list,
@ -152,18 +152,18 @@ async def generate_indexing_prompts(
min_examples_required=min_examples_required,
)
reporter.info("Generating entity summarization prompt...")
logger.info("Generating entity summarization prompt...")
entity_summarization_prompt = create_entity_summarization_prompt(
persona=persona,
language=language,
)
reporter.info("Generating community reporter role...")
logger.info("Generating community reporter role...")
community_reporter_role = await generate_community_reporter_role(
llm, domain=domain, persona=persona, docs=doc_list
)
reporter.info("Generating community summarization prompt...")
logger.info("Generating community summarization prompt...")
community_summarization_prompt = create_community_summarization_prompt(
persona=persona,
role=community_reporter_role,

View File

@ -19,7 +19,7 @@ Backwards compatibility is not guaranteed at this time.
from collections.abc import AsyncGenerator
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any
import pandas as pd
from pydantic import validate_call
@ -29,7 +29,7 @@ from graphrag.index.config.embeddings import (
community_full_content_embedding,
entity_description_embedding,
)
from graphrag.logging.print_progress import PrintProgressReporter
from graphrag.logger.print_progress import PrintProgressLogger
from graphrag.query.factory import (
get_drift_search_engine,
get_global_search_engine,
@ -44,13 +44,15 @@ from graphrag.query.indexer_adapters import (
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.query.structured_search.base import SearchResult # noqa: TC001
from graphrag.utils.cli import redact
from graphrag.utils.embeddings import create_collection_name
from graphrag.vector_stores.base import BaseVectorStore
from graphrag.vector_stores.factory import VectorStoreFactory
reporter = PrintProgressReporter("")
if TYPE_CHECKING:
from graphrag.query.structured_search.base import SearchResult
logger = PrintProgressLogger("")
@validate_call(config={"arbitrary_types_allowed": True})
@ -241,7 +243,7 @@ async def local_search(
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
description_embedding_store = _get_embedding_store(
config_args=vector_store_args, # type: ignore
@ -307,7 +309,7 @@ async def local_search_streaming(
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
description_embedding_store = _get_embedding_store(
config_args=vector_store_args, # type: ignore
@ -380,7 +382,7 @@ async def drift_search(
TODO: Document any exceptions to expect.
"""
vector_store_args = config.embeddings.vector_store
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
description_embedding_store = _get_embedding_store(
config_args=vector_store_args, # type: ignore
@ -430,7 +432,7 @@ def _get_embedding_store(
collection_name = create_collection_name(
config_args.get("container_name", "default"), embedding_name
)
embedding_store = VectorStoreFactory.get_vector_store(
embedding_store = VectorStoreFactory().create_vector_store(
vector_store_type=vector_store_type,
kwargs={**config_args, "collection_name": collection_name},
)

View File

@ -5,7 +5,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, ClassVar
from graphrag.config.enums import CacheType
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
@ -13,41 +13,45 @@ from graphrag.storage.file_pipeline_storage import FilePipelineStorage
if TYPE_CHECKING:
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.config.cache import (
PipelineBlobCacheConfig,
PipelineCacheConfig,
PipelineFileCacheConfig,
)
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
from graphrag.cache.memory_pipeline_cache import InMemoryCache
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
def create_cache(
config: PipelineCacheConfig | None, root_dir: str | None
) -> PipelineCache:
"""Create a cache from the given config."""
if config is None:
return NoopPipelineCache()
class CacheFactory:
"""A factory class for cache implementations.
match config.type:
case CacheType.none:
Includes a method for users to register a custom cache implementation.
"""
cache_types: ClassVar[dict[str, type]] = {}
@classmethod
def register(cls, cache_type: str, cache: type):
"""Register a custom cache implementation."""
cls.cache_types[cache_type] = cache
@classmethod
def create_cache(
cls, cache_type: CacheType | str | None, root_dir: str, kwargs: dict
) -> PipelineCache:
"""Create or get a cache from the provided type."""
if not cache_type:
return NoopPipelineCache()
case CacheType.memory:
return InMemoryCache()
case CacheType.file:
config = cast("PipelineFileCacheConfig", config)
storage = FilePipelineStorage(root_dir).child(config.base_dir)
return JsonPipelineCache(storage)
case CacheType.blob:
config = cast("PipelineBlobCacheConfig", config)
storage = BlobPipelineStorage(
config.connection_string,
config.container_name,
storage_account_blob_url=config.storage_account_blob_url,
).child(config.base_dir)
return JsonPipelineCache(storage)
case _:
msg = f"Unknown cache type: {config.type}"
raise ValueError(msg)
match cache_type:
case CacheType.none:
return NoopPipelineCache()
case CacheType.memory:
return InMemoryCache()
case CacheType.file:
return JsonPipelineCache(
FilePipelineStorage(root_dir=root_dir).child(kwargs["base_dir"])
)
case CacheType.blob:
return JsonPipelineCache(BlobPipelineStorage(**kwargs))
case _:
if cache_type in cls.cache_types:
return cls.cache_types[cache_type](**kwargs)
msg = f"Unknown cache type: {cache_type}"
raise ValueError(msg)

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'FilePipelineCache' model."""
"""A module containing 'JsonPipelineCache' model."""
import json
from typing import Any

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Create a pipeline reporter."""
"""Create a pipeline logger."""
from pathlib import Path
from typing import cast
@ -22,7 +22,7 @@ from graphrag.index.config.reporting import (
def create_pipeline_reporter(
config: PipelineReportingConfig | None, root_dir: str | None
) -> WorkflowCallbacks:
"""Create a reporter for the given pipeline config."""
"""Create a logger for the given pipeline config."""
config = config or PipelineFileReportingConfig(base_dir="logs")
match config.type:

View File

@ -7,16 +7,16 @@ from typing import Any
from datashaper import ExecutionNode, NoopWorkflowCallbacks, Progress, TableContainer
from graphrag.logging.base import ProgressReporter
from graphrag.logger.base import ProgressLogger
class ProgressWorkflowCallbacks(NoopWorkflowCallbacks):
"""A callbackmanager that delegates to a ProgressReporter."""
"""A callbackmanager that delegates to a ProgressLogger."""
_root_progress: ProgressReporter
_progress_stack: list[ProgressReporter]
_root_progress: ProgressLogger
_progress_stack: list[ProgressLogger]
def __init__(self, progress: ProgressReporter) -> None:
def __init__(self, progress: ProgressLogger) -> None:
"""Create a new ProgressWorkflowCallbacks."""
self._progress = progress
self._progress_stack = [progress]
@ -28,7 +28,7 @@ class ProgressWorkflowCallbacks(NoopWorkflowCallbacks):
self._progress_stack.append(self._latest.child(name))
@property
def _latest(self) -> ProgressReporter:
def _latest(self) -> ProgressLogger:
return self._progress_stack[-1]
def on_workflow_start(self, name: str, instance: object) -> None:

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""CLI implementation of index subcommand."""
"""CLI implementation of the index subcommand."""
import asyncio
import logging
@ -16,9 +16,8 @@ from graphrag.config.load_config import load_config
from graphrag.config.logging import enable_logging_with_config
from graphrag.config.resolve_path import resolve_paths
from graphrag.index.validate_config import validate_config_names
from graphrag.logging.base import ProgressReporter
from graphrag.logging.factory import create_progress_reporter
from graphrag.logging.types import ReporterType
from graphrag.logger.base import ProgressLogger
from graphrag.logger.factory import LoggerFactory, LoggerType
from graphrag.utils.cli import redact
# Ignore warnings from numba
@ -27,35 +26,35 @@ warnings.filterwarnings("ignore", message=".*NumbaDeprecationWarning.*")
log = logging.getLogger(__name__)
def _logger(reporter: ProgressReporter):
def _logger(logger: ProgressLogger):
def info(msg: str, verbose: bool = False):
log.info(msg)
if verbose:
reporter.info(msg)
logger.info(msg)
def error(msg: str, verbose: bool = False):
log.error(msg)
if verbose:
reporter.error(msg)
logger.error(msg)
def success(msg: str, verbose: bool = False):
log.info(msg)
if verbose:
reporter.success(msg)
logger.success(msg)
return info, error, success
def _register_signal_handlers(reporter: ProgressReporter):
def _register_signal_handlers(logger: ProgressLogger):
import signal
def handle_signal(signum, _):
# Handle the signal here
reporter.info(f"Received signal {signum}, exiting...")
reporter.dispose()
logger.info(f"Received signal {signum}, exiting...") # noqa: G004
logger.dispose()
for task in asyncio.all_tasks():
task.cancel()
reporter.info("All tasks cancelled. Exiting...")
logger.info("All tasks cancelled. Exiting...")
# Register signal handlers for SIGINT and SIGHUP
signal.signal(signal.SIGINT, handle_signal)
@ -70,7 +69,7 @@ def index_cli(
resume: str | None,
memprofile: bool,
cache: bool,
reporter: ReporterType,
logger: LoggerType,
config_filepath: Path | None,
dry_run: bool,
skip_validation: bool,
@ -85,7 +84,7 @@ def index_cli(
resume=resume,
memprofile=memprofile,
cache=cache,
reporter=reporter,
logger=logger,
dry_run=dry_run,
skip_validation=skip_validation,
output_dir=output_dir,
@ -97,7 +96,7 @@ def update_cli(
verbose: bool,
memprofile: bool,
cache: bool,
reporter: ReporterType,
logger: LoggerType,
config_filepath: Path | None,
skip_validation: bool,
output_dir: Path | None,
@ -121,7 +120,7 @@ def update_cli(
resume=False,
memprofile=memprofile,
cache=cache,
reporter=reporter,
logger=logger,
dry_run=False,
skip_validation=skip_validation,
output_dir=output_dir,
@ -134,13 +133,13 @@ def _run_index(
resume,
memprofile,
cache,
reporter,
logger,
dry_run,
skip_validation,
output_dir,
):
progress_reporter = create_progress_reporter(reporter)
info, error, success = _logger(progress_reporter)
progress_logger = LoggerFactory().create_logger(logger)
info, error, success = _logger(progress_logger)
run_id = resume or time.strftime("%Y%m%d-%H%M%S")
config.storage.base_dir = str(output_dir) if output_dir else config.storage.base_dir
@ -162,7 +161,7 @@ def _run_index(
)
if skip_validation:
validate_config_names(progress_reporter, config)
validate_config_names(progress_logger, config)
info(f"Starting pipeline run for: {run_id}, {dry_run=}", verbose)
info(
@ -174,7 +173,7 @@ def _run_index(
info("Dry run complete, exiting...", True)
sys.exit(0)
_register_signal_handlers(progress_reporter)
_register_signal_handlers(progress_logger)
outputs = asyncio.run(
api.build_index(
@ -182,14 +181,14 @@ def _run_index(
run_id=run_id,
is_resume_run=bool(resume),
memory_profile=memprofile,
progress_reporter=progress_reporter,
progress_logger=progress_logger,
)
)
encountered_errors = any(
output.errors and len(output.errors) > 0 for output in outputs
)
progress_reporter.stop()
progress_logger.stop()
if encountered_errors:
error(
"Errors occurred during the pipeline run, see logs for more details.", True

View File

@ -1,13 +1,12 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""CLI implementation of initialization subcommand."""
"""CLI implementation of the initialization subcommand."""
from pathlib import Path
from graphrag.config.init_content import INIT_DOTENV, INIT_YAML
from graphrag.logging.factory import create_progress_reporter
from graphrag.logging.types import ReporterType
from graphrag.logger.factory import LoggerFactory, LoggerType
from graphrag.prompts.index.claim_extraction import CLAIM_EXTRACTION_PROMPT
from graphrag.prompts.index.community_report import (
COMMUNITY_REPORT_PROMPT,
@ -28,8 +27,8 @@ from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PR
def initialize_project_at(path: Path) -> None:
"""Initialize the project at the given path."""
progress_reporter = create_progress_reporter(ReporterType.RICH)
progress_reporter.info(f"Initializing project at {path}")
progress_logger = LoggerFactory().create_logger(LoggerType.RICH)
progress_logger.info(f"Initializing project at {path}") # noqa: G004
root = Path(path)
if not root.exists():
root.mkdir(parents=True, exist_ok=True)

View File

@ -12,7 +12,7 @@ from typing import Annotated
import typer
from graphrag.logging.types import ReporterType
from graphrag.logger.types import LoggerType
from graphrag.prompt_tune.defaults import (
MAX_TOKEN_COUNT,
MIN_CHUNK_SIZE,
@ -145,9 +145,9 @@ def _index_cli(
resume: Annotated[
str | None, typer.Option(help="Resume a given indexing run")
] = None,
reporter: Annotated[
ReporterType, typer.Option(help="The progress reporter to use.")
] = ReporterType.RICH,
logger: Annotated[
LoggerType, typer.Option(help="The progress logger to use.")
] = LoggerType.RICH,
dry_run: Annotated[
bool,
typer.Option(
@ -180,7 +180,7 @@ def _index_cli(
resume=resume,
memprofile=memprofile,
cache=cache,
reporter=ReporterType(reporter),
logger=LoggerType(logger),
config_filepath=config,
dry_run=dry_run,
skip_validation=skip_validation,
@ -212,9 +212,9 @@ def _update_cli(
memprofile: Annotated[
bool, typer.Option(help="Run the indexing pipeline with memory profiling")
] = False,
reporter: Annotated[
ReporterType, typer.Option(help="The progress reporter to use.")
] = ReporterType.RICH,
logger: Annotated[
LoggerType, typer.Option(help="The progress logger to use.")
] = LoggerType.RICH,
cache: Annotated[bool, typer.Option(help="Use LLM cache.")] = True,
skip_validation: Annotated[
bool,
@ -244,7 +244,7 @@ def _update_cli(
verbose=verbose,
memprofile=memprofile,
cache=cache,
reporter=ReporterType(reporter),
logger=LoggerType(logger),
config_filepath=config,
skip_validation=skip_validation,
output_dir=output,

View File

@ -1,13 +1,13 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""CLI implementation of prompt-tune subcommand."""
"""CLI implementation of the prompt-tune subcommand."""
from pathlib import Path
import graphrag.api as api
from graphrag.config.load_config import load_config
from graphrag.logging.print_progress import PrintProgressReporter
from graphrag.logger.print_progress import PrintProgressLogger
from graphrag.prompt_tune.generator.community_report_summarization import (
COMMUNITY_SUMMARIZATION_FILENAME,
)
@ -52,7 +52,7 @@ async def prompt_tune(
- k: The number of documents to select when using auto selection method.
- min_examples_required: The minimum number of examples required for entity extraction prompts.
"""
reporter = PrintProgressReporter("")
logger = PrintProgressLogger("")
root_path = Path(root).resolve()
graph_config = load_config(root_path, config)
@ -73,7 +73,7 @@ async def prompt_tune(
output_path = output.resolve()
if output_path:
reporter.info(f"Writing prompts to {output_path}")
logger.info(f"Writing prompts to {output_path}") # noqa: G004
output_path.mkdir(parents=True, exist_ok=True)
entity_extraction_prompt_path = output_path / ENTITY_EXTRACTION_FILENAME
entity_summarization_prompt_path = output_path / ENTITY_SUMMARIZATION_FILENAME

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""CLI implementation of query subcommand."""
"""CLI implementation of the query subcommand."""
import asyncio
import sys
@ -14,11 +14,11 @@ from graphrag.config.load_config import load_config
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.resolve_path import resolve_paths
from graphrag.index.create_pipeline_config import create_pipeline_config
from graphrag.logging.print_progress import PrintProgressReporter
from graphrag.storage.factory import create_storage
from graphrag.logger.print_progress import PrintProgressLogger
from graphrag.storage.factory import StorageFactory
from graphrag.utils.storage import load_table_from_storage
reporter = PrintProgressReporter("")
logger = PrintProgressLogger("")
def run_global_search(
@ -40,9 +40,9 @@ def run_global_search(
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
resolve_paths(config)
dataframe_dict = _resolve_parquet_files(
dataframe_dict = _resolve_output_files(
config=config,
parquet_list=[
output_list=[
"create_final_nodes.parquet",
"create_final_entities.parquet",
"create_final_communities.parquet",
@ -100,7 +100,7 @@ def run_global_search(
query=query,
)
)
reporter.success(f"Global Search Response:\n{response}")
logger.success(f"Global Search Response:\n{response}")
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data
@ -124,9 +124,9 @@ def run_local_search(
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
resolve_paths(config)
dataframe_dict = _resolve_parquet_files(
dataframe_dict = _resolve_output_files(
config=config,
parquet_list=[
output_list=[
"create_final_nodes.parquet",
"create_final_community_reports.parquet",
"create_final_text_units.parquet",
@ -191,7 +191,7 @@ def run_local_search(
query=query,
)
)
reporter.success(f"Local Search Response:\n{response}")
logger.success(f"Local Search Response:\n{response}")
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data
@ -214,9 +214,9 @@ def run_drift_search(
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
resolve_paths(config)
dataframe_dict = _resolve_parquet_files(
dataframe_dict = _resolve_output_files(
config=config,
parquet_list=[
output_list=[
"create_final_nodes.parquet",
"create_final_community_reports.parquet",
"create_final_text_units.parquet",
@ -250,30 +250,33 @@ def run_drift_search(
query=query,
)
)
reporter.success(f"DRIFT Search Response:\n{response}")
logger.success(f"DRIFT Search Response:\n{response}")
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
# TODO: Map/Reduce Drift Search answer to a single response
return response, context_data
def _resolve_parquet_files(
def _resolve_output_files(
config: GraphRagConfig,
parquet_list: list[str],
output_list: list[str],
optional_list: list[str] | None = None,
) -> dict[str, pd.DataFrame]:
"""Read parquet files to a dataframe dict."""
"""Read indexing output files to a dataframe dict."""
dataframe_dict = {}
pipeline_config = create_pipeline_config(config)
storage_obj = create_storage(pipeline_config.storage) # type: ignore
for parquet_file in parquet_list:
df_key = parquet_file.split(".")[0]
storage_config = pipeline_config.storage.model_dump() # type: ignore
storage_obj = StorageFactory().create_storage(
storage_type=storage_config["type"], kwargs=storage_config
)
for output_file in output_list:
df_key = output_file.split(".")[0]
df_value = asyncio.run(
load_table_from_storage(name=parquet_file, storage=storage_obj)
load_table_from_storage(name=output_file, storage=storage_obj)
)
dataframe_dict[df_key] = df_value
# for optional parquet files, set the dict entry to None instead of erroring out if it does not exist
# for optional output files, set the dict entry to None instead of erroring out if it does not exist
if optional_list:
for optional_file in optional_list:
file_exists = asyncio.run(storage_obj.has(optional_file))

View File

@ -33,7 +33,7 @@ async_mode: {defs.ASYNC_MODE.value} # or asyncio
embeddings:
async_mode: {defs.ASYNC_MODE.value} # or asyncio
vector_store:{defs.VECTOR_STORE}
vector_store: {defs.VECTOR_STORE}
llm:
api_key: ${{GRAPHRAG_API_KEY}}
type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'PipelineCacheConfig', 'PipelineFileCacheConfig' and 'PipelineMemoryCacheConfig' models."""
"""A module containing 'PipelineCacheConfig', 'PipelineFileCacheConfig', 'PipelineMemoryCacheConfig', 'PipelineBlobCacheConfig' models."""
from __future__ import annotations

View File

@ -12,7 +12,7 @@ import pandas as pd
from graphrag.index.config.input import PipelineCSVInputConfig, PipelineInputConfig
from graphrag.index.utils.hashing import gen_sha512_hash
from graphrag.logging.base import ProgressReporter
from graphrag.logger.base import ProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
log = logging.getLogger(__name__)
@ -24,7 +24,7 @@ input_type = "csv"
async def load(
config: PipelineInputConfig,
progress: ProgressReporter | None,
progress: ProgressLogger | None,
storage: PipelineStorage,
) -> pd.DataFrame:
"""Load csv inputs from a directory."""

View File

@ -17,8 +17,8 @@ from graphrag.index.input.csv import input_type as csv
from graphrag.index.input.csv import load as load_csv
from graphrag.index.input.text import input_type as text
from graphrag.index.input.text import load as load_text
from graphrag.logging.base import ProgressReporter
from graphrag.logging.null_progress import NullProgressReporter
from graphrag.logger.base import ProgressLogger
from graphrag.logger.null_progress import NullProgressLogger
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
@ -31,13 +31,13 @@ loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = {
async def create_input(
config: PipelineInputConfig | InputConfig,
progress_reporter: ProgressReporter | None = None,
progress_reporter: ProgressLogger | None = None,
root_dir: str | None = None,
) -> pd.DataFrame:
"""Instantiate input data for a pipeline."""
root_dir = root_dir or ""
log.info("loading input from root_dir=%s", config.base_dir)
progress_reporter = progress_reporter or NullProgressReporter()
progress_reporter = progress_reporter or NullProgressLogger()
match config.type:
case InputType.blob:

View File

@ -12,7 +12,7 @@ import pandas as pd
from graphrag.index.config.input import PipelineInputConfig
from graphrag.index.utils.hashing import gen_sha512_hash
from graphrag.logging.base import ProgressReporter
from graphrag.logger.base import ProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
DEFAULT_FILE_PATTERN = re.compile(
@ -24,7 +24,7 @@ log = logging.getLogger(__name__)
async def load(
config: PipelineInputConfig,
progress: ProgressReporter | None,
progress: ProgressLogger | None,
storage: PipelineStorage,
) -> pd.DataFrame:
"""Load text inputs from a directory."""

View File

@ -217,7 +217,7 @@ def _create_vector_store(
if collection_name:
vector_store_config.update({"collection_name": collection_name})
vector_store = VectorStoreFactory.get_vector_store(
vector_store = VectorStoreFactory().create_vector_store(
vector_store_type, kwargs=vector_store_config
)

View File

@ -8,21 +8,18 @@ import logging
import time
import traceback
from collections.abc import AsyncIterable
from pathlib import Path
from typing import cast
import pandas as pd
from datashaper import NoopVerbCallbacks, WorkflowCallbacks
from graphrag.cache.factory import create_cache
from graphrag.cache.factory import CacheFactory
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
from graphrag.index.config.cache import PipelineMemoryCacheConfig
from graphrag.index.config.pipeline import (
PipelineConfig,
PipelineWorkflowReference,
)
from graphrag.index.config.storage import PipelineFileStorageConfig
from graphrag.index.config.workflow import PipelineWorkflowStep
from graphrag.index.exporter import ParquetExporter
from graphrag.index.input.factory import create_input
@ -51,9 +48,9 @@ from graphrag.index.workflows import (
WorkflowDefinitions,
load_workflows,
)
from graphrag.logging.base import ProgressReporter
from graphrag.logging.null_progress import NullProgressReporter
from graphrag.storage.factory import create_storage
from graphrag.logger.base import ProgressLogger
from graphrag.logger.null_progress import NullProgressLogger
from graphrag.storage.factory import StorageFactory
from graphrag.storage.pipeline_storage import PipelineStorage
log = logging.getLogger(__name__)
@ -67,7 +64,7 @@ async def run_pipeline_with_config(
update_index_storage: PipelineStorage | None = None,
cache: PipelineCache | None = None,
callbacks: list[WorkflowCallbacks] | None = None,
progress_reporter: ProgressReporter | None = None,
logger: ProgressLogger | None = None,
input_post_process_steps: list[PipelineWorkflowStep] | None = None,
additional_verbs: VerbDefinitions | None = None,
additional_workflows: WorkflowDefinitions | None = None,
@ -85,7 +82,7 @@ async def run_pipeline_with_config(
- dataset - The dataset to run the pipeline on (this overrides the config)
- storage - The storage to use for the pipeline (this overrides the config)
- cache - The cache to use for the pipeline (this overrides the config)
- reporter - The reporter to use for the pipeline (this overrides the config)
- logger - The logger to use for the pipeline (this overrides the config)
- input_post_process_steps - The post process steps to run on the input data (this overrides the config)
- additional_verbs - The custom verbs to use for the pipeline.
- additional_workflows - The custom workflows to use for the pipeline.
@ -102,23 +99,32 @@ async def run_pipeline_with_config(
config = _apply_substitutions(config, run_id)
root_dir = config.root_dir or ""
progress_reporter = progress_reporter or NullProgressReporter()
storage = storage = create_storage(config.storage) # type: ignore
progress_logger = logger or NullProgressLogger()
storage_config = config.storage.model_dump() # type: ignore
storage = storage or StorageFactory().create_storage(
storage_type=storage_config["type"], # type: ignore
kwargs=storage_config,
)
if is_update_run:
# TODO: remove the default choice (PipelineFileStorageConfig) once the new config system enforces a correct update-index-storage config when used.
update_index_storage = update_index_storage or create_storage(
config.update_index_storage
or PipelineFileStorageConfig(base_dir=str(Path(root_dir) / "output"))
update_storage_config = config.update_index_storage.model_dump() # type: ignore
update_index_storage = update_index_storage or StorageFactory().create_storage(
storage_type=update_storage_config["type"], # type: ignore
kwargs=update_storage_config,
)
# TODO: remove the default choice (PipelineMemoryCacheConfig) when the new config system guarantees the existence of a cache config
cache = cache or create_cache(config.cache or PipelineMemoryCacheConfig(), root_dir)
# TODO: remove the type ignore when the new config system guarantees the existence of a cache config
cache_config = config.cache.model_dump() # type: ignore
cache = cache or CacheFactory().create_cache(
cache_type=cache_config["type"], # type: ignore
root_dir=root_dir,
kwargs=cache_config,
)
# TODO: remove the type ignore when the new config system guarantees the existence of an input config
dataset = (
dataset
if dataset is not None
else await create_input(config.input, progress_reporter, root_dir) # type: ignore
else await create_input(config.input, progress_logger, root_dir) # type: ignore
)
post_process_steps = input_post_process_steps or _create_postprocess_steps(
@ -148,12 +154,12 @@ async def run_pipeline_with_config(
memory_profile=memory_profile,
additional_verbs=additional_verbs,
additional_workflows=additional_workflows,
progress_reporter=progress_reporter,
progress_logger=progress_logger,
is_resume_run=False,
):
tables_dict[table.workflow] = table.result
progress_reporter.success("Finished running workflows on new documents.")
progress_logger.success("Finished running workflows on new documents.")
await update_dataframe_outputs(
dataframe_dict=tables_dict,
storage=storage,
@ -161,7 +167,7 @@ async def run_pipeline_with_config(
config=config,
cache=cache,
callbacks=NoopVerbCallbacks(),
progress_reporter=progress_reporter,
progress_logger=progress_logger,
)
else:
@ -175,7 +181,7 @@ async def run_pipeline_with_config(
memory_profile=memory_profile,
additional_verbs=additional_verbs,
additional_workflows=additional_workflows,
progress_reporter=progress_reporter,
progress_logger=progress_logger,
is_resume_run=is_resume_run,
):
yield table
@ -187,7 +193,7 @@ async def run_pipeline(
storage: PipelineStorage | None = None,
cache: PipelineCache | None = None,
callbacks: list[WorkflowCallbacks] | None = None,
progress_reporter: ProgressReporter | None = None,
progress_logger: ProgressLogger | None = None,
input_post_process_steps: list[PipelineWorkflowStep] | None = None,
additional_verbs: VerbDefinitions | None = None,
additional_workflows: WorkflowDefinitions | None = None,
@ -206,7 +212,7 @@ async def run_pipeline(
These must exist after any post process steps are run if there are any!
- storage - The storage to use for the pipeline
- cache - The cache to use for the pipeline
- reporter - The reporter to use for the pipeline
- progress_logger - The logger to use for the pipeline
- input_post_process_steps - The post process steps to run on the input data
- additional_verbs - The custom verbs to use for the pipeline
- additional_workflows - The custom workflows to use for the pipeline
@ -216,7 +222,7 @@ async def run_pipeline(
"""
start_time = time.time()
progress_reporter = progress_reporter or NullProgressReporter()
progress_reporter = progress_logger or NullProgressLogger()
callbacks = callbacks or [ConsoleWorkflowCallbacks()]
callback_chain = _create_callback_chain(callbacks, progress_reporter)
context = create_run_context(storage=storage, cache=cache, stats=None)

View File

@ -21,7 +21,7 @@ from graphrag.index.context import PipelineRunContext
from graphrag.index.exporter import ParquetExporter
from graphrag.index.run.profiling import _write_workflow_stats
from graphrag.index.typing import PipelineRunResult
from graphrag.logging.base import ProgressReporter
from graphrag.logger.base import ProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
from graphrag.utils.storage import load_table_from_storage
@ -68,7 +68,7 @@ async def _export_workflow_output(
def _create_callback_chain(
callbacks: list[WorkflowCallbacks] | None, progress: ProgressReporter | None
callbacks: list[WorkflowCallbacks] | None, progress: ProgressLogger | None
) -> WorkflowCallbacks:
"""Create a callback manager that encompasses multiple callbacks."""
manager = WorkflowCallbacksManager()

View File

@ -23,7 +23,7 @@ from graphrag.index.update.entities import (
_run_entity_summarization,
)
from graphrag.index.update.relationships import _update_and_merge_relationships
from graphrag.logging.print_progress import ProgressReporter
from graphrag.logger.print_progress import ProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
from graphrag.utils.storage import load_table_from_storage
@ -85,7 +85,7 @@ async def update_dataframe_outputs(
config: PipelineConfig,
cache: PipelineCache,
callbacks: VerbCallbacks,
progress_reporter: ProgressReporter,
progress_logger: ProgressLogger,
) -> None:
"""Update the mergeable outputs.
@ -96,25 +96,25 @@ async def update_dataframe_outputs(
storage : PipelineStorage
The storage used to store the dataframes.
"""
progress_reporter.info("Updating Final Documents")
progress_logger.info("Updating Final Documents")
final_documents_df = await _concat_dataframes(
"create_final_documents", dataframe_dict, storage, update_storage
)
# Update entities and merge them
progress_reporter.info("Updating Final Entities")
progress_logger.info("Updating Final Entities")
merged_entities_df, entity_id_mapping = await _update_entities(
dataframe_dict, storage, update_storage, config, cache, callbacks
)
# Update relationships with the entities id mapping
progress_reporter.info("Updating Final Relationships")
progress_logger.info("Updating Final Relationships")
merged_relationships_df = await _update_relationships(
dataframe_dict, storage, update_storage
)
# Update and merge final text units
progress_reporter.info("Updating Final Text Units")
progress_logger.info("Updating Final Text Units")
merged_text_units = await _update_text_units(
dataframe_dict, storage, update_storage, entity_id_mapping
)
@ -124,23 +124,23 @@ async def update_dataframe_outputs(
await storage.has("create_final_covariates.parquet")
and "create_final_covariates" in dataframe_dict
):
progress_reporter.info("Updating Final Covariates")
progress_logger.info("Updating Final Covariates")
await _update_covariates(dataframe_dict, storage, update_storage)
# Merge final nodes and update community ids
progress_reporter.info("Updating Final Nodes")
progress_logger.info("Updating Final Nodes")
_, community_id_mapping = await _update_nodes(
dataframe_dict, storage, update_storage, merged_entities_df
)
# Merge final communities
progress_reporter.info("Updating Final Communities")
progress_logger.info("Updating Final Communities")
await _update_communities(
dataframe_dict, storage, update_storage, community_id_mapping
)
# Merge community reports
progress_reporter.info("Updating Final Community Reports")
progress_logger.info("Updating Final Community Reports")
merged_community_reports = await _update_community_reports(
dataframe_dict, storage, update_storage, community_id_mapping
)
@ -151,7 +151,7 @@ async def update_dataframe_outputs(
)
# Generate text embeddings
progress_reporter.info("Updating Text Embeddings")
progress_logger.info("Updating Text Embeddings")
await generate_text_embeddings(
final_documents=final_documents_df,
final_relationships=merged_relationships_df,

View File

@ -10,12 +10,10 @@ from datashaper import NoopVerbCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.llm.load_llm import load_llm, load_llm_embeddings
from graphrag.logging.print_progress import ProgressReporter
from graphrag.logger.print_progress import ProgressLogger
def validate_config_names(
reporter: ProgressReporter, parameters: GraphRagConfig
) -> None:
def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) -> None:
"""Validate config file for LLM deployment name typos."""
# Validate Chat LLM configs
llm = load_llm(
@ -26,9 +24,9 @@ def validate_config_names(
)
try:
asyncio.run(llm("This is an LLM connectivity test. Say Hello World"))
reporter.success("LLM Config Params Validated")
logger.success("LLM Config Params Validated")
except Exception as e: # noqa: BLE001
reporter.error(f"LLM configuration error detected. Exiting...\n{e}")
logger.error(f"LLM configuration error detected. Exiting...\n{e}") # noqa
sys.exit(1)
# Validate Embeddings LLM configs
@ -40,7 +38,7 @@ def validate_config_names(
)
try:
asyncio.run(embed_llm(["This is an LLM Embedding Test String"]))
reporter.success("Embedding LLM Config Params Validated")
logger.success("Embedding LLM Config Params Validated")
except Exception as e: # noqa: BLE001
reporter.error(f"Embedding LLM configuration error detected. Exiting...\n{e}")
logger.error(f"Embedding LLM configuration error detected. Exiting...\n{e}") # noqa
sys.exit(1)

View File

@ -1,4 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Logging utilities and implementations."""
"""Logger utilities and implementations."""

View File

@ -10,24 +10,24 @@ from datashaper.progress.types import Progress
class StatusLogger(ABC):
"""Provides a way to report status updates from the pipeline."""
"""Provides a way to log status updates from the pipeline."""
@abstractmethod
def error(self, message: str, details: dict[str, Any] | None = None):
"""Report an error."""
"""Log an error."""
@abstractmethod
def warning(self, message: str, details: dict[str, Any] | None = None):
"""Report a warning."""
"""Log a warning."""
@abstractmethod
def log(self, message: str, details: dict[str, Any] | None = None):
"""Report a log."""
class ProgressReporter(ABC):
class ProgressLogger(ABC):
"""
Abstract base class for progress reporters.
Abstract base class for progress loggers.
This is used to report workflow processing progress via mechanisms like progress-bars.
"""
@ -38,10 +38,10 @@ class ProgressReporter(ABC):
@abstractmethod
def dispose(self):
"""Dispose of the progress reporter."""
"""Dispose of the progress logger."""
@abstractmethod
def child(self, prefix: str, transient=True) -> "ProgressReporter":
def child(self, prefix: str, transient=True) -> "ProgressLogger":
"""Create a child progress bar."""
@abstractmethod
@ -50,20 +50,20 @@ class ProgressReporter(ABC):
@abstractmethod
def stop(self) -> None:
"""Stop the progress reporter."""
"""Stop the progress logger."""
@abstractmethod
def error(self, message: str) -> None:
"""Report an error."""
"""Log an error."""
@abstractmethod
def warning(self, message: str) -> None:
"""Report a warning."""
"""Log a warning."""
@abstractmethod
def info(self, message: str) -> None:
"""Report information."""
"""Log information."""
@abstractmethod
def success(self, message: str) -> None:
"""Report success."""
"""Log success."""

View File

@ -1,26 +1,26 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Console Reporter."""
"""Console Log."""
from typing import Any
from graphrag.logging.base import StatusLogger
from graphrag.logger.base import StatusLogger
class ConsoleReporter(StatusLogger):
"""A reporter that writes to a console."""
"""A logger that writes to a console."""
def error(self, message: str, details: dict[str, Any] | None = None):
"""Report an error."""
"""Log an error."""
print(message, details) # noqa T201
def warning(self, message: str, details: dict[str, Any] | None = None):
"""Report a warning."""
"""Log a warning."""
_print_warning(message)
def log(self, message: str, details: dict[str, Any] | None = None):
"""Report a log."""
"""Log a log."""
print(message, details) # noqa T201

View File

@ -0,0 +1,43 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Factory functions for creating loggers."""
from typing import ClassVar
from graphrag.logger.base import ProgressLogger
from graphrag.logger.null_progress import NullProgressLogger
from graphrag.logger.print_progress import PrintProgressLogger
from graphrag.logger.rich_progress import RichProgressLogger
from graphrag.logger.types import LoggerType
class LoggerFactory:
"""A factory class for loggers."""
logger_types: ClassVar[dict[str, type]] = {}
@classmethod
def register(cls, logger_type: str, logger: type):
"""Register a custom logger implementation."""
cls.logger_types[logger_type] = logger
@classmethod
def create_logger(
cls, logger_type: LoggerType | str, kwargs: dict | None = None
) -> ProgressLogger:
"""Create a logger based on the provided type."""
if kwargs is None:
kwargs = {}
match logger_type:
case LoggerType.RICH:
return RichProgressLogger("GraphRAG Indexer ")
case LoggerType.PRINT:
return PrintProgressLogger("GraphRAG Indexer ")
case LoggerType.NONE:
return NullProgressLogger()
case _:
if logger_type in cls.logger_types:
return cls.logger_types[logger_type](**kwargs)
# default to null logger if no other logger is found
return NullProgressLogger()

View File

@ -3,19 +3,19 @@
"""Null Progress Reporter."""
from graphrag.logging.base import Progress, ProgressReporter
from graphrag.logger.base import Progress, ProgressLogger
class NullProgressReporter(ProgressReporter):
"""A progress reporter that does nothing."""
class NullProgressLogger(ProgressLogger):
"""A progress logger that does nothing."""
def __call__(self, update: Progress) -> None:
"""Update progress."""
def dispose(self) -> None:
"""Dispose of the progress reporter."""
"""Dispose of the progress logger."""
def child(self, prefix: str, transient: bool = True) -> ProgressReporter:
def child(self, prefix: str, transient: bool = True) -> ProgressLogger:
"""Create a child progress bar."""
return self
@ -23,16 +23,16 @@ class NullProgressReporter(ProgressReporter):
"""Force a refresh."""
def stop(self) -> None:
"""Stop the progress reporter."""
"""Stop the progress logger."""
def error(self, message: str) -> None:
"""Report an error."""
"""Log an error."""
def warning(self, message: str) -> None:
"""Report a warning."""
"""Log a warning."""
def info(self, message: str) -> None:
"""Report information."""
"""Log information."""
def success(self, message: str) -> None:
"""Report success."""
"""Log success."""

View File

@ -1,18 +1,18 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Print Progress Reporter."""
"""Print Progress Logger."""
from graphrag.logging.base import Progress, ProgressReporter
from graphrag.logger.base import Progress, ProgressLogger
class PrintProgressReporter(ProgressReporter):
"""A progress reporter that does nothing."""
class PrintProgressLogger(ProgressLogger):
"""A progress logger that prints progress to stdout."""
prefix: str
def __init__(self, prefix: str):
"""Create a new progress reporter."""
"""Create a new progress logger."""
self.prefix = prefix
print(f"\n{self.prefix}", end="") # noqa T201
@ -21,30 +21,30 @@ class PrintProgressReporter(ProgressReporter):
print(".", end="") # noqa T201
def dispose(self) -> None:
"""Dispose of the progress reporter."""
"""Dispose of the progress logger."""
def child(self, prefix: str, transient: bool = True) -> "ProgressReporter":
def child(self, prefix: str, transient: bool = True) -> ProgressLogger:
"""Create a child progress bar."""
return PrintProgressReporter(prefix)
return PrintProgressLogger(prefix)
def stop(self) -> None:
"""Stop the progress reporter."""
"""Stop the progress logger."""
def force_refresh(self) -> None:
"""Force a refresh."""
def error(self, message: str) -> None:
"""Report an error."""
"""Log an error."""
print(f"\n{self.prefix}ERROR: {message}") # noqa T201
def warning(self, message: str) -> None:
"""Report a warning."""
"""Log a warning."""
print(f"\n{self.prefix}WARNING: {message}") # noqa T201
def info(self, message: str) -> None:
"""Report information."""
"""Log information."""
print(f"\n{self.prefix}INFO: {message}") # noqa T201
def success(self, message: str) -> None:
"""Report success."""
"""Log success."""
print(f"\n{self.prefix}SUCCESS: {message}") # noqa T201

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Rich-based progress reporter for CLI use."""
"""Rich-based progress logger for CLI use."""
# Print iterations progress
import asyncio
@ -13,12 +13,12 @@ from rich.progress import Progress, TaskID, TimeElapsedColumn
from rich.spinner import Spinner
from rich.tree import Tree
from graphrag.logging.base import ProgressReporter
from graphrag.logger.base import ProgressLogger
# https://stackoverflow.com/a/34325723
class RichProgressReporter(ProgressReporter):
"""A rich-based progress reporter for CLI use."""
class RichProgressLogger(ProgressLogger):
"""A rich-based progress logger for CLI use."""
_console: Console
_group: Group
@ -32,7 +32,7 @@ class RichProgressReporter(ProgressReporter):
_last_refresh: float = 0
def dispose(self) -> None:
"""Dispose of the progress reporter."""
"""Dispose of the progress logger."""
self._disposing = True
self._live.stop()
@ -59,10 +59,10 @@ class RichProgressReporter(ProgressReporter):
def __init__(
self,
prefix: str,
parent: "RichProgressReporter | None" = None,
parent: "RichProgressLogger | None" = None,
transient: bool = True,
) -> None:
"""Create a new rich-based progress reporter."""
"""Create a new rich-based progress logger."""
self._prefix = prefix
if parent is None:
@ -115,27 +115,27 @@ class RichProgressReporter(ProgressReporter):
self.live.refresh()
def stop(self) -> None:
"""Stop the progress reporter."""
"""Stop the progress logger."""
self._live.stop()
def child(self, prefix: str, transient: bool = True) -> ProgressReporter:
def child(self, prefix: str, transient: bool = True) -> ProgressLogger:
"""Create a child progress bar."""
return RichProgressReporter(parent=self, prefix=prefix, transient=transient)
return RichProgressLogger(parent=self, prefix=prefix, transient=transient)
def error(self, message: str) -> None:
"""Report an error."""
"""Log an error."""
self._console.print(f"❌ [red]{message}[/red]")
def warning(self, message: str) -> None:
"""Report a warning."""
"""Log a warning."""
self._console.print(f"⚠️ [yellow]{message}[/yellow]")
def success(self, message: str) -> None:
"""Report success."""
"""Log success."""
self._console.print(f"🚀 [green]{message}[/green]")
def info(self, message: str) -> None:
"""Report information."""
"""Log information."""
self._console.print(message)
def __call__(self, progress_update: DSProgress) -> None:

22
graphrag/logger/types.py Normal file
View File

@ -0,0 +1,22 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Logging types.
This module defines the types of loggers that can be used.
"""
from enum import Enum
# Note: Code in this module was not included in the factory module because it negatively impacts the CLI experience.
class LoggerType(str, Enum):
"""The type of logger to use."""
RICH = "rich"
PRINT = "print"
NONE = "none"
def __str__(self):
"""Return a string representation of the enum value."""
return self.value

View File

@ -1,36 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Factory functions for creating loggers."""
from graphrag.logging.base import ProgressReporter
from graphrag.logging.null_progress import NullProgressReporter
from graphrag.logging.print_progress import PrintProgressReporter
from graphrag.logging.rich_progress import RichProgressReporter
from graphrag.logging.types import ReporterType
def create_progress_reporter(
reporter_type: ReporterType = ReporterType.NONE,
) -> ProgressReporter:
"""Load a progress reporter.
Parameters
----------
reporter_type : {"rich", "print", "none"}, default=rich
The type of progress reporter to load.
Returns
-------
ProgressReporter
"""
match reporter_type:
case ReporterType.RICH:
return RichProgressReporter("GraphRAG Indexer ")
case ReporterType.PRINT:
return PrintProgressReporter("GraphRAG Indexer ")
case ReporterType.NONE:
return NullProgressReporter()
case _:
msg = f"Invalid progress reporter type: {reporter_type}"
raise ValueError(msg)

View File

@ -1,18 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Types for status reporting."""
from enum import Enum
class ReporterType(str, Enum):
"""The type of reporter to use."""
RICH = "rich"
PRINT = "print"
NONE = "none"
def __str__(self):
"""Return the string representation of the enum value."""
return self.value

View File

@ -1,4 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""GraphRAG knowledge model package root."""
"""Knowledge model package."""

View File

@ -15,7 +15,7 @@ from graphrag.config.models.llm_parameters import LLMParameters
from graphrag.index.input.factory import create_input
from graphrag.index.llm.load_llm import load_llm_embeddings
from graphrag.index.operations.chunk_text import chunk_text
from graphrag.logging.base import ProgressReporter
from graphrag.logger.base import ProgressLogger
from graphrag.prompt_tune.defaults import (
MIN_CHUNK_OVERLAP,
MIN_CHUNK_SIZE,
@ -54,7 +54,7 @@ async def load_docs_in_chunks(
config: GraphRagConfig,
select_method: DocSelectionType,
limit: int,
reporter: ProgressReporter,
logger: ProgressLogger,
chunk_size: int = MIN_CHUNK_SIZE,
n_subset_max: int = N_SUBSET_MAX,
k: int = K,
@ -64,7 +64,7 @@ async def load_docs_in_chunks(
config.embeddings.resolved_strategy()["llm"]
)
dataset = await create_input(config.input, reporter, root)
dataset = await create_input(config.input, logger, root)
# covert to text units
chunk_strategy = config.chunks.resolved_strategy(defs.ENCODING_MODEL)

View File

@ -1,4 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All prompts for indexing."""
"""All prompts for the indexing engine."""

View File

@ -1,4 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All prompts for query."""
"""All prompts for the query engine."""

View File

@ -8,8 +8,8 @@ from collections.abc import Callable
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from graphrag.logging.base import StatusLogger
from graphrag.logging.console import ConsoleReporter
from graphrag.logger.base import StatusLogger
from graphrag.logger.console import ConsoleReporter
from graphrag.query.llm.base import BaseTextEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
@ -101,7 +101,7 @@ class OpenAILLMImpl(BaseOpenAILLM):
organization: str | None = None,
max_retries: int = 10,
request_timeout: float = 180.0,
reporter: StatusLogger | None = None,
logger: StatusLogger | None = None,
):
self.api_key = api_key
self.azure_ad_token_provider = azure_ad_token_provider
@ -112,7 +112,7 @@ class OpenAILLMImpl(BaseOpenAILLM):
self.organization = organization
self.max_retries = max_retries
self.request_timeout = request_timeout
self.reporter = reporter or ConsoleReporter()
self.logger = logger or ConsoleReporter()
try:
# Create OpenAI sync and async clients

View File

@ -15,7 +15,7 @@ from tenacity import (
wait_exponential_jitter,
)
from graphrag.logging.base import StatusLogger
from graphrag.logger.base import StatusLogger
from graphrag.query.llm.base import BaseLLM, BaseLLMCallback
from graphrag.query.llm.oai.base import OpenAILLMImpl
from graphrag.query.llm.oai.typing import (
@ -42,7 +42,7 @@ class ChatOpenAI(BaseLLM, OpenAILLMImpl):
max_retries: int = 10,
request_timeout: float = 180.0,
retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore
reporter: StatusLogger | None = None,
logger: StatusLogger | None = None,
):
OpenAILLMImpl.__init__(
self=self,
@ -55,7 +55,7 @@ class ChatOpenAI(BaseLLM, OpenAILLMImpl):
organization=organization,
max_retries=max_retries,
request_timeout=request_timeout,
reporter=reporter,
logger=logger,
)
self.model = model
self.retry_error_types = retry_error_types

View File

@ -18,7 +18,7 @@ from tenacity import (
wait_exponential_jitter,
)
from graphrag.logging.base import StatusLogger
from graphrag.logger.base import StatusLogger
from graphrag.query.llm.base import BaseTextEmbedding
from graphrag.query.llm.oai.base import OpenAILLMImpl
from graphrag.query.llm.oai.typing import (
@ -46,7 +46,7 @@ class OpenAIEmbedding(BaseTextEmbedding, OpenAILLMImpl):
max_retries: int = 10,
request_timeout: float = 180.0,
retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore
reporter: StatusLogger | None = None,
logger: StatusLogger | None = None,
):
OpenAILLMImpl.__init__(
self=self,
@ -59,7 +59,7 @@ class OpenAIEmbedding(BaseTextEmbedding, OpenAILLMImpl):
organization=organization,
max_retries=max_retries,
request_timeout=request_timeout,
reporter=reporter,
logger=logger,
)
self.model = model

View File

@ -13,7 +13,7 @@ from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient
from datashaper import Progress
from graphrag.logging.base import ProgressReporter
from graphrag.logger.base import ProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
log = logging.getLogger(__name__)
@ -32,7 +32,7 @@ class BlobPipelineStorage(PipelineStorage):
self,
connection_string: str | None,
container_name: str,
encoding: str | None = None,
encoding: str = "utf-8",
path_prefix: str | None = None,
storage_account_blob_url: str | None = None,
):
@ -50,7 +50,7 @@ class BlobPipelineStorage(PipelineStorage):
account_url=storage_account_blob_url,
credential=DefaultAzureCredential(),
)
self._encoding = encoding or "utf-8"
self._encoding = encoding
self._container_name = container_name
self._connection_string = connection_string
self._path_prefix = path_prefix or ""
@ -95,7 +95,7 @@ class BlobPipelineStorage(PipelineStorage):
self,
file_pattern: re.Pattern[str],
base_dir: str | None = None,
progress: ProgressReporter | None = None,
progress: ProgressLogger | None = None,
file_filter: dict[str, Any] | None = None,
max_count=-1,
) -> Iterator[tuple[str, dict[str, Any]]]:
@ -179,7 +179,7 @@ class BlobPipelineStorage(PipelineStorage):
blob_client = container_client.get_blob_client(key)
blob_data = blob_client.download_blob().readall()
if not as_bytes:
coding = encoding or "utf-8"
coding = encoding or self._encoding
blob_data = blob_data.decode(coding)
except Exception:
log.exception("Error getting key %s", key)
@ -198,7 +198,7 @@ class BlobPipelineStorage(PipelineStorage):
if isinstance(value, bytes):
blob_client.upload_blob(value, overwrite=True)
else:
coding = encoding or "utf-8"
coding = encoding or self._encoding
blob_client.upload_blob(value.encode(coding), overwrite=True)
except Exception:
log.exception("Error setting key %s: %s", key)

View File

@ -5,7 +5,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, ClassVar
from graphrag.config.enums import StorageType
from graphrag.storage.blob_pipeline_storage import create_blob_storage
@ -13,29 +13,36 @@ from graphrag.storage.file_pipeline_storage import create_file_storage
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
if TYPE_CHECKING:
from graphrag.index.config.storage import (
PipelineBlobStorageConfig,
PipelineFileStorageConfig,
PipelineStorageConfig,
)
from graphrag.storage.pipeline_storage import PipelineStorage
def create_storage(config: PipelineStorageConfig):
"""Create a storage object based on the config."""
match config.type:
case StorageType.memory:
return MemoryPipelineStorage()
case StorageType.blob:
config = cast("PipelineBlobStorageConfig", config)
return create_blob_storage(
config.connection_string,
config.storage_account_blob_url,
config.container_name,
config.base_dir,
)
case StorageType.file:
config = cast("PipelineFileStorageConfig", config)
return create_file_storage(config.base_dir)
case _:
msg = f"Unknown storage type: {config.type}"
raise ValueError(msg)
class StorageFactory:
"""A factory class for storage implementations.
Includes a method for users to register a custom storage implementation.
"""
storage_types: ClassVar[dict[str, type]] = {}
@classmethod
def register(cls, storage_type: str, storage: type):
"""Register a custom storage implementation."""
cls.storage_types[storage_type] = storage
@classmethod
def create_storage(
cls, storage_type: StorageType | str, kwargs: dict
) -> PipelineStorage:
"""Create or get a storage object from the provided type."""
match storage_type:
case StorageType.blob:
return create_blob_storage(**kwargs)
case StorageType.file:
return create_file_storage(**kwargs)
case StorageType.memory:
return MemoryPipelineStorage()
case _:
if storage_type in cls.storage_types:
return cls.storage_types[storage_type](**kwargs)
msg = f"Unknown storage type: {storage_type}"
raise ValueError(msg)

View File

@ -16,7 +16,7 @@ from aiofiles.os import remove
from aiofiles.ospath import exists
from datashaper import Progress
from graphrag.logging.base import ProgressReporter
from graphrag.logger.base import ProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
log = logging.getLogger(__name__)
@ -28,17 +28,17 @@ class FilePipelineStorage(PipelineStorage):
_root_dir: str
_encoding: str
def __init__(self, root_dir: str | None = None, encoding: str | None = None):
def __init__(self, root_dir: str = "", encoding: str = "utf-8"):
"""Init method definition."""
self._root_dir = root_dir or ""
self._encoding = encoding or "utf-8"
self._root_dir = root_dir
self._encoding = encoding
Path(self._root_dir).mkdir(parents=True, exist_ok=True)
def find(
self,
file_pattern: re.Pattern[str],
base_dir: str | None = None,
progress: ProgressReporter | None = None,
progress: ProgressLogger | None = None,
file_filter: dict[str, Any] | None = None,
max_count=-1,
) -> Iterator[tuple[str, dict[str, Any]]]:
@ -152,10 +152,11 @@ def join_path(file_path: str, file_name: str) -> Path:
return Path(file_path) / Path(file_name).parent / Path(file_name).name
def create_file_storage(out_dir: str | None) -> PipelineStorage:
def create_file_storage(**kwargs: Any) -> PipelineStorage:
"""Create a file based storage."""
log.info("Creating file storage at %s", out_dir)
return FilePipelineStorage(out_dir)
base_dir = kwargs["base_dir"]
log.info("Creating file storage at %s", base_dir)
return FilePipelineStorage(root_dir=base_dir)
def _create_progress_status(

View File

@ -8,7 +8,7 @@ from abc import ABCMeta, abstractmethod
from collections.abc import Iterator
from typing import Any
from graphrag.logging.base import ProgressReporter
from graphrag.logger.base import ProgressLogger
class PipelineStorage(metaclass=ABCMeta):
@ -19,7 +19,7 @@ class PipelineStorage(metaclass=ABCMeta):
self,
file_pattern: re.Pattern[str],
base_dir: str | None = None,
progress: ProgressReporter | None = None,
progress: ProgressLogger | None = None,
file_filter: dict[str, Any] | None = None,
max_count=-1,
) -> Iterator[tuple[str, dict[str, Any]]]:

View File

@ -19,7 +19,7 @@ async def load_table_from_storage(name: str, storage: PipelineStorage) -> pd.Dat
msg = f"Could not find {name} in storage!"
raise ValueError(msg)
try:
log.info("read table from storage: %s", name)
log.info("reading table from storage: %s", name)
return pd.read_parquet(BytesIO(await storage.get(name, as_bytes=True)))
except Exception:
log.exception("error loading table from storage: %s", name)

View File

@ -7,6 +7,7 @@ from enum import Enum
from typing import ClassVar
from graphrag.vector_stores.azure_ai_search import AzureAISearch
from graphrag.vector_stores.base import BaseVectorStore
from graphrag.vector_stores.lancedb import LanceDBVectorStore
@ -24,14 +25,14 @@ class VectorStoreFactory:
@classmethod
def register(cls, vector_store_type: str, vector_store: type):
"""Register a vector store type."""
"""Register a custom vector store implementation."""
cls.vector_store_types[vector_store_type] = vector_store
@classmethod
def get_vector_store(
def create_vector_store(
cls, vector_store_type: VectorStoreType | str, kwargs: dict
) -> LanceDBVectorStore | AzureAISearch:
"""Get the vector store type from a string."""
) -> BaseVectorStore:
"""Create or get a vector store from the provided type."""
match vector_store_type:
case VectorStoreType.LanceDB:
return LanceDBVectorStore(**kwargs)

1576
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -132,7 +132,6 @@ coverage_report = 'coverage report --omit "**/tests/**" --show-missing'
check_format = 'ruff format . --check'
fix = "ruff check --fix ."
fix_unsafe = "ruff check --fix --unsafe-fixes ."
_test_all = "coverage run -m pytest ./tests"
test_unit = "pytest ./tests/unit"
test_integration = "pytest ./tests/integration"
@ -146,7 +145,6 @@ query = "python -m graphrag query"
prompt_tune = "python -m graphrag prompt-tune"
# Pass in a test pattern
test_only = "pytest -s -k"
serve_docs = "mkdocs serve"
build_docs = "mkdocs build"

View File

@ -137,7 +137,7 @@ class TestIndexer:
"--verbose" if debug else None,
"--root",
root.resolve().as_posix(),
"--reporter",
"--logger",
"print",
]
command = [arg for arg in command if arg]