mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 09:07:20 +08:00
Cleanup factory methods (#1482)
* cleanup factory methods to have similar design pattern across codebase * add semversioner file * cleanup logging factory * update developer guide * add comment * typo fix * cleanup reporter terminology * renmae reporter to logger * fix comments * update comment * instantiate factory classes correctly and update index api callback parameter --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
parent
04405803db
commit
823342188d
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "cleanup and refactor factory classes."
|
||||
}
|
||||
@ -10,24 +10,56 @@
|
||||
# Getting Started
|
||||
|
||||
## Install Dependencies
|
||||
|
||||
```sh
|
||||
# Install Python dependencies.
|
||||
```shell
|
||||
# install python dependencies
|
||||
poetry install
|
||||
```
|
||||
|
||||
## Executing the Indexing Engine
|
||||
|
||||
```sh
|
||||
## Execute the indexing engine
|
||||
```shell
|
||||
poetry run poe index <...args>
|
||||
```
|
||||
|
||||
## Executing Queries
|
||||
## Execute prompt tuning
|
||||
```shell
|
||||
poetry run poe prompt_tune <...args>
|
||||
```
|
||||
|
||||
```sh
|
||||
## Execute Queries
|
||||
```shell
|
||||
poetry run poe query <...args>
|
||||
```
|
||||
|
||||
## Repository Structure
|
||||
An overview of the repository's top-level folder structure is provided below, detailing the overall design and purpose.
|
||||
We leverage a factory design pattern where possible, enabling a variety of implementations for each core component of graphrag.
|
||||
|
||||
```shell
|
||||
graphrag
|
||||
├── api # library API definitions
|
||||
├── cache # cache module supporting several options
|
||||
│ └─ factory.py # └─ main entrypoint to create a cache
|
||||
├── callbacks # a collection of commonly used callback functions
|
||||
├── cli # library CLI
|
||||
│ └─ main.py # └─ primary CLI entrypoint
|
||||
├── config # configuration management
|
||||
├── index # indexing engine
|
||||
| └─ run/run.py # main entrypoint to build an index
|
||||
├── llm # generic llm interfaces
|
||||
├── logger # logger module supporting several options
|
||||
│ └─ factory.py # └─ main entrypoint to create a logger
|
||||
├── model # data model definitions associated with the knowledge graph
|
||||
├── prompt_tune # prompt tuning module
|
||||
├── prompts # a collection of all the system prompts used by graphrag
|
||||
├── query # query engine
|
||||
├── storage # storage module supporting several options
|
||||
│ └─ factory.py # └─ main entrypoint to create/load a storage endpoint
|
||||
├── utils # helper functions used throughout the library
|
||||
└── vector_stores # vector store module containing a few options
|
||||
└─ factory.py # └─ main entrypoint to create a vector store
|
||||
```
|
||||
Where appropriate, the factories expose a registration method for users to provide their own custom implementations if desired.
|
||||
|
||||
## Versioning
|
||||
|
||||
We use [semversioner](https://github.com/raulgomis/semversioner) to automate and enforce semantic versioning in the release process. Our CI/CD pipeline checks that all PR's include a json file generated by semversioner. When submitting a PR, please run:
|
||||
|
||||
@ -156,7 +156,7 @@ This section controls the storage mechanism used by the pipeline used for export
|
||||
|
||||
| Parameter | Description | Type | Required or Optional | Default |
|
||||
| ------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- | -------------------- | ------- |
|
||||
| `GRAPHRAG_STORAGE_TYPE` | The type of reporter to use. Options are `file`, `memory`, or `blob` | `str` | optional | `file` |
|
||||
| `GRAPHRAG_STORAGE_TYPE` | The type of storage to use. Options are `file`, `memory`, or `blob` | `str` | optional | `file` |
|
||||
| `GRAPHRAG_STORAGE_STORAGE_ACCOUNT_BLOB_URL` | The Azure Storage blob endpoint to use when in `blob` mode and using managed identity. Will have the format `https://<storage_account_name>.blob.core.windows.net` | `str` | optional | None |
|
||||
| `GRAPHRAG_STORAGE_CONNECTION_STRING` | The Azure Storage connection string to use when in `blob` mode. | `str` | optional | None |
|
||||
| `GRAPHRAG_STORAGE_CONTAINER_NAME` | The Azure Storage container name to use when in `blob` mode. | `str` | optional | None |
|
||||
|
||||
@ -17,7 +17,7 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.create_pipeline_config import create_pipeline_config
|
||||
from graphrag.index.run import run_pipeline_with_config
|
||||
from graphrag.index.typing import PipelineRunResult
|
||||
from graphrag.logging.base import ProgressReporter
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
|
||||
|
||||
async def build_index(
|
||||
@ -26,7 +26,7 @@ async def build_index(
|
||||
is_resume_run: bool = False,
|
||||
memory_profile: bool = False,
|
||||
callbacks: list[WorkflowCallbacks] | None = None,
|
||||
progress_reporter: ProgressReporter | None = None,
|
||||
progress_logger: ProgressLogger | None = None,
|
||||
) -> list[PipelineRunResult]:
|
||||
"""Run the pipeline with the given configuration.
|
||||
|
||||
@ -42,8 +42,8 @@ async def build_index(
|
||||
Whether to enable memory profiling.
|
||||
callbacks : list[WorkflowCallbacks] | None default=None
|
||||
A list of callbacks to register.
|
||||
progress_reporter : ProgressReporter | None default=None
|
||||
The progress reporter.
|
||||
progress_logger : ProgressLogger | None default=None
|
||||
The progress logger.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -60,10 +60,10 @@ async def build_index(
|
||||
pipeline_cache = (
|
||||
NoopPipelineCache() if config.cache.type == CacheType.none is None else None
|
||||
)
|
||||
# create a pipeline reporter and add to any additional callbacks
|
||||
# TODO: remove the type ignore once the new config engine has been refactored
|
||||
callbacks = (
|
||||
[create_pipeline_reporter(config.reporting, None)] if config.reporting else None # type: ignore
|
||||
) # type: ignore
|
||||
callbacks = callbacks or []
|
||||
callbacks.append(create_pipeline_reporter(config.reporting, None)) # type: ignore
|
||||
outputs: list[PipelineRunResult] = []
|
||||
async for output in run_pipeline_with_config(
|
||||
pipeline_config,
|
||||
@ -71,15 +71,15 @@ async def build_index(
|
||||
memory_profile=memory_profile,
|
||||
cache=pipeline_cache,
|
||||
callbacks=callbacks,
|
||||
progress_reporter=progress_reporter,
|
||||
logger=progress_logger,
|
||||
is_resume_run=is_resume_run,
|
||||
is_update_run=is_update_run,
|
||||
):
|
||||
outputs.append(output)
|
||||
if progress_reporter:
|
||||
if progress_logger:
|
||||
if output.errors and len(output.errors) > 0:
|
||||
progress_reporter.error(output.workflow)
|
||||
progress_logger.error(output.workflow)
|
||||
else:
|
||||
progress_reporter.success(output.workflow)
|
||||
progress_reporter.info(str(output.result))
|
||||
progress_logger.success(output.workflow)
|
||||
progress_logger.info(str(output.result))
|
||||
return outputs
|
||||
|
||||
@ -16,7 +16,7 @@ from pydantic import PositiveInt, validate_call
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.llm.load_llm import load_llm
|
||||
from graphrag.logging.print_progress import PrintProgressReporter
|
||||
from graphrag.logger.print_progress import PrintProgressLogger
|
||||
from graphrag.prompt_tune.defaults import MAX_TOKEN_COUNT
|
||||
from graphrag.prompt_tune.generator.community_report_rating import (
|
||||
generate_community_report_rating,
|
||||
@ -80,7 +80,7 @@ async def generate_indexing_prompts(
|
||||
-------
|
||||
tuple[str, str, str]: entity extraction prompt, entity summarization prompt, community summarization prompt
|
||||
"""
|
||||
reporter = PrintProgressReporter("")
|
||||
logger = PrintProgressLogger("")
|
||||
|
||||
# Retrieve documents
|
||||
doc_list = await load_docs_in_chunks(
|
||||
@ -88,7 +88,7 @@ async def generate_indexing_prompts(
|
||||
config=config,
|
||||
limit=limit,
|
||||
select_method=selection_method,
|
||||
reporter=reporter,
|
||||
logger=logger,
|
||||
chunk_size=chunk_size,
|
||||
n_subset_max=n_subset_max,
|
||||
k=k,
|
||||
@ -103,25 +103,25 @@ async def generate_indexing_prompts(
|
||||
)
|
||||
|
||||
if not domain:
|
||||
reporter.info("Generating domain...")
|
||||
logger.info("Generating domain...")
|
||||
domain = await generate_domain(llm, doc_list)
|
||||
reporter.info(f"Generated domain: {domain}")
|
||||
logger.info(f"Generated domain: {domain}") # noqa
|
||||
|
||||
if not language:
|
||||
reporter.info("Detecting language...")
|
||||
logger.info("Detecting language...")
|
||||
language = await detect_language(llm, doc_list)
|
||||
|
||||
reporter.info("Generating persona...")
|
||||
logger.info("Generating persona...")
|
||||
persona = await generate_persona(llm, domain)
|
||||
|
||||
reporter.info("Generating community report ranking description...")
|
||||
logger.info("Generating community report ranking description...")
|
||||
community_report_ranking = await generate_community_report_rating(
|
||||
llm, domain=domain, persona=persona, docs=doc_list
|
||||
)
|
||||
|
||||
entity_types = None
|
||||
if discover_entity_types:
|
||||
reporter.info("Generating entity types...")
|
||||
logger.info("Generating entity types...")
|
||||
entity_types = await generate_entity_types(
|
||||
llm,
|
||||
domain=domain,
|
||||
@ -130,7 +130,7 @@ async def generate_indexing_prompts(
|
||||
json_mode=config.llm.model_supports_json or False,
|
||||
)
|
||||
|
||||
reporter.info("Generating entity relationship examples...")
|
||||
logger.info("Generating entity relationship examples...")
|
||||
examples = await generate_entity_relationship_examples(
|
||||
llm,
|
||||
persona=persona,
|
||||
@ -140,7 +140,7 @@ async def generate_indexing_prompts(
|
||||
json_mode=False, # config.llm.model_supports_json should be used, but these prompts are used in non-json mode by the index engine
|
||||
)
|
||||
|
||||
reporter.info("Generating entity extraction prompt...")
|
||||
logger.info("Generating entity extraction prompt...")
|
||||
entity_extraction_prompt = create_entity_extraction_prompt(
|
||||
entity_types=entity_types,
|
||||
docs=doc_list,
|
||||
@ -152,18 +152,18 @@ async def generate_indexing_prompts(
|
||||
min_examples_required=min_examples_required,
|
||||
)
|
||||
|
||||
reporter.info("Generating entity summarization prompt...")
|
||||
logger.info("Generating entity summarization prompt...")
|
||||
entity_summarization_prompt = create_entity_summarization_prompt(
|
||||
persona=persona,
|
||||
language=language,
|
||||
)
|
||||
|
||||
reporter.info("Generating community reporter role...")
|
||||
logger.info("Generating community reporter role...")
|
||||
community_reporter_role = await generate_community_reporter_role(
|
||||
llm, domain=domain, persona=persona, docs=doc_list
|
||||
)
|
||||
|
||||
reporter.info("Generating community summarization prompt...")
|
||||
logger.info("Generating community summarization prompt...")
|
||||
community_summarization_prompt = create_community_summarization_prompt(
|
||||
persona=persona,
|
||||
role=community_reporter_role,
|
||||
|
||||
@ -19,7 +19,7 @@ Backwards compatibility is not guaranteed at this time.
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pandas as pd
|
||||
from pydantic import validate_call
|
||||
@ -29,7 +29,7 @@ from graphrag.index.config.embeddings import (
|
||||
community_full_content_embedding,
|
||||
entity_description_embedding,
|
||||
)
|
||||
from graphrag.logging.print_progress import PrintProgressReporter
|
||||
from graphrag.logger.print_progress import PrintProgressLogger
|
||||
from graphrag.query.factory import (
|
||||
get_drift_search_engine,
|
||||
get_global_search_engine,
|
||||
@ -44,13 +44,15 @@ from graphrag.query.indexer_adapters import (
|
||||
read_indexer_reports,
|
||||
read_indexer_text_units,
|
||||
)
|
||||
from graphrag.query.structured_search.base import SearchResult # noqa: TC001
|
||||
from graphrag.utils.cli import redact
|
||||
from graphrag.utils.embeddings import create_collection_name
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
from graphrag.vector_stores.factory import VectorStoreFactory
|
||||
|
||||
reporter = PrintProgressReporter("")
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.query.structured_search.base import SearchResult
|
||||
|
||||
logger = PrintProgressLogger("")
|
||||
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
@ -241,7 +243,7 @@ async def local_search(
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
vector_store_args = config.embeddings.vector_store
|
||||
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
|
||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
||||
|
||||
description_embedding_store = _get_embedding_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
@ -307,7 +309,7 @@ async def local_search_streaming(
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
vector_store_args = config.embeddings.vector_store
|
||||
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
|
||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
||||
|
||||
description_embedding_store = _get_embedding_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
@ -380,7 +382,7 @@ async def drift_search(
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
vector_store_args = config.embeddings.vector_store
|
||||
reporter.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore
|
||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
||||
|
||||
description_embedding_store = _get_embedding_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
@ -430,7 +432,7 @@ def _get_embedding_store(
|
||||
collection_name = create_collection_name(
|
||||
config_args.get("container_name", "default"), embedding_name
|
||||
)
|
||||
embedding_store = VectorStoreFactory.get_vector_store(
|
||||
embedding_store = VectorStoreFactory().create_vector_store(
|
||||
vector_store_type=vector_store_type,
|
||||
kwargs={**config_args, "collection_name": collection_name},
|
||||
)
|
||||
|
||||
66
graphrag/cache/factory.py
vendored
66
graphrag/cache/factory.py
vendored
@ -5,7 +5,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from graphrag.config.enums import CacheType
|
||||
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
|
||||
@ -13,41 +13,45 @@ from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.index.config.cache import (
|
||||
PipelineBlobCacheConfig,
|
||||
PipelineCacheConfig,
|
||||
PipelineFileCacheConfig,
|
||||
)
|
||||
|
||||
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
|
||||
from graphrag.cache.memory_pipeline_cache import InMemoryCache
|
||||
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
|
||||
|
||||
|
||||
def create_cache(
|
||||
config: PipelineCacheConfig | None, root_dir: str | None
|
||||
) -> PipelineCache:
|
||||
"""Create a cache from the given config."""
|
||||
if config is None:
|
||||
return NoopPipelineCache()
|
||||
class CacheFactory:
|
||||
"""A factory class for cache implementations.
|
||||
|
||||
match config.type:
|
||||
case CacheType.none:
|
||||
Includes a method for users to register a custom cache implementation.
|
||||
"""
|
||||
|
||||
cache_types: ClassVar[dict[str, type]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, cache_type: str, cache: type):
|
||||
"""Register a custom cache implementation."""
|
||||
cls.cache_types[cache_type] = cache
|
||||
|
||||
@classmethod
|
||||
def create_cache(
|
||||
cls, cache_type: CacheType | str | None, root_dir: str, kwargs: dict
|
||||
) -> PipelineCache:
|
||||
"""Create or get a cache from the provided type."""
|
||||
if not cache_type:
|
||||
return NoopPipelineCache()
|
||||
case CacheType.memory:
|
||||
return InMemoryCache()
|
||||
case CacheType.file:
|
||||
config = cast("PipelineFileCacheConfig", config)
|
||||
storage = FilePipelineStorage(root_dir).child(config.base_dir)
|
||||
return JsonPipelineCache(storage)
|
||||
case CacheType.blob:
|
||||
config = cast("PipelineBlobCacheConfig", config)
|
||||
storage = BlobPipelineStorage(
|
||||
config.connection_string,
|
||||
config.container_name,
|
||||
storage_account_blob_url=config.storage_account_blob_url,
|
||||
).child(config.base_dir)
|
||||
return JsonPipelineCache(storage)
|
||||
case _:
|
||||
msg = f"Unknown cache type: {config.type}"
|
||||
raise ValueError(msg)
|
||||
match cache_type:
|
||||
case CacheType.none:
|
||||
return NoopPipelineCache()
|
||||
case CacheType.memory:
|
||||
return InMemoryCache()
|
||||
case CacheType.file:
|
||||
return JsonPipelineCache(
|
||||
FilePipelineStorage(root_dir=root_dir).child(kwargs["base_dir"])
|
||||
)
|
||||
case CacheType.blob:
|
||||
return JsonPipelineCache(BlobPipelineStorage(**kwargs))
|
||||
case _:
|
||||
if cache_type in cls.cache_types:
|
||||
return cls.cache_types[cache_type](**kwargs)
|
||||
msg = f"Unknown cache type: {cache_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
2
graphrag/cache/json_pipeline_cache.py
vendored
2
graphrag/cache/json_pipeline_cache.py
vendored
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'FilePipelineCache' model."""
|
||||
"""A module containing 'JsonPipelineCache' model."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Create a pipeline reporter."""
|
||||
"""Create a pipeline logger."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
@ -22,7 +22,7 @@ from graphrag.index.config.reporting import (
|
||||
def create_pipeline_reporter(
|
||||
config: PipelineReportingConfig | None, root_dir: str | None
|
||||
) -> WorkflowCallbacks:
|
||||
"""Create a reporter for the given pipeline config."""
|
||||
"""Create a logger for the given pipeline config."""
|
||||
config = config or PipelineFileReportingConfig(base_dir="logs")
|
||||
|
||||
match config.type:
|
||||
|
||||
@ -7,16 +7,16 @@ from typing import Any
|
||||
|
||||
from datashaper import ExecutionNode, NoopWorkflowCallbacks, Progress, TableContainer
|
||||
|
||||
from graphrag.logging.base import ProgressReporter
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
|
||||
|
||||
class ProgressWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
"""A callbackmanager that delegates to a ProgressReporter."""
|
||||
"""A callbackmanager that delegates to a ProgressLogger."""
|
||||
|
||||
_root_progress: ProgressReporter
|
||||
_progress_stack: list[ProgressReporter]
|
||||
_root_progress: ProgressLogger
|
||||
_progress_stack: list[ProgressLogger]
|
||||
|
||||
def __init__(self, progress: ProgressReporter) -> None:
|
||||
def __init__(self, progress: ProgressLogger) -> None:
|
||||
"""Create a new ProgressWorkflowCallbacks."""
|
||||
self._progress = progress
|
||||
self._progress_stack = [progress]
|
||||
@ -28,7 +28,7 @@ class ProgressWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
self._progress_stack.append(self._latest.child(name))
|
||||
|
||||
@property
|
||||
def _latest(self) -> ProgressReporter:
|
||||
def _latest(self) -> ProgressLogger:
|
||||
return self._progress_stack[-1]
|
||||
|
||||
def on_workflow_start(self, name: str, instance: object) -> None:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""CLI implementation of index subcommand."""
|
||||
"""CLI implementation of the index subcommand."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
@ -16,9 +16,8 @@ from graphrag.config.load_config import load_config
|
||||
from graphrag.config.logging import enable_logging_with_config
|
||||
from graphrag.config.resolve_path import resolve_paths
|
||||
from graphrag.index.validate_config import validate_config_names
|
||||
from graphrag.logging.base import ProgressReporter
|
||||
from graphrag.logging.factory import create_progress_reporter
|
||||
from graphrag.logging.types import ReporterType
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.factory import LoggerFactory, LoggerType
|
||||
from graphrag.utils.cli import redact
|
||||
|
||||
# Ignore warnings from numba
|
||||
@ -27,35 +26,35 @@ warnings.filterwarnings("ignore", message=".*NumbaDeprecationWarning.*")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _logger(reporter: ProgressReporter):
|
||||
def _logger(logger: ProgressLogger):
|
||||
def info(msg: str, verbose: bool = False):
|
||||
log.info(msg)
|
||||
if verbose:
|
||||
reporter.info(msg)
|
||||
logger.info(msg)
|
||||
|
||||
def error(msg: str, verbose: bool = False):
|
||||
log.error(msg)
|
||||
if verbose:
|
||||
reporter.error(msg)
|
||||
logger.error(msg)
|
||||
|
||||
def success(msg: str, verbose: bool = False):
|
||||
log.info(msg)
|
||||
if verbose:
|
||||
reporter.success(msg)
|
||||
logger.success(msg)
|
||||
|
||||
return info, error, success
|
||||
|
||||
|
||||
def _register_signal_handlers(reporter: ProgressReporter):
|
||||
def _register_signal_handlers(logger: ProgressLogger):
|
||||
import signal
|
||||
|
||||
def handle_signal(signum, _):
|
||||
# Handle the signal here
|
||||
reporter.info(f"Received signal {signum}, exiting...")
|
||||
reporter.dispose()
|
||||
logger.info(f"Received signal {signum}, exiting...") # noqa: G004
|
||||
logger.dispose()
|
||||
for task in asyncio.all_tasks():
|
||||
task.cancel()
|
||||
reporter.info("All tasks cancelled. Exiting...")
|
||||
logger.info("All tasks cancelled. Exiting...")
|
||||
|
||||
# Register signal handlers for SIGINT and SIGHUP
|
||||
signal.signal(signal.SIGINT, handle_signal)
|
||||
@ -70,7 +69,7 @@ def index_cli(
|
||||
resume: str | None,
|
||||
memprofile: bool,
|
||||
cache: bool,
|
||||
reporter: ReporterType,
|
||||
logger: LoggerType,
|
||||
config_filepath: Path | None,
|
||||
dry_run: bool,
|
||||
skip_validation: bool,
|
||||
@ -85,7 +84,7 @@ def index_cli(
|
||||
resume=resume,
|
||||
memprofile=memprofile,
|
||||
cache=cache,
|
||||
reporter=reporter,
|
||||
logger=logger,
|
||||
dry_run=dry_run,
|
||||
skip_validation=skip_validation,
|
||||
output_dir=output_dir,
|
||||
@ -97,7 +96,7 @@ def update_cli(
|
||||
verbose: bool,
|
||||
memprofile: bool,
|
||||
cache: bool,
|
||||
reporter: ReporterType,
|
||||
logger: LoggerType,
|
||||
config_filepath: Path | None,
|
||||
skip_validation: bool,
|
||||
output_dir: Path | None,
|
||||
@ -121,7 +120,7 @@ def update_cli(
|
||||
resume=False,
|
||||
memprofile=memprofile,
|
||||
cache=cache,
|
||||
reporter=reporter,
|
||||
logger=logger,
|
||||
dry_run=False,
|
||||
skip_validation=skip_validation,
|
||||
output_dir=output_dir,
|
||||
@ -134,13 +133,13 @@ def _run_index(
|
||||
resume,
|
||||
memprofile,
|
||||
cache,
|
||||
reporter,
|
||||
logger,
|
||||
dry_run,
|
||||
skip_validation,
|
||||
output_dir,
|
||||
):
|
||||
progress_reporter = create_progress_reporter(reporter)
|
||||
info, error, success = _logger(progress_reporter)
|
||||
progress_logger = LoggerFactory().create_logger(logger)
|
||||
info, error, success = _logger(progress_logger)
|
||||
run_id = resume or time.strftime("%Y%m%d-%H%M%S")
|
||||
|
||||
config.storage.base_dir = str(output_dir) if output_dir else config.storage.base_dir
|
||||
@ -162,7 +161,7 @@ def _run_index(
|
||||
)
|
||||
|
||||
if skip_validation:
|
||||
validate_config_names(progress_reporter, config)
|
||||
validate_config_names(progress_logger, config)
|
||||
|
||||
info(f"Starting pipeline run for: {run_id}, {dry_run=}", verbose)
|
||||
info(
|
||||
@ -174,7 +173,7 @@ def _run_index(
|
||||
info("Dry run complete, exiting...", True)
|
||||
sys.exit(0)
|
||||
|
||||
_register_signal_handlers(progress_reporter)
|
||||
_register_signal_handlers(progress_logger)
|
||||
|
||||
outputs = asyncio.run(
|
||||
api.build_index(
|
||||
@ -182,14 +181,14 @@ def _run_index(
|
||||
run_id=run_id,
|
||||
is_resume_run=bool(resume),
|
||||
memory_profile=memprofile,
|
||||
progress_reporter=progress_reporter,
|
||||
progress_logger=progress_logger,
|
||||
)
|
||||
)
|
||||
encountered_errors = any(
|
||||
output.errors and len(output.errors) > 0 for output in outputs
|
||||
)
|
||||
|
||||
progress_reporter.stop()
|
||||
progress_logger.stop()
|
||||
if encountered_errors:
|
||||
error(
|
||||
"Errors occurred during the pipeline run, see logs for more details.", True
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""CLI implementation of initialization subcommand."""
|
||||
"""CLI implementation of the initialization subcommand."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag.config.init_content import INIT_DOTENV, INIT_YAML
|
||||
from graphrag.logging.factory import create_progress_reporter
|
||||
from graphrag.logging.types import ReporterType
|
||||
from graphrag.logger.factory import LoggerFactory, LoggerType
|
||||
from graphrag.prompts.index.claim_extraction import CLAIM_EXTRACTION_PROMPT
|
||||
from graphrag.prompts.index.community_report import (
|
||||
COMMUNITY_REPORT_PROMPT,
|
||||
@ -28,8 +27,8 @@ from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PR
|
||||
|
||||
def initialize_project_at(path: Path) -> None:
|
||||
"""Initialize the project at the given path."""
|
||||
progress_reporter = create_progress_reporter(ReporterType.RICH)
|
||||
progress_reporter.info(f"Initializing project at {path}")
|
||||
progress_logger = LoggerFactory().create_logger(LoggerType.RICH)
|
||||
progress_logger.info(f"Initializing project at {path}") # noqa: G004
|
||||
root = Path(path)
|
||||
if not root.exists():
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@ -12,7 +12,7 @@ from typing import Annotated
|
||||
|
||||
import typer
|
||||
|
||||
from graphrag.logging.types import ReporterType
|
||||
from graphrag.logger.types import LoggerType
|
||||
from graphrag.prompt_tune.defaults import (
|
||||
MAX_TOKEN_COUNT,
|
||||
MIN_CHUNK_SIZE,
|
||||
@ -145,9 +145,9 @@ def _index_cli(
|
||||
resume: Annotated[
|
||||
str | None, typer.Option(help="Resume a given indexing run")
|
||||
] = None,
|
||||
reporter: Annotated[
|
||||
ReporterType, typer.Option(help="The progress reporter to use.")
|
||||
] = ReporterType.RICH,
|
||||
logger: Annotated[
|
||||
LoggerType, typer.Option(help="The progress logger to use.")
|
||||
] = LoggerType.RICH,
|
||||
dry_run: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
@ -180,7 +180,7 @@ def _index_cli(
|
||||
resume=resume,
|
||||
memprofile=memprofile,
|
||||
cache=cache,
|
||||
reporter=ReporterType(reporter),
|
||||
logger=LoggerType(logger),
|
||||
config_filepath=config,
|
||||
dry_run=dry_run,
|
||||
skip_validation=skip_validation,
|
||||
@ -212,9 +212,9 @@ def _update_cli(
|
||||
memprofile: Annotated[
|
||||
bool, typer.Option(help="Run the indexing pipeline with memory profiling")
|
||||
] = False,
|
||||
reporter: Annotated[
|
||||
ReporterType, typer.Option(help="The progress reporter to use.")
|
||||
] = ReporterType.RICH,
|
||||
logger: Annotated[
|
||||
LoggerType, typer.Option(help="The progress logger to use.")
|
||||
] = LoggerType.RICH,
|
||||
cache: Annotated[bool, typer.Option(help="Use LLM cache.")] = True,
|
||||
skip_validation: Annotated[
|
||||
bool,
|
||||
@ -244,7 +244,7 @@ def _update_cli(
|
||||
verbose=verbose,
|
||||
memprofile=memprofile,
|
||||
cache=cache,
|
||||
reporter=ReporterType(reporter),
|
||||
logger=LoggerType(logger),
|
||||
config_filepath=config,
|
||||
skip_validation=skip_validation,
|
||||
output_dir=output,
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""CLI implementation of prompt-tune subcommand."""
|
||||
"""CLI implementation of the prompt-tune subcommand."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import graphrag.api as api
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.logging.print_progress import PrintProgressReporter
|
||||
from graphrag.logger.print_progress import PrintProgressLogger
|
||||
from graphrag.prompt_tune.generator.community_report_summarization import (
|
||||
COMMUNITY_SUMMARIZATION_FILENAME,
|
||||
)
|
||||
@ -52,7 +52,7 @@ async def prompt_tune(
|
||||
- k: The number of documents to select when using auto selection method.
|
||||
- min_examples_required: The minimum number of examples required for entity extraction prompts.
|
||||
"""
|
||||
reporter = PrintProgressReporter("")
|
||||
logger = PrintProgressLogger("")
|
||||
root_path = Path(root).resolve()
|
||||
graph_config = load_config(root_path, config)
|
||||
|
||||
@ -73,7 +73,7 @@ async def prompt_tune(
|
||||
|
||||
output_path = output.resolve()
|
||||
if output_path:
|
||||
reporter.info(f"Writing prompts to {output_path}")
|
||||
logger.info(f"Writing prompts to {output_path}") # noqa: G004
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
entity_extraction_prompt_path = output_path / ENTITY_EXTRACTION_FILENAME
|
||||
entity_summarization_prompt_path = output_path / ENTITY_SUMMARIZATION_FILENAME
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""CLI implementation of query subcommand."""
|
||||
"""CLI implementation of the query subcommand."""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
@ -14,11 +14,11 @@ from graphrag.config.load_config import load_config
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.resolve_path import resolve_paths
|
||||
from graphrag.index.create_pipeline_config import create_pipeline_config
|
||||
from graphrag.logging.print_progress import PrintProgressReporter
|
||||
from graphrag.storage.factory import create_storage
|
||||
from graphrag.logger.print_progress import PrintProgressLogger
|
||||
from graphrag.storage.factory import StorageFactory
|
||||
from graphrag.utils.storage import load_table_from_storage
|
||||
|
||||
reporter = PrintProgressReporter("")
|
||||
logger = PrintProgressLogger("")
|
||||
|
||||
|
||||
def run_global_search(
|
||||
@ -40,9 +40,9 @@ def run_global_search(
|
||||
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
|
||||
resolve_paths(config)
|
||||
|
||||
dataframe_dict = _resolve_parquet_files(
|
||||
dataframe_dict = _resolve_output_files(
|
||||
config=config,
|
||||
parquet_list=[
|
||||
output_list=[
|
||||
"create_final_nodes.parquet",
|
||||
"create_final_entities.parquet",
|
||||
"create_final_communities.parquet",
|
||||
@ -100,7 +100,7 @@ def run_global_search(
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
reporter.success(f"Global Search Response:\n{response}")
|
||||
logger.success(f"Global Search Response:\n{response}")
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
return response, context_data
|
||||
@ -124,9 +124,9 @@ def run_local_search(
|
||||
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
|
||||
resolve_paths(config)
|
||||
|
||||
dataframe_dict = _resolve_parquet_files(
|
||||
dataframe_dict = _resolve_output_files(
|
||||
config=config,
|
||||
parquet_list=[
|
||||
output_list=[
|
||||
"create_final_nodes.parquet",
|
||||
"create_final_community_reports.parquet",
|
||||
"create_final_text_units.parquet",
|
||||
@ -191,7 +191,7 @@ def run_local_search(
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
reporter.success(f"Local Search Response:\n{response}")
|
||||
logger.success(f"Local Search Response:\n{response}")
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
return response, context_data
|
||||
@ -214,9 +214,9 @@ def run_drift_search(
|
||||
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
|
||||
resolve_paths(config)
|
||||
|
||||
dataframe_dict = _resolve_parquet_files(
|
||||
dataframe_dict = _resolve_output_files(
|
||||
config=config,
|
||||
parquet_list=[
|
||||
output_list=[
|
||||
"create_final_nodes.parquet",
|
||||
"create_final_community_reports.parquet",
|
||||
"create_final_text_units.parquet",
|
||||
@ -250,30 +250,33 @@ def run_drift_search(
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
reporter.success(f"DRIFT Search Response:\n{response}")
|
||||
logger.success(f"DRIFT Search Response:\n{response}")
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
# TODO: Map/Reduce Drift Search answer to a single response
|
||||
return response, context_data
|
||||
|
||||
|
||||
def _resolve_parquet_files(
|
||||
def _resolve_output_files(
|
||||
config: GraphRagConfig,
|
||||
parquet_list: list[str],
|
||||
output_list: list[str],
|
||||
optional_list: list[str] | None = None,
|
||||
) -> dict[str, pd.DataFrame]:
|
||||
"""Read parquet files to a dataframe dict."""
|
||||
"""Read indexing output files to a dataframe dict."""
|
||||
dataframe_dict = {}
|
||||
pipeline_config = create_pipeline_config(config)
|
||||
storage_obj = create_storage(pipeline_config.storage) # type: ignore
|
||||
for parquet_file in parquet_list:
|
||||
df_key = parquet_file.split(".")[0]
|
||||
storage_config = pipeline_config.storage.model_dump() # type: ignore
|
||||
storage_obj = StorageFactory().create_storage(
|
||||
storage_type=storage_config["type"], kwargs=storage_config
|
||||
)
|
||||
for output_file in output_list:
|
||||
df_key = output_file.split(".")[0]
|
||||
df_value = asyncio.run(
|
||||
load_table_from_storage(name=parquet_file, storage=storage_obj)
|
||||
load_table_from_storage(name=output_file, storage=storage_obj)
|
||||
)
|
||||
dataframe_dict[df_key] = df_value
|
||||
|
||||
# for optional parquet files, set the dict entry to None instead of erroring out if it does not exist
|
||||
# for optional output files, set the dict entry to None instead of erroring out if it does not exist
|
||||
if optional_list:
|
||||
for optional_file in optional_list:
|
||||
file_exists = asyncio.run(storage_obj.has(optional_file))
|
||||
|
||||
@ -33,7 +33,7 @@ async_mode: {defs.ASYNC_MODE.value} # or asyncio
|
||||
|
||||
embeddings:
|
||||
async_mode: {defs.ASYNC_MODE.value} # or asyncio
|
||||
vector_store:{defs.VECTOR_STORE}
|
||||
vector_store: {defs.VECTOR_STORE}
|
||||
llm:
|
||||
api_key: ${{GRAPHRAG_API_KEY}}
|
||||
type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'PipelineCacheConfig', 'PipelineFileCacheConfig' and 'PipelineMemoryCacheConfig' models."""
|
||||
"""A module containing 'PipelineCacheConfig', 'PipelineFileCacheConfig', 'PipelineMemoryCacheConfig', 'PipelineBlobCacheConfig' models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ import pandas as pd
|
||||
|
||||
from graphrag.index.config.input import PipelineCSVInputConfig, PipelineInputConfig
|
||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
from graphrag.logging.base import ProgressReporter
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -24,7 +24,7 @@ input_type = "csv"
|
||||
|
||||
async def load(
|
||||
config: PipelineInputConfig,
|
||||
progress: ProgressReporter | None,
|
||||
progress: ProgressLogger | None,
|
||||
storage: PipelineStorage,
|
||||
) -> pd.DataFrame:
|
||||
"""Load csv inputs from a directory."""
|
||||
|
||||
@ -17,8 +17,8 @@ from graphrag.index.input.csv import input_type as csv
|
||||
from graphrag.index.input.csv import load as load_csv
|
||||
from graphrag.index.input.text import input_type as text
|
||||
from graphrag.index.input.text import load as load_text
|
||||
from graphrag.logging.base import ProgressReporter
|
||||
from graphrag.logging.null_progress import NullProgressReporter
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.null_progress import NullProgressLogger
|
||||
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
|
||||
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
||||
|
||||
@ -31,13 +31,13 @@ loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = {
|
||||
|
||||
async def create_input(
|
||||
config: PipelineInputConfig | InputConfig,
|
||||
progress_reporter: ProgressReporter | None = None,
|
||||
progress_reporter: ProgressLogger | None = None,
|
||||
root_dir: str | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Instantiate input data for a pipeline."""
|
||||
root_dir = root_dir or ""
|
||||
log.info("loading input from root_dir=%s", config.base_dir)
|
||||
progress_reporter = progress_reporter or NullProgressReporter()
|
||||
progress_reporter = progress_reporter or NullProgressLogger()
|
||||
|
||||
match config.type:
|
||||
case InputType.blob:
|
||||
|
||||
@ -12,7 +12,7 @@ import pandas as pd
|
||||
|
||||
from graphrag.index.config.input import PipelineInputConfig
|
||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
from graphrag.logging.base import ProgressReporter
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
DEFAULT_FILE_PATTERN = re.compile(
|
||||
@ -24,7 +24,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
async def load(
|
||||
config: PipelineInputConfig,
|
||||
progress: ProgressReporter | None,
|
||||
progress: ProgressLogger | None,
|
||||
storage: PipelineStorage,
|
||||
) -> pd.DataFrame:
|
||||
"""Load text inputs from a directory."""
|
||||
|
||||
@ -217,7 +217,7 @@ def _create_vector_store(
|
||||
if collection_name:
|
||||
vector_store_config.update({"collection_name": collection_name})
|
||||
|
||||
vector_store = VectorStoreFactory.get_vector_store(
|
||||
vector_store = VectorStoreFactory().create_vector_store(
|
||||
vector_store_type, kwargs=vector_store_config
|
||||
)
|
||||
|
||||
|
||||
@ -8,21 +8,18 @@ import logging
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import AsyncIterable
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import NoopVerbCallbacks, WorkflowCallbacks
|
||||
|
||||
from graphrag.cache.factory import create_cache
|
||||
from graphrag.cache.factory import CacheFactory
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
|
||||
from graphrag.index.config.cache import PipelineMemoryCacheConfig
|
||||
from graphrag.index.config.pipeline import (
|
||||
PipelineConfig,
|
||||
PipelineWorkflowReference,
|
||||
)
|
||||
from graphrag.index.config.storage import PipelineFileStorageConfig
|
||||
from graphrag.index.config.workflow import PipelineWorkflowStep
|
||||
from graphrag.index.exporter import ParquetExporter
|
||||
from graphrag.index.input.factory import create_input
|
||||
@ -51,9 +48,9 @@ from graphrag.index.workflows import (
|
||||
WorkflowDefinitions,
|
||||
load_workflows,
|
||||
)
|
||||
from graphrag.logging.base import ProgressReporter
|
||||
from graphrag.logging.null_progress import NullProgressReporter
|
||||
from graphrag.storage.factory import create_storage
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.null_progress import NullProgressLogger
|
||||
from graphrag.storage.factory import StorageFactory
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -67,7 +64,7 @@ async def run_pipeline_with_config(
|
||||
update_index_storage: PipelineStorage | None = None,
|
||||
cache: PipelineCache | None = None,
|
||||
callbacks: list[WorkflowCallbacks] | None = None,
|
||||
progress_reporter: ProgressReporter | None = None,
|
||||
logger: ProgressLogger | None = None,
|
||||
input_post_process_steps: list[PipelineWorkflowStep] | None = None,
|
||||
additional_verbs: VerbDefinitions | None = None,
|
||||
additional_workflows: WorkflowDefinitions | None = None,
|
||||
@ -85,7 +82,7 @@ async def run_pipeline_with_config(
|
||||
- dataset - The dataset to run the pipeline on (this overrides the config)
|
||||
- storage - The storage to use for the pipeline (this overrides the config)
|
||||
- cache - The cache to use for the pipeline (this overrides the config)
|
||||
- reporter - The reporter to use for the pipeline (this overrides the config)
|
||||
- logger - The logger to use for the pipeline (this overrides the config)
|
||||
- input_post_process_steps - The post process steps to run on the input data (this overrides the config)
|
||||
- additional_verbs - The custom verbs to use for the pipeline.
|
||||
- additional_workflows - The custom workflows to use for the pipeline.
|
||||
@ -102,23 +99,32 @@ async def run_pipeline_with_config(
|
||||
config = _apply_substitutions(config, run_id)
|
||||
root_dir = config.root_dir or ""
|
||||
|
||||
progress_reporter = progress_reporter or NullProgressReporter()
|
||||
storage = storage = create_storage(config.storage) # type: ignore
|
||||
progress_logger = logger or NullProgressLogger()
|
||||
storage_config = config.storage.model_dump() # type: ignore
|
||||
storage = storage or StorageFactory().create_storage(
|
||||
storage_type=storage_config["type"], # type: ignore
|
||||
kwargs=storage_config,
|
||||
)
|
||||
|
||||
if is_update_run:
|
||||
# TODO: remove the default choice (PipelineFileStorageConfig) once the new config system enforces a correct update-index-storage config when used.
|
||||
update_index_storage = update_index_storage or create_storage(
|
||||
config.update_index_storage
|
||||
or PipelineFileStorageConfig(base_dir=str(Path(root_dir) / "output"))
|
||||
update_storage_config = config.update_index_storage.model_dump() # type: ignore
|
||||
update_index_storage = update_index_storage or StorageFactory().create_storage(
|
||||
storage_type=update_storage_config["type"], # type: ignore
|
||||
kwargs=update_storage_config,
|
||||
)
|
||||
|
||||
# TODO: remove the default choice (PipelineMemoryCacheConfig) when the new config system guarantees the existence of a cache config
|
||||
cache = cache or create_cache(config.cache or PipelineMemoryCacheConfig(), root_dir)
|
||||
# TODO: remove the type ignore when the new config system guarantees the existence of a cache config
|
||||
cache_config = config.cache.model_dump() # type: ignore
|
||||
cache = cache or CacheFactory().create_cache(
|
||||
cache_type=cache_config["type"], # type: ignore
|
||||
root_dir=root_dir,
|
||||
kwargs=cache_config,
|
||||
)
|
||||
# TODO: remove the type ignore when the new config system guarantees the existence of an input config
|
||||
dataset = (
|
||||
dataset
|
||||
if dataset is not None
|
||||
else await create_input(config.input, progress_reporter, root_dir) # type: ignore
|
||||
else await create_input(config.input, progress_logger, root_dir) # type: ignore
|
||||
)
|
||||
|
||||
post_process_steps = input_post_process_steps or _create_postprocess_steps(
|
||||
@ -148,12 +154,12 @@ async def run_pipeline_with_config(
|
||||
memory_profile=memory_profile,
|
||||
additional_verbs=additional_verbs,
|
||||
additional_workflows=additional_workflows,
|
||||
progress_reporter=progress_reporter,
|
||||
progress_logger=progress_logger,
|
||||
is_resume_run=False,
|
||||
):
|
||||
tables_dict[table.workflow] = table.result
|
||||
|
||||
progress_reporter.success("Finished running workflows on new documents.")
|
||||
progress_logger.success("Finished running workflows on new documents.")
|
||||
await update_dataframe_outputs(
|
||||
dataframe_dict=tables_dict,
|
||||
storage=storage,
|
||||
@ -161,7 +167,7 @@ async def run_pipeline_with_config(
|
||||
config=config,
|
||||
cache=cache,
|
||||
callbacks=NoopVerbCallbacks(),
|
||||
progress_reporter=progress_reporter,
|
||||
progress_logger=progress_logger,
|
||||
)
|
||||
|
||||
else:
|
||||
@ -175,7 +181,7 @@ async def run_pipeline_with_config(
|
||||
memory_profile=memory_profile,
|
||||
additional_verbs=additional_verbs,
|
||||
additional_workflows=additional_workflows,
|
||||
progress_reporter=progress_reporter,
|
||||
progress_logger=progress_logger,
|
||||
is_resume_run=is_resume_run,
|
||||
):
|
||||
yield table
|
||||
@ -187,7 +193,7 @@ async def run_pipeline(
|
||||
storage: PipelineStorage | None = None,
|
||||
cache: PipelineCache | None = None,
|
||||
callbacks: list[WorkflowCallbacks] | None = None,
|
||||
progress_reporter: ProgressReporter | None = None,
|
||||
progress_logger: ProgressLogger | None = None,
|
||||
input_post_process_steps: list[PipelineWorkflowStep] | None = None,
|
||||
additional_verbs: VerbDefinitions | None = None,
|
||||
additional_workflows: WorkflowDefinitions | None = None,
|
||||
@ -206,7 +212,7 @@ async def run_pipeline(
|
||||
These must exist after any post process steps are run if there are any!
|
||||
- storage - The storage to use for the pipeline
|
||||
- cache - The cache to use for the pipeline
|
||||
- reporter - The reporter to use for the pipeline
|
||||
- progress_logger - The logger to use for the pipeline
|
||||
- input_post_process_steps - The post process steps to run on the input data
|
||||
- additional_verbs - The custom verbs to use for the pipeline
|
||||
- additional_workflows - The custom workflows to use for the pipeline
|
||||
@ -216,7 +222,7 @@ async def run_pipeline(
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
progress_reporter = progress_reporter or NullProgressReporter()
|
||||
progress_reporter = progress_logger or NullProgressLogger()
|
||||
callbacks = callbacks or [ConsoleWorkflowCallbacks()]
|
||||
callback_chain = _create_callback_chain(callbacks, progress_reporter)
|
||||
context = create_run_context(storage=storage, cache=cache, stats=None)
|
||||
|
||||
@ -21,7 +21,7 @@ from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.exporter import ParquetExporter
|
||||
from graphrag.index.run.profiling import _write_workflow_stats
|
||||
from graphrag.index.typing import PipelineRunResult
|
||||
from graphrag.logging.base import ProgressReporter
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
from graphrag.utils.storage import load_table_from_storage
|
||||
|
||||
@ -68,7 +68,7 @@ async def _export_workflow_output(
|
||||
|
||||
|
||||
def _create_callback_chain(
|
||||
callbacks: list[WorkflowCallbacks] | None, progress: ProgressReporter | None
|
||||
callbacks: list[WorkflowCallbacks] | None, progress: ProgressLogger | None
|
||||
) -> WorkflowCallbacks:
|
||||
"""Create a callback manager that encompasses multiple callbacks."""
|
||||
manager = WorkflowCallbacksManager()
|
||||
|
||||
@ -23,7 +23,7 @@ from graphrag.index.update.entities import (
|
||||
_run_entity_summarization,
|
||||
)
|
||||
from graphrag.index.update.relationships import _update_and_merge_relationships
|
||||
from graphrag.logging.print_progress import ProgressReporter
|
||||
from graphrag.logger.print_progress import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
from graphrag.utils.storage import load_table_from_storage
|
||||
|
||||
@ -85,7 +85,7 @@ async def update_dataframe_outputs(
|
||||
config: PipelineConfig,
|
||||
cache: PipelineCache,
|
||||
callbacks: VerbCallbacks,
|
||||
progress_reporter: ProgressReporter,
|
||||
progress_logger: ProgressLogger,
|
||||
) -> None:
|
||||
"""Update the mergeable outputs.
|
||||
|
||||
@ -96,25 +96,25 @@ async def update_dataframe_outputs(
|
||||
storage : PipelineStorage
|
||||
The storage used to store the dataframes.
|
||||
"""
|
||||
progress_reporter.info("Updating Final Documents")
|
||||
progress_logger.info("Updating Final Documents")
|
||||
final_documents_df = await _concat_dataframes(
|
||||
"create_final_documents", dataframe_dict, storage, update_storage
|
||||
)
|
||||
|
||||
# Update entities and merge them
|
||||
progress_reporter.info("Updating Final Entities")
|
||||
progress_logger.info("Updating Final Entities")
|
||||
merged_entities_df, entity_id_mapping = await _update_entities(
|
||||
dataframe_dict, storage, update_storage, config, cache, callbacks
|
||||
)
|
||||
|
||||
# Update relationships with the entities id mapping
|
||||
progress_reporter.info("Updating Final Relationships")
|
||||
progress_logger.info("Updating Final Relationships")
|
||||
merged_relationships_df = await _update_relationships(
|
||||
dataframe_dict, storage, update_storage
|
||||
)
|
||||
|
||||
# Update and merge final text units
|
||||
progress_reporter.info("Updating Final Text Units")
|
||||
progress_logger.info("Updating Final Text Units")
|
||||
merged_text_units = await _update_text_units(
|
||||
dataframe_dict, storage, update_storage, entity_id_mapping
|
||||
)
|
||||
@ -124,23 +124,23 @@ async def update_dataframe_outputs(
|
||||
await storage.has("create_final_covariates.parquet")
|
||||
and "create_final_covariates" in dataframe_dict
|
||||
):
|
||||
progress_reporter.info("Updating Final Covariates")
|
||||
progress_logger.info("Updating Final Covariates")
|
||||
await _update_covariates(dataframe_dict, storage, update_storage)
|
||||
|
||||
# Merge final nodes and update community ids
|
||||
progress_reporter.info("Updating Final Nodes")
|
||||
progress_logger.info("Updating Final Nodes")
|
||||
_, community_id_mapping = await _update_nodes(
|
||||
dataframe_dict, storage, update_storage, merged_entities_df
|
||||
)
|
||||
|
||||
# Merge final communities
|
||||
progress_reporter.info("Updating Final Communities")
|
||||
progress_logger.info("Updating Final Communities")
|
||||
await _update_communities(
|
||||
dataframe_dict, storage, update_storage, community_id_mapping
|
||||
)
|
||||
|
||||
# Merge community reports
|
||||
progress_reporter.info("Updating Final Community Reports")
|
||||
progress_logger.info("Updating Final Community Reports")
|
||||
merged_community_reports = await _update_community_reports(
|
||||
dataframe_dict, storage, update_storage, community_id_mapping
|
||||
)
|
||||
@ -151,7 +151,7 @@ async def update_dataframe_outputs(
|
||||
)
|
||||
|
||||
# Generate text embeddings
|
||||
progress_reporter.info("Updating Text Embeddings")
|
||||
progress_logger.info("Updating Text Embeddings")
|
||||
await generate_text_embeddings(
|
||||
final_documents=final_documents_df,
|
||||
final_relationships=merged_relationships_df,
|
||||
|
||||
@ -10,12 +10,10 @@ from datashaper import NoopVerbCallbacks
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.llm.load_llm import load_llm, load_llm_embeddings
|
||||
from graphrag.logging.print_progress import ProgressReporter
|
||||
from graphrag.logger.print_progress import ProgressLogger
|
||||
|
||||
|
||||
def validate_config_names(
|
||||
reporter: ProgressReporter, parameters: GraphRagConfig
|
||||
) -> None:
|
||||
def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) -> None:
|
||||
"""Validate config file for LLM deployment name typos."""
|
||||
# Validate Chat LLM configs
|
||||
llm = load_llm(
|
||||
@ -26,9 +24,9 @@ def validate_config_names(
|
||||
)
|
||||
try:
|
||||
asyncio.run(llm("This is an LLM connectivity test. Say Hello World"))
|
||||
reporter.success("LLM Config Params Validated")
|
||||
logger.success("LLM Config Params Validated")
|
||||
except Exception as e: # noqa: BLE001
|
||||
reporter.error(f"LLM configuration error detected. Exiting...\n{e}")
|
||||
logger.error(f"LLM configuration error detected. Exiting...\n{e}") # noqa
|
||||
sys.exit(1)
|
||||
|
||||
# Validate Embeddings LLM configs
|
||||
@ -40,7 +38,7 @@ def validate_config_names(
|
||||
)
|
||||
try:
|
||||
asyncio.run(embed_llm(["This is an LLM Embedding Test String"]))
|
||||
reporter.success("Embedding LLM Config Params Validated")
|
||||
logger.success("Embedding LLM Config Params Validated")
|
||||
except Exception as e: # noqa: BLE001
|
||||
reporter.error(f"Embedding LLM configuration error detected. Exiting...\n{e}")
|
||||
logger.error(f"Embedding LLM configuration error detected. Exiting...\n{e}") # noqa
|
||||
sys.exit(1)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Logging utilities and implementations."""
|
||||
"""Logger utilities and implementations."""
|
||||
@ -10,24 +10,24 @@ from datashaper.progress.types import Progress
|
||||
|
||||
|
||||
class StatusLogger(ABC):
|
||||
"""Provides a way to report status updates from the pipeline."""
|
||||
"""Provides a way to log status updates from the pipeline."""
|
||||
|
||||
@abstractmethod
|
||||
def error(self, message: str, details: dict[str, Any] | None = None):
|
||||
"""Report an error."""
|
||||
"""Log an error."""
|
||||
|
||||
@abstractmethod
|
||||
def warning(self, message: str, details: dict[str, Any] | None = None):
|
||||
"""Report a warning."""
|
||||
"""Log a warning."""
|
||||
|
||||
@abstractmethod
|
||||
def log(self, message: str, details: dict[str, Any] | None = None):
|
||||
"""Report a log."""
|
||||
|
||||
|
||||
class ProgressReporter(ABC):
|
||||
class ProgressLogger(ABC):
|
||||
"""
|
||||
Abstract base class for progress reporters.
|
||||
Abstract base class for progress loggers.
|
||||
|
||||
This is used to report workflow processing progress via mechanisms like progress-bars.
|
||||
"""
|
||||
@ -38,10 +38,10 @@ class ProgressReporter(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def dispose(self):
|
||||
"""Dispose of the progress reporter."""
|
||||
"""Dispose of the progress logger."""
|
||||
|
||||
@abstractmethod
|
||||
def child(self, prefix: str, transient=True) -> "ProgressReporter":
|
||||
def child(self, prefix: str, transient=True) -> "ProgressLogger":
|
||||
"""Create a child progress bar."""
|
||||
|
||||
@abstractmethod
|
||||
@ -50,20 +50,20 @@ class ProgressReporter(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""Stop the progress reporter."""
|
||||
"""Stop the progress logger."""
|
||||
|
||||
@abstractmethod
|
||||
def error(self, message: str) -> None:
|
||||
"""Report an error."""
|
||||
"""Log an error."""
|
||||
|
||||
@abstractmethod
|
||||
def warning(self, message: str) -> None:
|
||||
"""Report a warning."""
|
||||
"""Log a warning."""
|
||||
|
||||
@abstractmethod
|
||||
def info(self, message: str) -> None:
|
||||
"""Report information."""
|
||||
"""Log information."""
|
||||
|
||||
@abstractmethod
|
||||
def success(self, message: str) -> None:
|
||||
"""Report success."""
|
||||
"""Log success."""
|
||||
@ -1,26 +1,26 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Console Reporter."""
|
||||
"""Console Log."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from graphrag.logging.base import StatusLogger
|
||||
from graphrag.logger.base import StatusLogger
|
||||
|
||||
|
||||
class ConsoleReporter(StatusLogger):
|
||||
"""A reporter that writes to a console."""
|
||||
"""A logger that writes to a console."""
|
||||
|
||||
def error(self, message: str, details: dict[str, Any] | None = None):
|
||||
"""Report an error."""
|
||||
"""Log an error."""
|
||||
print(message, details) # noqa T201
|
||||
|
||||
def warning(self, message: str, details: dict[str, Any] | None = None):
|
||||
"""Report a warning."""
|
||||
"""Log a warning."""
|
||||
_print_warning(message)
|
||||
|
||||
def log(self, message: str, details: dict[str, Any] | None = None):
|
||||
"""Report a log."""
|
||||
"""Log a log."""
|
||||
print(message, details) # noqa T201
|
||||
|
||||
|
||||
43
graphrag/logger/factory.py
Normal file
43
graphrag/logger/factory.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Factory functions for creating loggers."""
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.null_progress import NullProgressLogger
|
||||
from graphrag.logger.print_progress import PrintProgressLogger
|
||||
from graphrag.logger.rich_progress import RichProgressLogger
|
||||
from graphrag.logger.types import LoggerType
|
||||
|
||||
|
||||
class LoggerFactory:
|
||||
"""A factory class for loggers."""
|
||||
|
||||
logger_types: ClassVar[dict[str, type]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, logger_type: str, logger: type):
|
||||
"""Register a custom logger implementation."""
|
||||
cls.logger_types[logger_type] = logger
|
||||
|
||||
@classmethod
|
||||
def create_logger(
|
||||
cls, logger_type: LoggerType | str, kwargs: dict | None = None
|
||||
) -> ProgressLogger:
|
||||
"""Create a logger based on the provided type."""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
match logger_type:
|
||||
case LoggerType.RICH:
|
||||
return RichProgressLogger("GraphRAG Indexer ")
|
||||
case LoggerType.PRINT:
|
||||
return PrintProgressLogger("GraphRAG Indexer ")
|
||||
case LoggerType.NONE:
|
||||
return NullProgressLogger()
|
||||
case _:
|
||||
if logger_type in cls.logger_types:
|
||||
return cls.logger_types[logger_type](**kwargs)
|
||||
# default to null logger if no other logger is found
|
||||
return NullProgressLogger()
|
||||
@ -3,19 +3,19 @@
|
||||
|
||||
"""Null Progress Reporter."""
|
||||
|
||||
from graphrag.logging.base import Progress, ProgressReporter
|
||||
from graphrag.logger.base import Progress, ProgressLogger
|
||||
|
||||
|
||||
class NullProgressReporter(ProgressReporter):
|
||||
"""A progress reporter that does nothing."""
|
||||
class NullProgressLogger(ProgressLogger):
|
||||
"""A progress logger that does nothing."""
|
||||
|
||||
def __call__(self, update: Progress) -> None:
|
||||
"""Update progress."""
|
||||
|
||||
def dispose(self) -> None:
|
||||
"""Dispose of the progress reporter."""
|
||||
"""Dispose of the progress logger."""
|
||||
|
||||
def child(self, prefix: str, transient: bool = True) -> ProgressReporter:
|
||||
def child(self, prefix: str, transient: bool = True) -> ProgressLogger:
|
||||
"""Create a child progress bar."""
|
||||
return self
|
||||
|
||||
@ -23,16 +23,16 @@ class NullProgressReporter(ProgressReporter):
|
||||
"""Force a refresh."""
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the progress reporter."""
|
||||
"""Stop the progress logger."""
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""Report an error."""
|
||||
"""Log an error."""
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""Report a warning."""
|
||||
"""Log a warning."""
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""Report information."""
|
||||
"""Log information."""
|
||||
|
||||
def success(self, message: str) -> None:
|
||||
"""Report success."""
|
||||
"""Log success."""
|
||||
@ -1,18 +1,18 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Print Progress Reporter."""
|
||||
"""Print Progress Logger."""
|
||||
|
||||
from graphrag.logging.base import Progress, ProgressReporter
|
||||
from graphrag.logger.base import Progress, ProgressLogger
|
||||
|
||||
|
||||
class PrintProgressReporter(ProgressReporter):
|
||||
"""A progress reporter that does nothing."""
|
||||
class PrintProgressLogger(ProgressLogger):
|
||||
"""A progress logger that prints progress to stdout."""
|
||||
|
||||
prefix: str
|
||||
|
||||
def __init__(self, prefix: str):
|
||||
"""Create a new progress reporter."""
|
||||
"""Create a new progress logger."""
|
||||
self.prefix = prefix
|
||||
print(f"\n{self.prefix}", end="") # noqa T201
|
||||
|
||||
@ -21,30 +21,30 @@ class PrintProgressReporter(ProgressReporter):
|
||||
print(".", end="") # noqa T201
|
||||
|
||||
def dispose(self) -> None:
|
||||
"""Dispose of the progress reporter."""
|
||||
"""Dispose of the progress logger."""
|
||||
|
||||
def child(self, prefix: str, transient: bool = True) -> "ProgressReporter":
|
||||
def child(self, prefix: str, transient: bool = True) -> ProgressLogger:
|
||||
"""Create a child progress bar."""
|
||||
return PrintProgressReporter(prefix)
|
||||
return PrintProgressLogger(prefix)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the progress reporter."""
|
||||
"""Stop the progress logger."""
|
||||
|
||||
def force_refresh(self) -> None:
|
||||
"""Force a refresh."""
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""Report an error."""
|
||||
"""Log an error."""
|
||||
print(f"\n{self.prefix}ERROR: {message}") # noqa T201
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""Report a warning."""
|
||||
"""Log a warning."""
|
||||
print(f"\n{self.prefix}WARNING: {message}") # noqa T201
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""Report information."""
|
||||
"""Log information."""
|
||||
print(f"\n{self.prefix}INFO: {message}") # noqa T201
|
||||
|
||||
def success(self, message: str) -> None:
|
||||
"""Report success."""
|
||||
"""Log success."""
|
||||
print(f"\n{self.prefix}SUCCESS: {message}") # noqa T201
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Rich-based progress reporter for CLI use."""
|
||||
"""Rich-based progress logger for CLI use."""
|
||||
|
||||
# Print iterations progress
|
||||
import asyncio
|
||||
@ -13,12 +13,12 @@ from rich.progress import Progress, TaskID, TimeElapsedColumn
|
||||
from rich.spinner import Spinner
|
||||
from rich.tree import Tree
|
||||
|
||||
from graphrag.logging.base import ProgressReporter
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
|
||||
|
||||
# https://stackoverflow.com/a/34325723
|
||||
class RichProgressReporter(ProgressReporter):
|
||||
"""A rich-based progress reporter for CLI use."""
|
||||
class RichProgressLogger(ProgressLogger):
|
||||
"""A rich-based progress logger for CLI use."""
|
||||
|
||||
_console: Console
|
||||
_group: Group
|
||||
@ -32,7 +32,7 @@ class RichProgressReporter(ProgressReporter):
|
||||
_last_refresh: float = 0
|
||||
|
||||
def dispose(self) -> None:
|
||||
"""Dispose of the progress reporter."""
|
||||
"""Dispose of the progress logger."""
|
||||
self._disposing = True
|
||||
self._live.stop()
|
||||
|
||||
@ -59,10 +59,10 @@ class RichProgressReporter(ProgressReporter):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
parent: "RichProgressReporter | None" = None,
|
||||
parent: "RichProgressLogger | None" = None,
|
||||
transient: bool = True,
|
||||
) -> None:
|
||||
"""Create a new rich-based progress reporter."""
|
||||
"""Create a new rich-based progress logger."""
|
||||
self._prefix = prefix
|
||||
|
||||
if parent is None:
|
||||
@ -115,27 +115,27 @@ class RichProgressReporter(ProgressReporter):
|
||||
self.live.refresh()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the progress reporter."""
|
||||
"""Stop the progress logger."""
|
||||
self._live.stop()
|
||||
|
||||
def child(self, prefix: str, transient: bool = True) -> ProgressReporter:
|
||||
def child(self, prefix: str, transient: bool = True) -> ProgressLogger:
|
||||
"""Create a child progress bar."""
|
||||
return RichProgressReporter(parent=self, prefix=prefix, transient=transient)
|
||||
return RichProgressLogger(parent=self, prefix=prefix, transient=transient)
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""Report an error."""
|
||||
"""Log an error."""
|
||||
self._console.print(f"❌ [red]{message}[/red]")
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""Report a warning."""
|
||||
"""Log a warning."""
|
||||
self._console.print(f"⚠️ [yellow]{message}[/yellow]")
|
||||
|
||||
def success(self, message: str) -> None:
|
||||
"""Report success."""
|
||||
"""Log success."""
|
||||
self._console.print(f"🚀 [green]{message}[/green]")
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""Report information."""
|
||||
"""Log information."""
|
||||
self._console.print(message)
|
||||
|
||||
def __call__(self, progress_update: DSProgress) -> None:
|
||||
22
graphrag/logger/types.py
Normal file
22
graphrag/logger/types.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Logging types.
|
||||
|
||||
This module defines the types of loggers that can be used.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
# Note: Code in this module was not included in the factory module because it negatively impacts the CLI experience.
|
||||
class LoggerType(str, Enum):
|
||||
"""The type of logger to use."""
|
||||
|
||||
RICH = "rich"
|
||||
PRINT = "print"
|
||||
NONE = "none"
|
||||
|
||||
def __str__(self):
|
||||
"""Return a string representation of the enum value."""
|
||||
return self.value
|
||||
@ -1,36 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Factory functions for creating loggers."""
|
||||
|
||||
from graphrag.logging.base import ProgressReporter
|
||||
from graphrag.logging.null_progress import NullProgressReporter
|
||||
from graphrag.logging.print_progress import PrintProgressReporter
|
||||
from graphrag.logging.rich_progress import RichProgressReporter
|
||||
from graphrag.logging.types import ReporterType
|
||||
|
||||
|
||||
def create_progress_reporter(
|
||||
reporter_type: ReporterType = ReporterType.NONE,
|
||||
) -> ProgressReporter:
|
||||
"""Load a progress reporter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
reporter_type : {"rich", "print", "none"}, default=rich
|
||||
The type of progress reporter to load.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ProgressReporter
|
||||
"""
|
||||
match reporter_type:
|
||||
case ReporterType.RICH:
|
||||
return RichProgressReporter("GraphRAG Indexer ")
|
||||
case ReporterType.PRINT:
|
||||
return PrintProgressReporter("GraphRAG Indexer ")
|
||||
case ReporterType.NONE:
|
||||
return NullProgressReporter()
|
||||
case _:
|
||||
msg = f"Invalid progress reporter type: {reporter_type}"
|
||||
raise ValueError(msg)
|
||||
@ -1,18 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Types for status reporting."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ReporterType(str, Enum):
|
||||
"""The type of reporter to use."""
|
||||
|
||||
RICH = "rich"
|
||||
PRINT = "print"
|
||||
NONE = "none"
|
||||
|
||||
def __str__(self):
|
||||
"""Return the string representation of the enum value."""
|
||||
return self.value
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""GraphRAG knowledge model package root."""
|
||||
"""Knowledge model package."""
|
||||
|
||||
@ -15,7 +15,7 @@ from graphrag.config.models.llm_parameters import LLMParameters
|
||||
from graphrag.index.input.factory import create_input
|
||||
from graphrag.index.llm.load_llm import load_llm_embeddings
|
||||
from graphrag.index.operations.chunk_text import chunk_text
|
||||
from graphrag.logging.base import ProgressReporter
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.prompt_tune.defaults import (
|
||||
MIN_CHUNK_OVERLAP,
|
||||
MIN_CHUNK_SIZE,
|
||||
@ -54,7 +54,7 @@ async def load_docs_in_chunks(
|
||||
config: GraphRagConfig,
|
||||
select_method: DocSelectionType,
|
||||
limit: int,
|
||||
reporter: ProgressReporter,
|
||||
logger: ProgressLogger,
|
||||
chunk_size: int = MIN_CHUNK_SIZE,
|
||||
n_subset_max: int = N_SUBSET_MAX,
|
||||
k: int = K,
|
||||
@ -64,7 +64,7 @@ async def load_docs_in_chunks(
|
||||
config.embeddings.resolved_strategy()["llm"]
|
||||
)
|
||||
|
||||
dataset = await create_input(config.input, reporter, root)
|
||||
dataset = await create_input(config.input, logger, root)
|
||||
|
||||
# covert to text units
|
||||
chunk_strategy = config.chunks.resolved_strategy(defs.ENCODING_MODEL)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All prompts for indexing."""
|
||||
"""All prompts for the indexing engine."""
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All prompts for query."""
|
||||
"""All prompts for the query engine."""
|
||||
|
||||
@ -8,8 +8,8 @@ from collections.abc import Callable
|
||||
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
||||
|
||||
from graphrag.logging.base import StatusLogger
|
||||
from graphrag.logging.console import ConsoleReporter
|
||||
from graphrag.logger.base import StatusLogger
|
||||
from graphrag.logger.console import ConsoleReporter
|
||||
from graphrag.query.llm.base import BaseTextEmbedding
|
||||
from graphrag.query.llm.oai.typing import OpenaiApiType
|
||||
|
||||
@ -101,7 +101,7 @@ class OpenAILLMImpl(BaseOpenAILLM):
|
||||
organization: str | None = None,
|
||||
max_retries: int = 10,
|
||||
request_timeout: float = 180.0,
|
||||
reporter: StatusLogger | None = None,
|
||||
logger: StatusLogger | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.azure_ad_token_provider = azure_ad_token_provider
|
||||
@ -112,7 +112,7 @@ class OpenAILLMImpl(BaseOpenAILLM):
|
||||
self.organization = organization
|
||||
self.max_retries = max_retries
|
||||
self.request_timeout = request_timeout
|
||||
self.reporter = reporter or ConsoleReporter()
|
||||
self.logger = logger or ConsoleReporter()
|
||||
|
||||
try:
|
||||
# Create OpenAI sync and async clients
|
||||
|
||||
@ -15,7 +15,7 @@ from tenacity import (
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
from graphrag.logging.base import StatusLogger
|
||||
from graphrag.logger.base import StatusLogger
|
||||
from graphrag.query.llm.base import BaseLLM, BaseLLMCallback
|
||||
from graphrag.query.llm.oai.base import OpenAILLMImpl
|
||||
from graphrag.query.llm.oai.typing import (
|
||||
@ -42,7 +42,7 @@ class ChatOpenAI(BaseLLM, OpenAILLMImpl):
|
||||
max_retries: int = 10,
|
||||
request_timeout: float = 180.0,
|
||||
retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore
|
||||
reporter: StatusLogger | None = None,
|
||||
logger: StatusLogger | None = None,
|
||||
):
|
||||
OpenAILLMImpl.__init__(
|
||||
self=self,
|
||||
@ -55,7 +55,7 @@ class ChatOpenAI(BaseLLM, OpenAILLMImpl):
|
||||
organization=organization,
|
||||
max_retries=max_retries,
|
||||
request_timeout=request_timeout,
|
||||
reporter=reporter,
|
||||
logger=logger,
|
||||
)
|
||||
self.model = model
|
||||
self.retry_error_types = retry_error_types
|
||||
|
||||
@ -18,7 +18,7 @@ from tenacity import (
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
from graphrag.logging.base import StatusLogger
|
||||
from graphrag.logger.base import StatusLogger
|
||||
from graphrag.query.llm.base import BaseTextEmbedding
|
||||
from graphrag.query.llm.oai.base import OpenAILLMImpl
|
||||
from graphrag.query.llm.oai.typing import (
|
||||
@ -46,7 +46,7 @@ class OpenAIEmbedding(BaseTextEmbedding, OpenAILLMImpl):
|
||||
max_retries: int = 10,
|
||||
request_timeout: float = 180.0,
|
||||
retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore
|
||||
reporter: StatusLogger | None = None,
|
||||
logger: StatusLogger | None = None,
|
||||
):
|
||||
OpenAILLMImpl.__init__(
|
||||
self=self,
|
||||
@ -59,7 +59,7 @@ class OpenAIEmbedding(BaseTextEmbedding, OpenAILLMImpl):
|
||||
organization=organization,
|
||||
max_retries=max_retries,
|
||||
request_timeout=request_timeout,
|
||||
reporter=reporter,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
self.model = model
|
||||
|
||||
@ -13,7 +13,7 @@ from azure.identity import DefaultAzureCredential
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
from datashaper import Progress
|
||||
|
||||
from graphrag.logging.base import ProgressReporter
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -32,7 +32,7 @@ class BlobPipelineStorage(PipelineStorage):
|
||||
self,
|
||||
connection_string: str | None,
|
||||
container_name: str,
|
||||
encoding: str | None = None,
|
||||
encoding: str = "utf-8",
|
||||
path_prefix: str | None = None,
|
||||
storage_account_blob_url: str | None = None,
|
||||
):
|
||||
@ -50,7 +50,7 @@ class BlobPipelineStorage(PipelineStorage):
|
||||
account_url=storage_account_blob_url,
|
||||
credential=DefaultAzureCredential(),
|
||||
)
|
||||
self._encoding = encoding or "utf-8"
|
||||
self._encoding = encoding
|
||||
self._container_name = container_name
|
||||
self._connection_string = connection_string
|
||||
self._path_prefix = path_prefix or ""
|
||||
@ -95,7 +95,7 @@ class BlobPipelineStorage(PipelineStorage):
|
||||
self,
|
||||
file_pattern: re.Pattern[str],
|
||||
base_dir: str | None = None,
|
||||
progress: ProgressReporter | None = None,
|
||||
progress: ProgressLogger | None = None,
|
||||
file_filter: dict[str, Any] | None = None,
|
||||
max_count=-1,
|
||||
) -> Iterator[tuple[str, dict[str, Any]]]:
|
||||
@ -179,7 +179,7 @@ class BlobPipelineStorage(PipelineStorage):
|
||||
blob_client = container_client.get_blob_client(key)
|
||||
blob_data = blob_client.download_blob().readall()
|
||||
if not as_bytes:
|
||||
coding = encoding or "utf-8"
|
||||
coding = encoding or self._encoding
|
||||
blob_data = blob_data.decode(coding)
|
||||
except Exception:
|
||||
log.exception("Error getting key %s", key)
|
||||
@ -198,7 +198,7 @@ class BlobPipelineStorage(PipelineStorage):
|
||||
if isinstance(value, bytes):
|
||||
blob_client.upload_blob(value, overwrite=True)
|
||||
else:
|
||||
coding = encoding or "utf-8"
|
||||
coding = encoding or self._encoding
|
||||
blob_client.upload_blob(value.encode(coding), overwrite=True)
|
||||
except Exception:
|
||||
log.exception("Error setting key %s: %s", key)
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from graphrag.config.enums import StorageType
|
||||
from graphrag.storage.blob_pipeline_storage import create_blob_storage
|
||||
@ -13,29 +13,36 @@ from graphrag.storage.file_pipeline_storage import create_file_storage
|
||||
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.index.config.storage import (
|
||||
PipelineBlobStorageConfig,
|
||||
PipelineFileStorageConfig,
|
||||
PipelineStorageConfig,
|
||||
)
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
|
||||
def create_storage(config: PipelineStorageConfig):
|
||||
"""Create a storage object based on the config."""
|
||||
match config.type:
|
||||
case StorageType.memory:
|
||||
return MemoryPipelineStorage()
|
||||
case StorageType.blob:
|
||||
config = cast("PipelineBlobStorageConfig", config)
|
||||
return create_blob_storage(
|
||||
config.connection_string,
|
||||
config.storage_account_blob_url,
|
||||
config.container_name,
|
||||
config.base_dir,
|
||||
)
|
||||
case StorageType.file:
|
||||
config = cast("PipelineFileStorageConfig", config)
|
||||
return create_file_storage(config.base_dir)
|
||||
case _:
|
||||
msg = f"Unknown storage type: {config.type}"
|
||||
raise ValueError(msg)
|
||||
class StorageFactory:
|
||||
"""A factory class for storage implementations.
|
||||
|
||||
Includes a method for users to register a custom storage implementation.
|
||||
"""
|
||||
|
||||
storage_types: ClassVar[dict[str, type]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, storage_type: str, storage: type):
|
||||
"""Register a custom storage implementation."""
|
||||
cls.storage_types[storage_type] = storage
|
||||
|
||||
@classmethod
|
||||
def create_storage(
|
||||
cls, storage_type: StorageType | str, kwargs: dict
|
||||
) -> PipelineStorage:
|
||||
"""Create or get a storage object from the provided type."""
|
||||
match storage_type:
|
||||
case StorageType.blob:
|
||||
return create_blob_storage(**kwargs)
|
||||
case StorageType.file:
|
||||
return create_file_storage(**kwargs)
|
||||
case StorageType.memory:
|
||||
return MemoryPipelineStorage()
|
||||
case _:
|
||||
if storage_type in cls.storage_types:
|
||||
return cls.storage_types[storage_type](**kwargs)
|
||||
msg = f"Unknown storage type: {storage_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
@ -16,7 +16,7 @@ from aiofiles.os import remove
|
||||
from aiofiles.ospath import exists
|
||||
from datashaper import Progress
|
||||
|
||||
from graphrag.logging.base import ProgressReporter
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -28,17 +28,17 @@ class FilePipelineStorage(PipelineStorage):
|
||||
_root_dir: str
|
||||
_encoding: str
|
||||
|
||||
def __init__(self, root_dir: str | None = None, encoding: str | None = None):
|
||||
def __init__(self, root_dir: str = "", encoding: str = "utf-8"):
|
||||
"""Init method definition."""
|
||||
self._root_dir = root_dir or ""
|
||||
self._encoding = encoding or "utf-8"
|
||||
self._root_dir = root_dir
|
||||
self._encoding = encoding
|
||||
Path(self._root_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def find(
|
||||
self,
|
||||
file_pattern: re.Pattern[str],
|
||||
base_dir: str | None = None,
|
||||
progress: ProgressReporter | None = None,
|
||||
progress: ProgressLogger | None = None,
|
||||
file_filter: dict[str, Any] | None = None,
|
||||
max_count=-1,
|
||||
) -> Iterator[tuple[str, dict[str, Any]]]:
|
||||
@ -152,10 +152,11 @@ def join_path(file_path: str, file_name: str) -> Path:
|
||||
return Path(file_path) / Path(file_name).parent / Path(file_name).name
|
||||
|
||||
|
||||
def create_file_storage(out_dir: str | None) -> PipelineStorage:
|
||||
def create_file_storage(**kwargs: Any) -> PipelineStorage:
|
||||
"""Create a file based storage."""
|
||||
log.info("Creating file storage at %s", out_dir)
|
||||
return FilePipelineStorage(out_dir)
|
||||
base_dir = kwargs["base_dir"]
|
||||
log.info("Creating file storage at %s", base_dir)
|
||||
return FilePipelineStorage(root_dir=base_dir)
|
||||
|
||||
|
||||
def _create_progress_status(
|
||||
|
||||
@ -8,7 +8,7 @@ from abc import ABCMeta, abstractmethod
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from graphrag.logging.base import ProgressReporter
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
|
||||
|
||||
class PipelineStorage(metaclass=ABCMeta):
|
||||
@ -19,7 +19,7 @@ class PipelineStorage(metaclass=ABCMeta):
|
||||
self,
|
||||
file_pattern: re.Pattern[str],
|
||||
base_dir: str | None = None,
|
||||
progress: ProgressReporter | None = None,
|
||||
progress: ProgressLogger | None = None,
|
||||
file_filter: dict[str, Any] | None = None,
|
||||
max_count=-1,
|
||||
) -> Iterator[tuple[str, dict[str, Any]]]:
|
||||
|
||||
@ -19,7 +19,7 @@ async def load_table_from_storage(name: str, storage: PipelineStorage) -> pd.Dat
|
||||
msg = f"Could not find {name} in storage!"
|
||||
raise ValueError(msg)
|
||||
try:
|
||||
log.info("read table from storage: %s", name)
|
||||
log.info("reading table from storage: %s", name)
|
||||
return pd.read_parquet(BytesIO(await storage.get(name, as_bytes=True)))
|
||||
except Exception:
|
||||
log.exception("error loading table from storage: %s", name)
|
||||
|
||||
@ -7,6 +7,7 @@ from enum import Enum
|
||||
from typing import ClassVar
|
||||
|
||||
from graphrag.vector_stores.azure_ai_search import AzureAISearch
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||
|
||||
|
||||
@ -24,14 +25,14 @@ class VectorStoreFactory:
|
||||
|
||||
@classmethod
|
||||
def register(cls, vector_store_type: str, vector_store: type):
|
||||
"""Register a vector store type."""
|
||||
"""Register a custom vector store implementation."""
|
||||
cls.vector_store_types[vector_store_type] = vector_store
|
||||
|
||||
@classmethod
|
||||
def get_vector_store(
|
||||
def create_vector_store(
|
||||
cls, vector_store_type: VectorStoreType | str, kwargs: dict
|
||||
) -> LanceDBVectorStore | AzureAISearch:
|
||||
"""Get the vector store type from a string."""
|
||||
) -> BaseVectorStore:
|
||||
"""Create or get a vector store from the provided type."""
|
||||
match vector_store_type:
|
||||
case VectorStoreType.LanceDB:
|
||||
return LanceDBVectorStore(**kwargs)
|
||||
|
||||
1576
poetry.lock
generated
1576
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -132,7 +132,6 @@ coverage_report = 'coverage report --omit "**/tests/**" --show-missing'
|
||||
check_format = 'ruff format . --check'
|
||||
fix = "ruff check --fix ."
|
||||
fix_unsafe = "ruff check --fix --unsafe-fixes ."
|
||||
|
||||
_test_all = "coverage run -m pytest ./tests"
|
||||
test_unit = "pytest ./tests/unit"
|
||||
test_integration = "pytest ./tests/integration"
|
||||
@ -146,7 +145,6 @@ query = "python -m graphrag query"
|
||||
prompt_tune = "python -m graphrag prompt-tune"
|
||||
# Pass in a test pattern
|
||||
test_only = "pytest -s -k"
|
||||
|
||||
serve_docs = "mkdocs serve"
|
||||
build_docs = "mkdocs build"
|
||||
|
||||
|
||||
@ -137,7 +137,7 @@ class TestIndexer:
|
||||
"--verbose" if debug else None,
|
||||
"--root",
|
||||
root.resolve().as_posix(),
|
||||
"--reporter",
|
||||
"--logger",
|
||||
"print",
|
||||
]
|
||||
command = [arg for arg in command if arg]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user